From d56ed53eb1af85f563f5af9d73e9001003cedf83 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 21 Nov 2024 17:51:45 +1100 Subject: [PATCH] FIX: cancel functionality regressed (#938) The cancel messaging was not floating correctly to the HTTP call leading to impossible to cancel completions This is now fully tested as well. --- lib/ai_bot/playground.rb | 4 ++ lib/ai_bot/post_streamer.rb | 10 +++ lib/completions/endpoints/base.rb | 11 +++- spec/lib/modules/ai_bot/playground_spec.rb | 76 +++++++++++++++++++++- 4 files changed, 97 insertions(+), 4 deletions(-) diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 222d79c8..33327306 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -461,6 +461,10 @@ module DiscourseAi if stream_reply && !Discourse.redis.get(redis_stream_key) cancel&.call reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) + # we do not break out, cause if we do + # we will not get results from bot + # leading to broken context + # we need to trust it to cancel at the endpoint end if post_streamer diff --git a/lib/ai_bot/post_streamer.rb b/lib/ai_bot/post_streamer.rb index 57ba3c40..73621a2f 100644 --- a/lib/ai_bot/post_streamer.rb +++ b/lib/ai_bot/post_streamer.rb @@ -3,6 +3,15 @@ module DiscourseAi module AiBot class PostStreamer + # test only + def self.on_callback=(on_callback) + @on_callback = on_callback + end + + def self.on_callback + @on_callback + end + def initialize(delay: 0.5) @mutex = Mutex.new @callback = nil @@ -11,6 +20,7 @@ module DiscourseAi end def run_later(&callback) + self.class.on_callback.call(callback) if self.class.on_callback @mutex.synchronize { @callback = callback } ensure_worker! end diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 8c9711e6..7bc8c076 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -139,15 +139,20 @@ module DiscourseAi begin cancelled = false - cancel = -> { cancelled = true } - if cancelled + cancel = -> do + cancelled = true http.finish - break end + break if cancelled + response.read_body do |chunk| + break if cancelled + response_raw << chunk + decode_chunk(chunk).each do |partial| + break if cancelled partials_raw << partial.to_s response_data << partial if partial.is_a?(String) partials = [partial] diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 07485d1f..052b6800 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -3,7 +3,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do subject(:playground) { described_class.new(bot) } - fab!(:claude_2) { Fabricate(:llm_model, name: "claude-2") } + fab!(:claude_2) do + Fabricate( + :llm_model, + provider: "anthropic", + url: "https://api.anthropic.com/v1/messages", + name: "claude-2", + ) + end fab!(:opus_model) { Fabricate(:anthropic_model) } fab!(:bot_user) do @@ -948,6 +955,73 @@ RSpec.describe DiscourseAi::AiBot::Playground do end end + describe "#canceling a completions" do + after { DiscourseAi::AiBot::PostStreamer.on_callback = nil } + + it "should be able to cancel a completion halfway through" do + body = (<<~STRING).strip + event: message_start + data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}} + + event: content_block_start + data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}} + + event: ping + data: {"type": "ping"} + + |event: content_block_delta + data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} + + |event: content_block_delta + data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "1"}} + + |event: content_block_delta + data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "2"}} + + |event: content_block_delta + data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "3"}} + + event: content_block_stop + data: {"type": "content_block_stop", "index": 0} + + event: message_delta + data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}} + + event: message_stop + data: {"type": "message_stop"} + STRING + + split = body.split("|") + + count = 0 + DiscourseAi::AiBot::PostStreamer.on_callback = + proc do |callback| + count += 1 + if count == 2 + last_post = third_post.topic.posts.order(:id).last + Discourse.redis.del("gpt_cancel:#{last_post.id}") + end + raise "this should not happen" if count > 2 + end + + require_relative("../../completions/endpoints/endpoint_compliance") + EndpointMock.with_chunk_array_support do + stub_request(:post, "https://api.anthropic.com/v1/messages").to_return( + status: 200, + body: split, + ) + # we are going to need to use real data here cause we want to trigger the + # base endpoint to cancel part way through + playground.reply_to(third_post) + end + + last_post = third_post.topic.posts.order(:id).last + + # not Hello123, we cancelled at 1 which means we may get 2 and then be done + expect(last_post.raw).to eq("Hello12") + end + end + describe "#available_bot_usernames" do it "includes persona users" do persona = Fabricate(:ai_persona)