From ef7d4cc5090e54491ab00d5bdd0ef3ad85c499de Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Tue, 12 Dec 2023 17:06:53 -0300 Subject: [PATCH] FIX: Recover from Bedrock returning invalid base64 payloads during streaming (#352) --- lib/completions/endpoints/aws_bedrock.rb | 2 +- .../completions/endpoints/aws_bedrock_spec.rb | 51 +++++++++++-------- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index 84fe4f49..7624b7c8 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -71,7 +71,7 @@ module DiscourseAi .string .then { JSON.parse(_1) } .dig("bytes") - .then { Base64.decode64(_1) } + .then { Base64.decode64(_1.to_s) } rescue JSON::ParserError, Aws::EventStream::Errors::MessageChecksumError, Aws::EventStream::Errors::PreludeChecksumError => e diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index 5c0cb8cc..c9f9f037 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -78,26 +78,24 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do def stream_line(delta, finish_reason: nil) encoder = Aws::EventStream::Encoder.new - message = - Aws::EventStream::Message.new( - payload: - StringIO.new( - { - bytes: - Base64.encode64( - { - completion: delta, - stop: finish_reason ? "\n\nHuman:" : nil, - stop_reason: finish_reason, - truncated: false, - log_id: "12b029451c6d18094d868bc04ce83f63", - model: "claude-2", - exception: nil, - }.to_json, - ), - }.to_json, - ), - ) + bytes = + if delta.nil? + nil + else + Base64.encode64( + { + completion: delta, + stop: finish_reason ? "\n\nHuman:" : nil, + stop_reason: finish_reason, + truncated: false, + log_id: "12b029451c6d18094d868bc04ce83f63", + model: "claude-2", + exception: nil, + }.to_json, + ) + end + + message = Aws::EventStream::Message.new(payload: StringIO.new({ bytes: bytes }.to_json)) encoder.encode(message) end @@ -122,4 +120,17 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do end it_behaves_like "an endpoint that can communicate with a completion service" + + it "skips empty deltas" do + completion_deltas = [nil, "Mount", "ain", " ", "Tree ", "Frog"] + stub_streamed_response(prompt, completion_deltas) + completion_response = +"" + + model.perform_completion!(prompt, Fabricate(:user)) do |partial, cancel| + completion_response << partial + cancel.call if completion_response.split(" ").length == 2 + end + + expect(completion_response).to eq(completion_deltas[0...-1].join) + end end