Revert "FIX: Recover from Bedrock returning invalid base64 payloads during streaming (#352)" (#353)

This reverts commit ef7d4cc509.
This commit is contained in:
Roman Rizzi 2023-12-12 17:22:44 -03:00 committed by GitHub
parent ef7d4cc509
commit 031c2a6b46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 32 deletions

View File

@ -71,7 +71,7 @@ module DiscourseAi
.string
.then { JSON.parse(_1) }
.dig("bytes")
.then { Base64.decode64(_1.to_s) }
.then { Base64.decode64(_1) }
rescue JSON::ParserError,
Aws::EventStream::Errors::MessageChecksumError,
Aws::EventStream::Errors::PreludeChecksumError => e

View File

@ -78,24 +78,26 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
def stream_line(delta, finish_reason: nil)
encoder = Aws::EventStream::Encoder.new
bytes =
if delta.nil?
nil
else
Base64.encode64(
{
completion: delta,
stop: finish_reason ? "\n\nHuman:" : nil,
stop_reason: finish_reason,
truncated: false,
log_id: "12b029451c6d18094d868bc04ce83f63",
model: "claude-2",
exception: nil,
}.to_json,
)
end
message = Aws::EventStream::Message.new(payload: StringIO.new({ bytes: bytes }.to_json))
message =
Aws::EventStream::Message.new(
payload:
StringIO.new(
{
bytes:
Base64.encode64(
{
completion: delta,
stop: finish_reason ? "\n\nHuman:" : nil,
stop_reason: finish_reason,
truncated: false,
log_id: "12b029451c6d18094d868bc04ce83f63",
model: "claude-2",
exception: nil,
}.to_json,
),
}.to_json,
),
)
encoder.encode(message)
end
@ -120,17 +122,4 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
end
it_behaves_like "an endpoint that can communicate with a completion service"
it "skips empty deltas" do
completion_deltas = [nil, "Mount", "ain", " ", "Tree ", "Frog"]
stub_streamed_response(prompt, completion_deltas)
completion_response = +""
model.perform_completion!(prompt, Fabricate(:user)) do |partial, cancel|
completion_response << partial
cancel.call if completion_response.split(" ").length == 2
end
expect(completion_response).to eq(completion_deltas[0...-1].join)
end
end