FIX: support spaces within arguments for Open AI (#499)

Previous to this fix if a tool call ever streamed a SPACE alone,
we would eat it and ignore it, breaking params

Also fixes some tests to ensure they are actually called :)
This commit is contained in:
Sam 2024-02-29 12:47:34 +11:00 committed by GitHub
parent 1b72a00d2c
commit 9fb1430e40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 84 additions and 13 deletions

View File

@ -136,14 +136,12 @@ module DiscourseAi
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
# half a line sent here
return if !parsed
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
@has_function_call ||= response_h.dig(:tool_calls).present?
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
end
@ -172,8 +170,11 @@ module DiscourseAi
function_buffer.at("tool_name").content = f_name if f_name
function_buffer.at("tool_id").content = partial[:id] if partial[:id]
if partial.dig(:function, :arguments).present?
@args_buffer << partial.dig(:function, :arguments)
args = partial.dig(:function, :arguments)
# allow for SPACE within arguments
if args && args != ""
@args_buffer << args
begin
json_args = JSON.parse(@args_buffer, symbolize_names: true)

View File

@ -53,6 +53,13 @@ class OpenAiMock < EndpointMock
}.to_json
end
def stub_raw(chunks)
WebMock.stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
status: 200,
body: chunks,
)
end
def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
@ -69,6 +76,8 @@ class OpenAiMock < EndpointMock
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
.with(body: request_body(prompt, stream: true, tool_call: tool_call))
.to_return(status: 200, body: chunks)
yield if block_given?
end
def tool_deltas
@ -168,14 +177,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
end
it "will automatically recover from a bad payload" do
called = false
# this should not happen, but lets ensure nothing bad happens
# the row with test1 is invalid json
raw_data = <<~TEXT.strip
d|a|t|a|:| |{|"choices":[{"delta":{"content":"test,"}}]}
data: {"choices":[{"delta":{"content":"test1,"}}]
data: {"choices":[{"delta":{"content":"test|1| |,"}}]
data: {"choices":[{"delta":|{"content":"test2,"}}]}
data: {"choices":[{"delta":|{"content":"test2 ,"}}]}
data: {"choices":[{"delta":{"content":"test3,"}}]|}
@ -187,16 +198,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
chunks = raw_data.split("|")
open_ai_mock.with_chunk_array_support do
open_ai_mock.stub_streamed_response(compliance.dialect.translate, chunks) do
partials = []
open_ai_mock.stub_raw(chunks)
endpoint.perform_completion!(compliance.dialect, user) do |partial|
partials << partial
end
partials = []
expect(partials.join).to eq("test,test1,test2,test3,test4")
end
endpoint.perform_completion!(compliance.dialect, user) { |partial| partials << partial }
called = true
expect(partials.join).to eq("test,test2 ,test3,test4")
end
expect(called).to be(true)
end
end
@ -204,6 +215,65 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
it "returns a function invocation" do
compliance.streaming_mode_tools(open_ai_mock)
end
it "properly handles spaces in tools payload" do
raw_data = <<~TEXT.strip
data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"google","arguments":""}}]}}]}
data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "{\\""}}]}}]}
data: {"ch|oices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "query"}}]}}]}
data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "\\":\\""}}]}}]}
data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "Ad"}}]}}]}
data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "a|b"}}]}}]}
data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "as"}}]}}]}
data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": |"| "}}]}}]}
data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "9"}}]}}]}
data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "."}}]}}]}
data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"argume|nts": "1"}}]}}]}
data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "\\"}"}}]}}]}
data: {"choices": [{"index": 0, "delta": {"tool_calls": []}}]}
data: [D|ONE]
TEXT
chunks = raw_data.split("|")
open_ai_mock.with_chunk_array_support do
open_ai_mock.stub_raw(chunks)
partials = []
endpoint.perform_completion!(compliance.dialect, user) do |partial, x, y|
partials << partial
end
expect(partials.length).to eq(1)
function_call = (<<~TXT).strip
<function_calls>
<invoke>
<tool_name>google</tool_name>
<tool_id>func_id</tool_id>
<parameters>
<query>Adabas 9.1</query>
</parameters>
</invoke>
</function_calls>
TXT
expect(partials[0].strip).to eq(function_call)
end
end
end
end
end