FIX: streaming broken in bedrock when chunks are not aligned (#609)

Also

- Stop caching llm list - this cause llm list in persona to be incorrect
- Add more UI to debug screen so you can properly see raw response
This commit is contained in:
Sam 2024-05-09 12:11:50 +10:00 committed by GitHub
parent cf34838a09
commit 514823daca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 111 additions and 28 deletions

View File

@ -1,5 +1,6 @@
import Component from "@glimmer/component"; import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking"; import { tracked } from "@glimmer/tracking";
import { on } from "@ember/modifier";
import { action } from "@ember/object"; import { action } from "@ember/object";
import { next } from "@ember/runloop"; import { next } from "@ember/runloop";
import { htmlSafe } from "@ember/template"; import { htmlSafe } from "@ember/template";
@ -14,6 +15,7 @@ import I18n from "discourse-i18n";
export default class DebugAiModal extends Component { export default class DebugAiModal extends Component {
@tracked info = null; @tracked info = null;
@tracked justCopiedText = ""; @tracked justCopiedText = "";
@tracked activeTab = "request";
constructor() { constructor() {
super(...arguments); super(...arguments);
@ -30,7 +32,11 @@ export default class DebugAiModal extends Component {
let parsed; let parsed;
try { try {
parsed = JSON.parse(this.info.raw_request_payload); if (this.activeTab === "request") {
parsed = JSON.parse(this.info.raw_request_payload);
} else {
return this.formattedResponse(this.info.raw_response_payload);
}
} catch (e) { } catch (e) {
return this.info.raw_request_payload; return this.info.raw_request_payload;
} }
@ -38,6 +44,14 @@ export default class DebugAiModal extends Component {
return htmlSafe(this.jsonToHtml(parsed)); return htmlSafe(this.jsonToHtml(parsed));
} }
formattedResponse(response) {
// we need to replace the new lines with <br> to make it look good
const split = response.split("\n");
const safe = split.map((line) => escapeExpression(line)).join("<br>");
return htmlSafe(safe);
}
jsonToHtml(json) { jsonToHtml(json) {
let html = "<ul>"; let html = "<ul>";
for (let key in json) { for (let key in json) {
@ -94,6 +108,26 @@ export default class DebugAiModal extends Component {
}); });
} }
get requestActive() {
return this.activeTab === "request" ? "active" : "";
}
get responseActive() {
return this.activeTab === "response" ? "active" : "";
}
@action
requestClicked(e) {
this.activeTab = "request";
e.preventDefault();
}
@action
responseClicked(e) {
this.activeTab = "response";
e.preventDefault();
}
<template> <template>
<DModal <DModal
class="ai-debug-modal" class="ai-debug-modal"
@ -101,6 +135,18 @@ export default class DebugAiModal extends Component {
@closeModal={{@closeModal}} @closeModal={{@closeModal}}
> >
<:body> <:body>
<ul class="nav nav-pills ai-debug-modal__nav">
<li><a
href=""
class={{this.requestActive}}
{{on "click" this.requestClicked}}
>{{i18n "discourse_ai.ai_bot.debug_ai_modal.request"}}</a></li>
<li><a
href=""
class={{this.responseActive}}
{{on "click" this.responseClicked}}
>{{i18n "discourse_ai.ai_bot.debug_ai_modal.response"}}</a></li>
</ul>
<div class="ai-debug-modal__tokens"> <div class="ai-debug-modal__tokens">
<span> <span>
{{i18n "discourse_ai.ai_bot.debug_ai_modal.request_tokens"}} {{i18n "discourse_ai.ai_bot.debug_ai_modal.request_tokens"}}

View File

@ -161,3 +161,9 @@ span.onebox-ai-llm-title {
.ai-debug-modal__tokens span { .ai-debug-modal__tokens span {
display: block; display: block;
} }
.d-modal ul.ai-debug-modal__nav {
margin: 0 0 1em;
padding: 0;
border-bottom: none;
}

View File

@ -271,6 +271,8 @@ en:
copy_response: "Copy response" copy_response: "Copy response"
request_tokens: "Request tokens:" request_tokens: "Request tokens:"
response_tokens: "Response tokens:" response_tokens: "Response tokens:"
request: "Request"
response: "Response"
share_full_topic_modal: share_full_topic_modal:
title: "Share Conversation Publicly" title: "Share Conversation Publicly"

View File

@ -111,23 +111,30 @@ module DiscourseAi
end end
def decode(chunk) def decode(chunk)
parsed = @decoder ||= Aws::EventStream::Decoder.new
Aws::EventStream::Decoder
.new
.decode_chunk(chunk)
.first
.payload
.string
.then { JSON.parse(_1) }
bytes = parsed.dig("bytes") decoded, _done = @decoder.decode_chunk(chunk)
if !bytes messages = []
Rails.logger.error("#{self.class.name}: #{parsed.to_s[0..500]}") return messages if !decoded
nil
else i = 0
Base64.decode64(parsed.dig("bytes")) while decoded
parsed = JSON.parse(decoded.payload.string)
messages << Base64.decode64(parsed["bytes"])
decoded, _done = @decoder.decode_chunk
i += 1
if i > 10_000
Rails.logger.error(
"DiscourseAI: Stream decoder looped too many times, logic error needs fixing",
)
break
end
end end
messages
rescue JSON::ParserError, rescue JSON::ParserError,
Aws::EventStream::Errors::MessageChecksumError, Aws::EventStream::Errors::MessageChecksumError,
Aws::EventStream::Errors::PreludeChecksumError => e Aws::EventStream::Errors::PreludeChecksumError => e
@ -161,8 +168,14 @@ module DiscourseAi
result result
end end
def partials_from(decoded_chunk) def partials_from(decoded_chunks)
[decoded_chunk] decoded_chunks
end
def chunk_to_string(chunk)
joined = +chunk.join("\n")
joined << "\n" if joined.length > 0
joined
end end
end end
end end

View File

@ -168,9 +168,15 @@ module DiscourseAi
if decoded_chunk.nil? if decoded_chunk.nil?
raise CompletionFailed, "#{self.class.name}: Failed to decode LLM completion" raise CompletionFailed, "#{self.class.name}: Failed to decode LLM completion"
end end
response_raw << decoded_chunk response_raw << chunk_to_string(decoded_chunk)
redo_chunk = leftover + decoded_chunk if decoded_chunk.is_a?(String)
redo_chunk = leftover + decoded_chunk
else
# custom implementation for endpoint
# no implicit leftover support
redo_chunk = decoded_chunk
end
raw_partials = partials_from(redo_chunk) raw_partials = partials_from(redo_chunk)
@ -347,6 +353,14 @@ module DiscourseAi
response.include?("<function_calls>") response.include?("<function_calls>")
end end
def chunk_to_string(chunk)
if chunk.is_a?(String)
chunk
else
chunk.to_s
end
end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil) def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
if payload&.include?("</invoke>") if payload&.include?("</invoke>")
matches = payload.match(%r{<function_calls>.*</invoke>}m) matches = payload.match(%r{<function_calls>.*</invoke>}m)

View File

@ -26,9 +26,7 @@ module DiscourseAi
def normalize_model_params(model_params) def normalize_model_params(model_params)
model_params = model_params.dup model_params = model_params.dup
model_params[:p] = model_params.delete(:top_p) if model_params[:top_p] model_params[:p] = model_params.delete(:top_p) if model_params[:top_p]
model_params model_params
end end

View File

@ -10,15 +10,15 @@ module DiscourseAi
end end
def self.values def self.values
@values ||= # do not cache cause settings can change this
DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models| DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models|
endpoint = endpoint =
DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first) DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first)
models.map do |model_name| models.map do |model_name|
{ name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" } { name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" }
end
end end
end
end end
end end
end end

View File

@ -86,6 +86,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
Aws::EventStream::Encoder.new.encode(aws_message) Aws::EventStream::Encoder.new.encode(aws_message)
end end
# stream 1 letter at a time
# cause we need to handle this case
messages = messages.join("").split
bedrock_mock.with_chunk_array_support do bedrock_mock.with_chunk_array_support do
stub_request( stub_request(
:post, :post,