diff --git a/lib/shared/inference/openai_completions.rb b/lib/shared/inference/openai_completions.rb index 17306a3b..c7553c8f 100644 --- a/lib/shared/inference/openai_completions.rb +++ b/lib/shared/inference/openai_completions.rb @@ -3,31 +3,88 @@ module ::DiscourseAi module Inference class OpenAiCompletions + TIMEOUT = 60 + CompletionFailed = Class.new(StandardError) - def self.perform!(messages, model = SiteSetting.ai_helper_model) + def self.perform!( + messages, + model = SiteSetting.ai_helper_model, + temperature: nil, + top_p: nil, + max_tokens: nil, + stream: false, + &blk + ) + raise ArgumentError, "block must be supplied in streaming mode" if stream && !blk + + url = URI("https://api.openai.com/v1/chat/completions") headers = { - "Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}", - "Content-Type" => "application/json", + "Content-Type": "application/json", + Authorization: "Bearer #{SiteSetting.ai_openai_api_key}", } + payload = { model: model, messages: messages } - connection_opts = { request: { write_timeout: 60, read_timeout: 60, open_timeout: 60 } } + payload[:temperature] = temperature if temperature + payload[:top_p] = top_p if top_p + payload[:max_tokens] = max_tokens if max_tokens + payload[:stream] = true if stream - response = - Faraday.new(nil, connection_opts).post( - "https://api.openai.com/v1/chat/completions", - { model: model, messages: messages }.to_json, - headers, - ) + Net::HTTP.start( + url.host, + url.port, + use_ssl: true, + read_timeout: TIMEOUT, + open_timeout: TIMEOUT, + write_timeout: TIMEOUT, + ) do |http| + request = Net::HTTP::Post.new(url, headers) + request.body = payload.to_json - if response.status != 200 - Rails.logger.error( - "OpenAiCompletions: status: #{response.status} - body: #{response.body}", - ) - raise CompletionFailed + response = http.request(request) + + if response.code.to_i != 200 + Rails.logger.error( + "OpenAiCompletions: status: #{response.code.to_i} - body: #{response.body}", + ) + raise CompletionFailed + end + + if stream + stream(http, response, &blk) + else + JSON.parse(response.read_body, symbolize_names: true) + end end + end - JSON.parse(response.body, symbolize_names: true) + def self.stream(http, response) + cancelled = false + cancel = lambda { cancelled = true } + + response.read_body do |chunk| + if cancelled + http.finish + break + end + + chunk + .split("\n") + .each do |line| + data = line.split("data: ", 2)[1] + + next if data == "[DONE]" + + yield JSON.parse(data, symbolize_names: true), cancel if data + + if cancelled + http.finish + break + end + end + end + rescue IOError + raise if !cancelled end end end diff --git a/spec/shared/inference/openai_completions_spec.rb b/spec/shared/inference/openai_completions_spec.rb new file mode 100644 index 00000000..4624b1fd --- /dev/null +++ b/spec/shared/inference/openai_completions_spec.rb @@ -0,0 +1,92 @@ +# frozen_string_literal: true +require "rails_helper" + +describe DiscourseAi::Inference::OpenAiCompletions do + before { SiteSetting.ai_openai_api_key = "abc-123" } + + it "can complete a trivial prompt" do + body = <<~JSON + {"id":"chatcmpl-74OT0yKnvbmTkqyBINbHgAW0fpbxc","object":"chat.completion","created":1681281718,"model":"gpt-3.5-turbo-0301","usage":{"prompt_tokens":12,"completion_tokens":13,"total_tokens":25},"choices":[{"message":{"role":"assistant","content":"1. Serenity\\n2. Laughter\\n3. Adventure"},"finish_reason":"stop","index":0}]} + JSON + + stub_request(:post, "https://api.openai.com/v1/chat/completions").with( + body: + "{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"write 3 words\"}],\"temperature\":0.5,\"top_p\":0.8,\"max_tokens\":700}", + headers: { + "Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}", + "Content-Type" => "application/json", + }, + ).to_return(status: 200, body: body, headers: {}) + + prompt = [role: "user", content: "write 3 words"] + completions = + DiscourseAi::Inference::OpenAiCompletions.perform!( + prompt, + "gpt-3.5-turbo", + temperature: 0.5, + top_p: 0.8, + max_tokens: 700, + ) + expect(completions[:choices][0][:message][:content]).to eq( + "1. Serenity\n2. Laughter\n3. Adventure", + ) + end + + it "raises an error if attempting to stream without a block" do + expect do + DiscourseAi::Inference::OpenAiCompletions.perform!([], "gpt-3.5-turbo", stream: true) + end.to raise_error(ArgumentError) + end + + def stream_line(finish_reason: nil, delta: {}) + +"data: " << { + id: "chatcmpl-#{SecureRandom.hex}", + object: "chat.completion.chunk", + created: 1_681_283_881, + model: "gpt-3.5-turbo-0301", + choices: [{ delta: delta }], + finish_reason: finish_reason, + index: 0, + }.to_json + end + + it "can operate in streaming mode" do + payload = [ + stream_line(delta: { role: "assistant" }), + stream_line(delta: { content: "Mount" }), + stream_line(delta: { content: "ain" }), + stream_line(delta: { content: " " }), + stream_line(delta: { content: "Tree " }), + stream_line(delta: { content: "Frog" }), + stream_line(finish_reason: "stop"), + "[DONE]", + ].join("\n\n") + + stub_request(:post, "https://api.openai.com/v1/chat/completions").with( + body: + "{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"write 3 words\"}],\"stream\":true}", + headers: { + "Accept" => "*/*", + "Authorization" => "Bearer abc-123", + "Content-Type" => "application/json", + "Host" => "api.openai.com", + }, + ).to_return(status: 200, body: payload, headers: {}) + + prompt = [role: "user", content: "write 3 words"] + + content = +"" + + DiscourseAi::Inference::OpenAiCompletions.perform!( + prompt, + "gpt-3.5-turbo", + stream: true, + ) do |partial, cancel| + data = partial[:choices][0].dig(:delta, :content) + content << data if data + cancel.call if content.split(" ").length == 2 + end + + expect(content).to eq("Mountain Tree ") + end +end diff --git a/spec/support/openai_completions_inference_stubs.rb b/spec/support/openai_completions_inference_stubs.rb index 18e9bcda..369983fa 100644 --- a/spec/support/openai_completions_inference_stubs.rb +++ b/spec/support/openai_completions_inference_stubs.rb @@ -19,13 +19,13 @@ class OpenAiCompletionsInferenceStubs def spanish_text <<~STRING - Para que su horror sea perfecto, César, acosado al pie de la estatua por lo impacientes puñales de sus amigos, + Para que su horror sea perfecto, César, acosado al pie de la estatua por lo impacientes puñales de sus amigos, descubre entre las caras y los aceros la de Marco Bruto, su protegido, acaso su hijo, y ya no se defiende y exclama: ¡Tú también, hijo mío! Shakespeare y Quevedo recogen el patético grito. - - Al destino le agradan las repeticiones, las variantes, las simetrías; diecinueve siglos después, - en el sur de la provincia de Buenos Aires, un gaucho es agredido por otros gauchos y, al caer, - reconoce a un ahijado suyo y le dice con mansa reconvención y lenta sorpresa (estas palabras hay que oírlas, no leerlas): + + Al destino le agradan las repeticiones, las variantes, las simetrías; diecinueve siglos después, + en el sur de la provincia de Buenos Aires, un gaucho es agredido por otros gauchos y, al caer, + reconoce a un ahijado suyo y le dice con mansa reconvención y lenta sorpresa (estas palabras hay que oírlas, no leerlas): ¡Pero, che! Lo matan y no sabe que muere para que se repita una escena. STRING end @@ -59,7 +59,7 @@ class OpenAiCompletionsInferenceStubs where someone is betrayed by a close friend or protege, uttering a similar phrase of surprise and disappointment before their untimely death. The first example refers to Julius Caesar, who upon realizing that one of his own friends and proteges, Marcus Brutus, is among his assassins, exclaims \"You too, my son!\" The second example - is of a gaucho in Buenos Aires, who recognizes his godson among his attackers and utters the words of rebuke + is of a gaucho in Buenos Aires, who recognizes his godson among his attackers and utters the words of rebuke and surprise, \"But, my friend!\" before he is killed. The author suggests that these tragedies occur so that a scene may be repeated, emphasizing the cyclical nature of history and the inevitability of certain events." STRING @@ -94,7 +94,6 @@ class OpenAiCompletionsInferenceStubs end def stub_prompt(type) - prompt_builder = DiscourseAi::AiHelper::LlmPrompt.new text = type == TRANSLATE ? spanish_text : translated_response prompt_messages = CompletionPrompt.find_by(name: type).messages_with_user_input(text)