FIX: cancel functionality regressed (#938)

The cancel messaging was not floating correctly to the HTTP call leading to impossible to cancel completions 

This is now fully tested as well.
This commit is contained in:
Sam 2024-11-21 17:51:45 +11:00 committed by GitHub
parent d83248cf68
commit d56ed53eb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 97 additions and 4 deletions

View File

@ -461,6 +461,10 @@ module DiscourseAi
if stream_reply && !Discourse.redis.get(redis_stream_key) if stream_reply && !Discourse.redis.get(redis_stream_key)
cancel&.call cancel&.call
reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) reply_post.update!(raw: reply, cooked: PrettyText.cook(reply))
# we do not break out, cause if we do
# we will not get results from bot
# leading to broken context
# we need to trust it to cancel at the endpoint
end end
if post_streamer if post_streamer

View File

@ -3,6 +3,15 @@
module DiscourseAi module DiscourseAi
module AiBot module AiBot
class PostStreamer class PostStreamer
# test only
def self.on_callback=(on_callback)
@on_callback = on_callback
end
def self.on_callback
@on_callback
end
def initialize(delay: 0.5) def initialize(delay: 0.5)
@mutex = Mutex.new @mutex = Mutex.new
@callback = nil @callback = nil
@ -11,6 +20,7 @@ module DiscourseAi
end end
def run_later(&callback) def run_later(&callback)
self.class.on_callback.call(callback) if self.class.on_callback
@mutex.synchronize { @callback = callback } @mutex.synchronize { @callback = callback }
ensure_worker! ensure_worker!
end end

View File

@ -139,15 +139,20 @@ module DiscourseAi
begin begin
cancelled = false cancelled = false
cancel = -> { cancelled = true } cancel = -> do
if cancelled cancelled = true
http.finish http.finish
break
end end
break if cancelled
response.read_body do |chunk| response.read_body do |chunk|
break if cancelled
response_raw << chunk response_raw << chunk
decode_chunk(chunk).each do |partial| decode_chunk(chunk).each do |partial|
break if cancelled
partials_raw << partial.to_s partials_raw << partial.to_s
response_data << partial if partial.is_a?(String) response_data << partial if partial.is_a?(String)
partials = [partial] partials = [partial]

View File

@ -3,7 +3,14 @@
RSpec.describe DiscourseAi::AiBot::Playground do RSpec.describe DiscourseAi::AiBot::Playground do
subject(:playground) { described_class.new(bot) } subject(:playground) { described_class.new(bot) }
fab!(:claude_2) { Fabricate(:llm_model, name: "claude-2") } fab!(:claude_2) do
Fabricate(
:llm_model,
provider: "anthropic",
url: "https://api.anthropic.com/v1/messages",
name: "claude-2",
)
end
fab!(:opus_model) { Fabricate(:anthropic_model) } fab!(:opus_model) { Fabricate(:anthropic_model) }
fab!(:bot_user) do fab!(:bot_user) do
@ -948,6 +955,73 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end end
end end
describe "#canceling a completions" do
after { DiscourseAi::AiBot::PostStreamer.on_callback = nil }
it "should be able to cancel a completion halfway through" do
body = (<<~STRING).strip
event: message_start
data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}}
event: content_block_start
data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}}
event: ping
data: {"type": "ping"}
|event: content_block_delta
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}}
|event: content_block_delta
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "1"}}
|event: content_block_delta
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "2"}}
|event: content_block_delta
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "3"}}
event: content_block_stop
data: {"type": "content_block_stop", "index": 0}
event: message_delta
data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}}
event: message_stop
data: {"type": "message_stop"}
STRING
split = body.split("|")
count = 0
DiscourseAi::AiBot::PostStreamer.on_callback =
proc do |callback|
count += 1
if count == 2
last_post = third_post.topic.posts.order(:id).last
Discourse.redis.del("gpt_cancel:#{last_post.id}")
end
raise "this should not happen" if count > 2
end
require_relative("../../completions/endpoints/endpoint_compliance")
EndpointMock.with_chunk_array_support do
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(
status: 200,
body: split,
)
# we are going to need to use real data here cause we want to trigger the
# base endpoint to cancel part way through
playground.reply_to(third_post)
end
last_post = third_post.topic.posts.order(:id).last
# not Hello123, we cancelled at 1 which means we may get 2 and then be done
expect(last_post.raw).to eq("Hello12")
end
end
describe "#available_bot_usernames" do describe "#available_bot_usernames" do
it "includes persona users" do it "includes persona users" do
persona = Fabricate(:ai_persona) persona = Fabricate(:ai_persona)