DEV: Stop using shared_examples for endpoint specs (#430)
This commit is contained in:
parent
8eb1e851fc
commit
5bdf3dc1f4
|
@ -205,12 +205,10 @@ module DiscourseAi
|
||||||
tokenizer.size(extract_prompt_for_tokenizer(prompt))
|
tokenizer.size(extract_prompt_for_tokenizer(prompt))
|
||||||
end
|
end
|
||||||
|
|
||||||
attr_reader :tokenizer
|
attr_reader :tokenizer, :model
|
||||||
|
|
||||||
protected
|
protected
|
||||||
|
|
||||||
attr_reader :model
|
|
||||||
|
|
||||||
# should normalize temperature, max_tokens, stop_words to endpoint specific values
|
# should normalize temperature, max_tokens, stop_words to endpoint specific values
|
||||||
def normalize_model_params(model_params)
|
def normalize_model_params(model_params)
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -1,19 +1,8 @@
|
||||||
# frozen_String_literal: true
|
# frozen_String_literal: true
|
||||||
|
|
||||||
require_relative "endpoint_examples"
|
require_relative "endpoint_compliance"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|
||||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
|
||||||
|
|
||||||
let(:model_name) { "claude-2" }
|
|
||||||
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
|
|
||||||
let(:prompt) { dialect.translate }
|
|
||||||
|
|
||||||
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
|
||||||
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
|
|
||||||
|
|
||||||
let(:tool_id) { "get_weather" }
|
|
||||||
|
|
||||||
|
class AnthropicMock < EndpointMock
|
||||||
def response(content)
|
def response(content)
|
||||||
{
|
{
|
||||||
completion: content,
|
completion: content,
|
||||||
|
@ -21,7 +10,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
stop_reason: "stop_sequence",
|
stop_reason: "stop_sequence",
|
||||||
truncated: false,
|
truncated: false,
|
||||||
log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
|
log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
|
||||||
model: model_name,
|
model: "claude-2",
|
||||||
exception: nil,
|
exception: nil,
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
@ -29,7 +18,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
def stub_response(prompt, response_text, tool_call: false)
|
def stub_response(prompt, response_text, tool_call: false)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
||||||
.with(body: request_body)
|
.with(body: model.default_options.merge(prompt: prompt).to_json)
|
||||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -59,23 +48,49 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
||||||
.with(body: stream_request_body)
|
.with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
end
|
end
|
||||||
|
end
|
||||||
let(:tool_deltas) { ["Let me use a tool for that<function", <<~REPLY] }
|
|
||||||
_calls>
|
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
<invoke>
|
subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
||||||
<tool_name>get_weather</tool_name>
|
|
||||||
<parameters>
|
fab!(:user) { Fabricate(:user) }
|
||||||
<location>Sydney</location>
|
|
||||||
<unit>c</unit>
|
let(:anthropic_mock) { AnthropicMock.new(endpoint) }
|
||||||
</parameters>
|
|
||||||
</invoke>
|
let(:compliance) do
|
||||||
</function_calls>
|
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
|
||||||
REPLY
|
end
|
||||||
|
|
||||||
let(:tool_call) { invocation }
|
describe "#perform_completion!" do
|
||||||
|
context "when using regular mode" do
|
||||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
context "with simple prompts" do
|
||||||
|
it "completes a trivial prompt and logs the response" do
|
||||||
|
compliance.regular_mode_simple_prompt(anthropic_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with tools" do
|
||||||
|
it "returns a function invocation" do
|
||||||
|
compliance.regular_mode_tools(anthropic_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "when using streaming mode" do
|
||||||
|
context "with simple prompts" do
|
||||||
|
it "completes a trivial prompt and logs the response" do
|
||||||
|
compliance.streaming_mode_simple_prompt(anthropic_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with tools" do
|
||||||
|
it "returns a function invoncation" do
|
||||||
|
compliance.streaming_mode_tools(anthropic_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,28 +1,10 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
require_relative "endpoint_examples"
|
require_relative "endpoint_compliance"
|
||||||
require "aws-eventstream"
|
require "aws-eventstream"
|
||||||
require "aws-sigv4"
|
require "aws-sigv4"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
class BedrockMock < EndpointMock
|
||||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
|
||||||
|
|
||||||
let(:model_name) { "claude-2" }
|
|
||||||
let(:bedrock_name) { "claude-v2:1" }
|
|
||||||
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
|
|
||||||
let(:prompt) { dialect.translate }
|
|
||||||
|
|
||||||
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
|
||||||
let(:stream_request_body) { request_body }
|
|
||||||
|
|
||||||
let(:tool_id) { "get_weather" }
|
|
||||||
|
|
||||||
before do
|
|
||||||
SiteSetting.ai_bedrock_access_key_id = "123456"
|
|
||||||
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
|
|
||||||
SiteSetting.ai_bedrock_region = "us-east-1"
|
|
||||||
end
|
|
||||||
|
|
||||||
def response(content)
|
def response(content)
|
||||||
{
|
{
|
||||||
completion: content,
|
completion: content,
|
||||||
|
@ -30,19 +12,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
stop_reason: "stop_sequence",
|
stop_reason: "stop_sequence",
|
||||||
truncated: false,
|
truncated: false,
|
||||||
log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
|
log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
|
||||||
model: model_name,
|
model: "claude",
|
||||||
exception: nil,
|
exception: nil,
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_response(prompt, response_text, tool_call: false)
|
def stub_response(prompt, response_content, tool_call: false)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(
|
.stub_request(:post, "#{base_url}/invoke")
|
||||||
:post,
|
.with(body: model.default_options.merge(prompt: prompt).to_json)
|
||||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_name}/invoke",
|
.to_return(status: 200, body: JSON.dump(response(response_content)))
|
||||||
)
|
|
||||||
.with(body: request_body)
|
|
||||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def stream_line(delta, finish_reason: nil)
|
def stream_line(delta, finish_reason: nil)
|
||||||
|
@ -83,27 +62,60 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
end
|
end
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(
|
.stub_request(:post, "#{base_url}/invoke-with-response-stream")
|
||||||
:post,
|
.with(body: model.default_options.merge(prompt: prompt).to_json)
|
||||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_name}/invoke-with-response-stream",
|
|
||||||
)
|
|
||||||
.with(body: stream_request_body)
|
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:tool_deltas) { ["<function", <<~REPLY] }
|
def base_url
|
||||||
_calls>
|
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.claude-v2:1"
|
||||||
<invoke>
|
end
|
||||||
<tool_name>get_weather</tool_name>
|
end
|
||||||
<parameters>
|
|
||||||
<location>Sydney</location>
|
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
<unit>c</unit>
|
subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
||||||
</parameters>
|
|
||||||
</invoke>
|
fab!(:user) { Fabricate(:user) }
|
||||||
</function_calls>
|
|
||||||
REPLY
|
let(:bedrock_mock) { BedrockMock.new(endpoint) }
|
||||||
|
|
||||||
let(:tool_call) { invocation }
|
let(:compliance) do
|
||||||
|
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
|
||||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
end
|
||||||
|
|
||||||
|
before do
|
||||||
|
SiteSetting.ai_bedrock_access_key_id = "123456"
|
||||||
|
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
|
||||||
|
SiteSetting.ai_bedrock_region = "us-east-1"
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#perform_completion!" do
|
||||||
|
context "when using regular mode" do
|
||||||
|
context "with simple prompts" do
|
||||||
|
it "completes a trivial prompt and logs the response" do
|
||||||
|
compliance.regular_mode_simple_prompt(bedrock_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with tools" do
|
||||||
|
it "returns a function invocation" do
|
||||||
|
compliance.regular_mode_tools(bedrock_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "when using streaming mode" do
|
||||||
|
context "with simple prompts" do
|
||||||
|
it "completes a trivial prompt and logs the response" do
|
||||||
|
compliance.streaming_mode_simple_prompt(bedrock_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with tools" do
|
||||||
|
it "returns a function invoncation" do
|
||||||
|
compliance.streaming_mode_tools(bedrock_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,229 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require "net/http"
|
||||||
|
|
||||||
|
class EndpointMock
|
||||||
|
def initialize(model)
|
||||||
|
@model = model
|
||||||
|
end
|
||||||
|
|
||||||
|
attr_reader :model
|
||||||
|
|
||||||
|
def stub_simple_call(prompt)
|
||||||
|
stub_response(prompt, simple_response)
|
||||||
|
end
|
||||||
|
|
||||||
|
def stub_tool_call(prompt)
|
||||||
|
stub_response(prompt, tool_response, tool_call: true)
|
||||||
|
end
|
||||||
|
|
||||||
|
def stub_streamed_simple_call(prompt)
|
||||||
|
with_chunk_array_support do
|
||||||
|
stub_streamed_response(prompt, streamed_simple_deltas)
|
||||||
|
yield
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def stub_streamed_tool_call(prompt)
|
||||||
|
with_chunk_array_support do
|
||||||
|
stub_streamed_response(prompt, tool_deltas, tool_call: true)
|
||||||
|
yield
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def simple_response
|
||||||
|
"1. Serenity\\n2. Laughter\\n3. Adventure"
|
||||||
|
end
|
||||||
|
|
||||||
|
def streamed_simple_deltas
|
||||||
|
["Mount", "ain", " ", "Tree ", "Frog"]
|
||||||
|
end
|
||||||
|
|
||||||
|
def tool_deltas
|
||||||
|
["Let me use a tool for that<function", <<~REPLY.strip, <<~REPLY.strip, <<~REPLY.strip]
|
||||||
|
_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>get_weather</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<location>Sydney</location>
|
||||||
|
<unit>c</unit>
|
||||||
|
</para
|
||||||
|
REPLY
|
||||||
|
meters>
|
||||||
|
</invoke>
|
||||||
|
</funct
|
||||||
|
REPLY
|
||||||
|
ion_calls>
|
||||||
|
REPLY
|
||||||
|
end
|
||||||
|
|
||||||
|
def tool_response
|
||||||
|
tool_deltas.join
|
||||||
|
end
|
||||||
|
|
||||||
|
def invocation_response
|
||||||
|
<<~TEXT
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>get_weather</tool_name>
|
||||||
|
<tool_id>#{tool_id}</tool_id>
|
||||||
|
<parameters>
|
||||||
|
<location>Sydney</location>
|
||||||
|
<unit>c</unit>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
TEXT
|
||||||
|
end
|
||||||
|
|
||||||
|
def tool_id
|
||||||
|
"get_weather"
|
||||||
|
end
|
||||||
|
|
||||||
|
def tool
|
||||||
|
{
|
||||||
|
name: "get_weather",
|
||||||
|
description: "Get the weather in a city",
|
||||||
|
parameters: [
|
||||||
|
{ name: "location", type: "string", description: "the city name", required: true },
|
||||||
|
{
|
||||||
|
name: "unit",
|
||||||
|
type: "string",
|
||||||
|
description: "the unit of measurement celcius c or fahrenheit f",
|
||||||
|
enum: %w[c f],
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def with_chunk_array_support
|
||||||
|
mock = mocked_http
|
||||||
|
@original_net_http = ::Net.send(:remove_const, :HTTP)
|
||||||
|
::Net.send(:const_set, :HTTP, mock)
|
||||||
|
|
||||||
|
yield
|
||||||
|
ensure
|
||||||
|
::Net.send(:remove_const, :HTTP)
|
||||||
|
::Net.send(:const_set, :HTTP, @original_net_http)
|
||||||
|
end
|
||||||
|
|
||||||
|
protected
|
||||||
|
|
||||||
|
# Copied from https://github.com/bblimke/webmock/issues/629
|
||||||
|
# Workaround for stubbing a streamed response
|
||||||
|
def mocked_http
|
||||||
|
Class.new(::Net::HTTP) do
|
||||||
|
def request(*)
|
||||||
|
super do |response|
|
||||||
|
response.instance_eval do
|
||||||
|
def read_body(*, &block)
|
||||||
|
if block_given?
|
||||||
|
@body.each(&block)
|
||||||
|
else
|
||||||
|
super
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
yield response if block_given?
|
||||||
|
|
||||||
|
response
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
class EndpointsCompliance
|
||||||
|
def initialize(rspec, endpoint, dialect_klass, user)
|
||||||
|
@rspec = rspec
|
||||||
|
@endpoint = endpoint
|
||||||
|
@dialect_klass = dialect_klass
|
||||||
|
@user = user
|
||||||
|
end
|
||||||
|
|
||||||
|
delegate :expect, :eq, :be_present, to: :rspec
|
||||||
|
|
||||||
|
def generic_prompt(tools: [])
|
||||||
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
"You write words",
|
||||||
|
messages: [{ type: :user, content: "write 3 words" }],
|
||||||
|
tools: tools,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def dialect(prompt: generic_prompt)
|
||||||
|
dialect_klass.new(prompt, endpoint.model)
|
||||||
|
end
|
||||||
|
|
||||||
|
def regular_mode_simple_prompt(mock)
|
||||||
|
mock.stub_simple_call(dialect.translate)
|
||||||
|
|
||||||
|
completion_response = endpoint.perform_completion!(dialect, user)
|
||||||
|
|
||||||
|
expect(completion_response).to eq(mock.simple_response)
|
||||||
|
|
||||||
|
expect(AiApiAuditLog.count).to eq(1)
|
||||||
|
log = AiApiAuditLog.first
|
||||||
|
|
||||||
|
expect(log.provider_id).to eq(endpoint.provider_id)
|
||||||
|
expect(log.user_id).to eq(user.id)
|
||||||
|
expect(log.raw_request_payload).to be_present
|
||||||
|
expect(log.raw_response_payload).to eq(mock.response(completion_response).to_json)
|
||||||
|
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
|
||||||
|
expect(log.response_tokens).to eq(endpoint.tokenizer.size(completion_response))
|
||||||
|
end
|
||||||
|
|
||||||
|
def regular_mode_tools(mock)
|
||||||
|
prompt = generic_prompt(tools: [mock.tool])
|
||||||
|
a_dialect = dialect(prompt: prompt)
|
||||||
|
|
||||||
|
mock.stub_tool_call(a_dialect.translate)
|
||||||
|
|
||||||
|
completion_response = endpoint.perform_completion!(a_dialect, user)
|
||||||
|
|
||||||
|
expect(completion_response).to eq(mock.invocation_response)
|
||||||
|
end
|
||||||
|
|
||||||
|
def streaming_mode_simple_prompt(mock)
|
||||||
|
mock.stub_streamed_simple_call(dialect.translate) do
|
||||||
|
completion_response = +""
|
||||||
|
|
||||||
|
endpoint.perform_completion!(dialect, user) do |partial, cancel|
|
||||||
|
completion_response << partial
|
||||||
|
cancel.call if completion_response.split(" ").length == 2
|
||||||
|
end
|
||||||
|
|
||||||
|
expect(AiApiAuditLog.count).to eq(1)
|
||||||
|
log = AiApiAuditLog.first
|
||||||
|
|
||||||
|
expect(log.provider_id).to eq(endpoint.provider_id)
|
||||||
|
expect(log.user_id).to eq(user.id)
|
||||||
|
expect(log.raw_request_payload).to be_present
|
||||||
|
expect(log.raw_response_payload).to be_present
|
||||||
|
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
|
||||||
|
expect(log.response_tokens).to eq(
|
||||||
|
endpoint.tokenizer.size(mock.streamed_simple_deltas[0...-1].join),
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def streaming_mode_tools(mock)
|
||||||
|
prompt = generic_prompt(tools: [mock.tool])
|
||||||
|
a_dialect = dialect(prompt: prompt)
|
||||||
|
|
||||||
|
mock.stub_streamed_tool_call(a_dialect.translate) do
|
||||||
|
buffered_partial = +""
|
||||||
|
|
||||||
|
endpoint.perform_completion!(a_dialect, user) do |partial, cancel|
|
||||||
|
buffered_partial << partial
|
||||||
|
cancel.call if buffered_partial.include?("<function_calls>")
|
||||||
|
end
|
||||||
|
|
||||||
|
expect(buffered_partial).to eq(mock.invocation_response)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
attr_reader :rspec, :endpoint, :dialect_klass, :user
|
||||||
|
end
|
|
@ -1,175 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
RSpec.shared_examples "an endpoint that can communicate with a completion service" do
|
|
||||||
# Copied from https://github.com/bblimke/webmock/issues/629
|
|
||||||
# Workaround for stubbing a streamed response
|
|
||||||
before do
|
|
||||||
mocked_http =
|
|
||||||
Class.new(Net::HTTP) do
|
|
||||||
def request(*)
|
|
||||||
super do |response|
|
|
||||||
response.instance_eval do
|
|
||||||
def read_body(*, &block)
|
|
||||||
if block_given?
|
|
||||||
@body.each(&block)
|
|
||||||
else
|
|
||||||
super
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
yield response if block_given?
|
|
||||||
|
|
||||||
response
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
@original_net_http = Net.send(:remove_const, :HTTP)
|
|
||||||
Net.send(:const_set, :HTTP, mocked_http)
|
|
||||||
end
|
|
||||||
|
|
||||||
after do
|
|
||||||
Net.send(:remove_const, :HTTP)
|
|
||||||
Net.send(:const_set, :HTTP, @original_net_http)
|
|
||||||
end
|
|
||||||
|
|
||||||
let(:generic_prompt) do
|
|
||||||
DiscourseAi::Completions::Prompt.new(
|
|
||||||
"You write words",
|
|
||||||
messages: [{ type: :user, content: "write 3 words" }],
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
describe "#perform_completion!" do
|
|
||||||
fab!(:user) { Fabricate(:user) }
|
|
||||||
|
|
||||||
let(:tool) do
|
|
||||||
{
|
|
||||||
name: "get_weather",
|
|
||||||
description: "Get the weather in a city",
|
|
||||||
parameters: [
|
|
||||||
{ name: "location", type: "string", description: "the city name", required: true },
|
|
||||||
{
|
|
||||||
name: "unit",
|
|
||||||
type: "string",
|
|
||||||
description: "the unit of measurement celcius c or fahrenheit f",
|
|
||||||
enum: %w[c f],
|
|
||||||
required: true,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
end
|
|
||||||
|
|
||||||
let(:invocation) { <<~TEXT }
|
|
||||||
<function_calls>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>get_weather</tool_name>
|
|
||||||
<tool_id>#{tool_id || "get_weather"}</tool_id>
|
|
||||||
<parameters>
|
|
||||||
<location>Sydney</location>
|
|
||||||
<unit>c</unit>
|
|
||||||
</parameters>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
TEXT
|
|
||||||
|
|
||||||
context "when using regular mode" do
|
|
||||||
context "with simple prompts" do
|
|
||||||
let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" }
|
|
||||||
|
|
||||||
before { stub_response(prompt, response_text) }
|
|
||||||
|
|
||||||
it "can complete a trivial prompt" do
|
|
||||||
completion_response = model.perform_completion!(dialect, user)
|
|
||||||
|
|
||||||
expect(completion_response).to eq(response_text)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "creates an audit log for the request" do
|
|
||||||
model.perform_completion!(dialect, user)
|
|
||||||
|
|
||||||
expect(AiApiAuditLog.count).to eq(1)
|
|
||||||
log = AiApiAuditLog.first
|
|
||||||
|
|
||||||
response_body = response(response_text).to_json
|
|
||||||
|
|
||||||
expect(log.provider_id).to eq(model.provider_id)
|
|
||||||
expect(log.user_id).to eq(user.id)
|
|
||||||
expect(log.raw_request_payload).to eq(request_body)
|
|
||||||
expect(log.raw_response_payload).to eq(response_body)
|
|
||||||
expect(log.request_tokens).to eq(model.prompt_size(prompt))
|
|
||||||
expect(log.response_tokens).to eq(model.tokenizer.size(response_text))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "with functions" do
|
|
||||||
before do
|
|
||||||
generic_prompt.tools = [tool]
|
|
||||||
stub_response(prompt, tool_call, tool_call: true)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "returns a function invocation" do
|
|
||||||
completion_response = model.perform_completion!(dialect, user)
|
|
||||||
|
|
||||||
expect(completion_response).to eq(invocation)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "when using stream mode" do
|
|
||||||
context "with simple prompts" do
|
|
||||||
let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] }
|
|
||||||
|
|
||||||
before { stub_streamed_response(prompt, deltas) }
|
|
||||||
|
|
||||||
it "can complete a trivial prompt" do
|
|
||||||
completion_response = +""
|
|
||||||
|
|
||||||
model.perform_completion!(dialect, user) do |partial, cancel|
|
|
||||||
completion_response << partial
|
|
||||||
cancel.call if completion_response.split(" ").length == 2
|
|
||||||
end
|
|
||||||
|
|
||||||
expect(completion_response).to eq(deltas[0...-1].join)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "creates an audit log and updates is on each read." do
|
|
||||||
completion_response = +""
|
|
||||||
|
|
||||||
model.perform_completion!(dialect, user) do |partial, cancel|
|
|
||||||
completion_response << partial
|
|
||||||
cancel.call if completion_response.split(" ").length == 2
|
|
||||||
end
|
|
||||||
|
|
||||||
expect(AiApiAuditLog.count).to eq(1)
|
|
||||||
log = AiApiAuditLog.first
|
|
||||||
|
|
||||||
expect(log.provider_id).to eq(model.provider_id)
|
|
||||||
expect(log.user_id).to eq(user.id)
|
|
||||||
expect(log.raw_request_payload).to eq(stream_request_body)
|
|
||||||
expect(log.raw_response_payload).to be_present
|
|
||||||
expect(log.request_tokens).to eq(model.prompt_size(prompt))
|
|
||||||
expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "with functions" do
|
|
||||||
before do
|
|
||||||
generic_prompt.tools = [tool]
|
|
||||||
stub_streamed_response(prompt, tool_deltas, tool_call: true)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "waits for the invocation to finish before calling the partial" do
|
|
||||||
buffered_partial = ""
|
|
||||||
|
|
||||||
model.perform_completion!(dialect, user) do |partial, cancel|
|
|
||||||
buffered_partial = partial if partial.include?("<function_calls>")
|
|
||||||
end
|
|
||||||
|
|
||||||
expect(buffered_partial).to eq(invocation)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -1,69 +1,8 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
require_relative "endpoint_examples"
|
require_relative "endpoint_compliance"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|
||||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
|
||||||
|
|
||||||
let(:model_name) { "gemini-pro" }
|
|
||||||
let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) }
|
|
||||||
let(:prompt) { dialect.translate }
|
|
||||||
|
|
||||||
let(:tool_id) { "get_weather" }
|
|
||||||
|
|
||||||
let(:tool_payload) do
|
|
||||||
{
|
|
||||||
name: "get_weather",
|
|
||||||
description: "Get the weather in a city",
|
|
||||||
parameters: {
|
|
||||||
type: "object",
|
|
||||||
required: %w[location unit],
|
|
||||||
properties: {
|
|
||||||
"location" => {
|
|
||||||
type: "string",
|
|
||||||
description: "the city name",
|
|
||||||
},
|
|
||||||
"unit" => {
|
|
||||||
type: "string",
|
|
||||||
description: "the unit of measurement celcius c or fahrenheit f",
|
|
||||||
enum: %w[c f],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
end
|
|
||||||
|
|
||||||
let(:request_body) do
|
|
||||||
model
|
|
||||||
.default_options
|
|
||||||
.merge(contents: prompt)
|
|
||||||
.tap do |b|
|
|
||||||
b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt.tools.present?
|
|
||||||
end
|
|
||||||
.to_json
|
|
||||||
end
|
|
||||||
let(:stream_request_body) do
|
|
||||||
model
|
|
||||||
.default_options
|
|
||||||
.merge(contents: prompt)
|
|
||||||
.tap do |b|
|
|
||||||
b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt.tools.present?
|
|
||||||
end
|
|
||||||
.to_json
|
|
||||||
end
|
|
||||||
|
|
||||||
let(:tool_deltas) do
|
|
||||||
[
|
|
||||||
{ "functionCall" => { name: "get_weather", args: {} } },
|
|
||||||
{ "functionCall" => { name: "get_weather", args: { location: "" } } },
|
|
||||||
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
|
|
||||||
]
|
|
||||||
end
|
|
||||||
|
|
||||||
let(:tool_call) do
|
|
||||||
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }
|
|
||||||
end
|
|
||||||
|
|
||||||
|
class GeminiMock < EndpointMock
|
||||||
def response(content, tool_call: false)
|
def response(content, tool_call: false)
|
||||||
{
|
{
|
||||||
candidates: [
|
candidates: [
|
||||||
|
@ -97,9 +36,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(
|
.stub_request(
|
||||||
:post,
|
:post,
|
||||||
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
||||||
)
|
)
|
||||||
.with(body: request_body)
|
.with(body: request_body(prompt, tool_call))
|
||||||
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
|
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -139,11 +78,93 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(
|
.stub_request(
|
||||||
:post,
|
:post,
|
||||||
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
||||||
)
|
)
|
||||||
.with(body: stream_request_body)
|
.with(body: request_body(prompt, tool_call))
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
end
|
end
|
||||||
|
|
||||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
def tool_payload
|
||||||
|
{
|
||||||
|
name: "get_weather",
|
||||||
|
description: "Get the weather in a city",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
required: %w[location unit],
|
||||||
|
properties: {
|
||||||
|
"location" => {
|
||||||
|
type: "string",
|
||||||
|
description: "the city name",
|
||||||
|
},
|
||||||
|
"unit" => {
|
||||||
|
type: "string",
|
||||||
|
description: "the unit of measurement celcius c or fahrenheit f",
|
||||||
|
enum: %w[c f],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def request_body(prompt, tool_call)
|
||||||
|
model
|
||||||
|
.default_options
|
||||||
|
.merge(contents: prompt)
|
||||||
|
.tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if tool_call }
|
||||||
|
.to_json
|
||||||
|
end
|
||||||
|
|
||||||
|
def tool_deltas
|
||||||
|
[
|
||||||
|
{ "functionCall" => { name: "get_weather", args: {} } },
|
||||||
|
{ "functionCall" => { name: "get_weather", args: { location: "" } } },
|
||||||
|
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
|
def tool_response
|
||||||
|
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
|
subject(:endpoint) { described_class.new("gemini-pro", DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
||||||
|
|
||||||
|
fab!(:user) { Fabricate(:user) }
|
||||||
|
|
||||||
|
let(:bedrock_mock) { GeminiMock.new(endpoint) }
|
||||||
|
|
||||||
|
let(:compliance) do
|
||||||
|
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user)
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#perform_completion!" do
|
||||||
|
context "when using regular mode" do
|
||||||
|
context "with simple prompts" do
|
||||||
|
it "completes a trivial prompt and logs the response" do
|
||||||
|
compliance.regular_mode_simple_prompt(bedrock_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with tools" do
|
||||||
|
it "returns a function invocation" do
|
||||||
|
compliance.regular_mode_tools(bedrock_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "when using streaming mode" do
|
||||||
|
context "with simple prompts" do
|
||||||
|
it "completes a trivial prompt and logs the response" do
|
||||||
|
compliance.streaming_mode_simple_prompt(bedrock_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with tools" do
|
||||||
|
it "returns a function invoncation" do
|
||||||
|
compliance.streaming_mode_tools(bedrock_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,42 +1,8 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
require_relative "endpoint_examples"
|
require_relative "endpoint_compliance"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
|
||||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::Llama2Tokenizer) }
|
|
||||||
|
|
||||||
let(:model_name) { "Llama2-*-chat-hf" }
|
|
||||||
let(:dialect) do
|
|
||||||
DiscourseAi::Completions::Dialects::Llama2Classic.new(generic_prompt, model_name)
|
|
||||||
end
|
|
||||||
let(:prompt) { dialect.translate }
|
|
||||||
|
|
||||||
let(:tool_id) { "get_weather" }
|
|
||||||
|
|
||||||
let(:request_body) do
|
|
||||||
model
|
|
||||||
.default_options
|
|
||||||
.merge(inputs: prompt)
|
|
||||||
.tap do |payload|
|
|
||||||
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
|
||||||
model.prompt_size(prompt)
|
|
||||||
end
|
|
||||||
.to_json
|
|
||||||
end
|
|
||||||
let(:stream_request_body) do
|
|
||||||
model
|
|
||||||
.default_options
|
|
||||||
.merge(inputs: prompt)
|
|
||||||
.tap do |payload|
|
|
||||||
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
|
||||||
model.prompt_size(prompt)
|
|
||||||
payload[:stream] = true
|
|
||||||
end
|
|
||||||
.to_json
|
|
||||||
end
|
|
||||||
|
|
||||||
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
|
|
||||||
|
|
||||||
|
class HuggingFaceMock < EndpointMock
|
||||||
def response(content)
|
def response(content)
|
||||||
[{ generated_text: content }]
|
[{ generated_text: content }]
|
||||||
end
|
end
|
||||||
|
@ -44,7 +10,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||||
def stub_response(prompt, response_text, tool_call: false)
|
def stub_response(prompt, response_text, tool_call: false)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
||||||
.with(body: request_body)
|
.with(body: request_body(prompt))
|
||||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -75,33 +41,65 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
||||||
.with(body: stream_request_body)
|
.with(body: request_body(prompt, stream: true))
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:tool_deltas) { ["<function", <<~REPLY, <<~REPLY] }
|
def request_body(prompt, stream: false)
|
||||||
_calls>
|
model
|
||||||
<invoke>
|
.default_options
|
||||||
<tool_name>get_weather</tool_name>
|
.merge(inputs: prompt)
|
||||||
<parameters>
|
.tap do |payload|
|
||||||
<location>Sydney</location>
|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
||||||
<unit>c</unit>
|
model.prompt_size(prompt)
|
||||||
</parameters>
|
payload[:stream] = true if stream
|
||||||
</invoke>
|
end
|
||||||
</function_calls>
|
.to_json
|
||||||
REPLY
|
end
|
||||||
<function_calls>
|
end
|
||||||
<invoke>
|
|
||||||
<tool_name>get_weather</tool_name>
|
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||||
<parameters>
|
subject(:endpoint) do
|
||||||
<location>Sydney</location>
|
described_class.new("Llama2-*-chat-hf", DiscourseAi::Tokenizer::Llama2Tokenizer)
|
||||||
<unit>c</unit>
|
end
|
||||||
</parameters>
|
|
||||||
</invoke>
|
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
|
||||||
</function_calls>
|
|
||||||
REPLY
|
fab!(:user) { Fabricate(:user) }
|
||||||
|
|
||||||
let(:tool_call) { invocation }
|
let(:hf_mock) { HuggingFaceMock.new(endpoint) }
|
||||||
|
|
||||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
let(:compliance) do
|
||||||
|
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Llama2Classic, user)
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#perform_completion!" do
|
||||||
|
context "when using regular mode" do
|
||||||
|
context "with simple prompts" do
|
||||||
|
it "completes a trivial prompt and logs the response" do
|
||||||
|
compliance.regular_mode_simple_prompt(hf_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with tools" do
|
||||||
|
it "returns a function invocation" do
|
||||||
|
compliance.regular_mode_tools(hf_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "when using streaming mode" do
|
||||||
|
context "with simple prompts" do
|
||||||
|
it "completes a trivial prompt and logs the response" do
|
||||||
|
compliance.streaming_mode_simple_prompt(hf_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with tools" do
|
||||||
|
it "returns a function invoncation" do
|
||||||
|
compliance.streaming_mode_tools(hf_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,53 +1,8 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
require_relative "endpoint_examples"
|
require_relative "endpoint_compliance"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|
||||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
|
||||||
|
|
||||||
let(:model_name) { "gpt-3.5-turbo" }
|
|
||||||
let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) }
|
|
||||||
let(:prompt) { dialect.translate }
|
|
||||||
|
|
||||||
let(:tool_id) { "eujbuebfe" }
|
|
||||||
|
|
||||||
let(:tool_deltas) do
|
|
||||||
[
|
|
||||||
{ id: tool_id, function: {} },
|
|
||||||
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
|
||||||
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
|
||||||
{ id: tool_id, function: { name: "get_weather", arguments: "{" } },
|
|
||||||
{ id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } },
|
|
||||||
{ id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } },
|
|
||||||
]
|
|
||||||
end
|
|
||||||
|
|
||||||
let(:tool_call) do
|
|
||||||
{
|
|
||||||
id: tool_id,
|
|
||||||
function: {
|
|
||||||
name: "get_weather",
|
|
||||||
arguments: { location: "Sydney", unit: "c" }.to_json,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
end
|
|
||||||
|
|
||||||
let(:request_body) do
|
|
||||||
model
|
|
||||||
.default_options
|
|
||||||
.merge(messages: prompt)
|
|
||||||
.tap { |b| b[:tools] = dialect.tools if generic_prompt.tools.present? }
|
|
||||||
.to_json
|
|
||||||
end
|
|
||||||
|
|
||||||
let(:stream_request_body) do
|
|
||||||
model
|
|
||||||
.default_options
|
|
||||||
.merge(messages: prompt, stream: true)
|
|
||||||
.tap { |b| b[:tools] = dialect.tools if generic_prompt.tools.present? }
|
|
||||||
.to_json
|
|
||||||
end
|
|
||||||
|
|
||||||
|
class OpenAiMock < EndpointMock
|
||||||
def response(content, tool_call: false)
|
def response(content, tool_call: false)
|
||||||
message_content =
|
message_content =
|
||||||
if tool_call
|
if tool_call
|
||||||
|
@ -75,7 +30,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
def stub_response(prompt, response_text, tool_call: false)
|
def stub_response(prompt, response_text, tool_call: false)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
|
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
|
||||||
.with(body: request_body)
|
.with(body: request_body(prompt, tool_call: tool_call))
|
||||||
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
|
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -112,54 +67,110 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
|
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
|
||||||
.with(body: stream_request_body)
|
.with(body: request_body(prompt, stream: true, tool_call: tool_call))
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
end
|
end
|
||||||
|
|
||||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
def tool_deltas
|
||||||
|
[
|
||||||
|
{ id: tool_id, function: {} },
|
||||||
|
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
||||||
|
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
||||||
|
{ id: tool_id, function: { name: "get_weather", arguments: "{" } },
|
||||||
|
{ id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } },
|
||||||
|
{ id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } },
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
context "when chunked encoding returns partial chunks" do
|
def tool_response
|
||||||
# See: https://github.com/bblimke/webmock/issues/629
|
{
|
||||||
let(:mock_net_http) do
|
id: tool_id,
|
||||||
Class.new(Net::HTTP) do
|
function: {
|
||||||
def request(*)
|
name: "get_weather",
|
||||||
super do |response|
|
arguments: { location: "Sydney", unit: "c" }.to_json,
|
||||||
response.instance_eval do
|
},
|
||||||
def read_body(*, &)
|
}
|
||||||
@body.each(&)
|
end
|
||||||
|
|
||||||
|
def tool_id
|
||||||
|
"eujbuebfe"
|
||||||
|
end
|
||||||
|
|
||||||
|
def tool_payload
|
||||||
|
{
|
||||||
|
type: "function",
|
||||||
|
function: {
|
||||||
|
name: "get_weather",
|
||||||
|
description: "Get the weather in a city",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
location: {
|
||||||
|
type: "string",
|
||||||
|
description: "the city name",
|
||||||
|
},
|
||||||
|
unit: {
|
||||||
|
type: "string",
|
||||||
|
description: "the unit of measurement celcius c or fahrenheit f",
|
||||||
|
enum: %w[c f],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required: %w[location unit],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def request_body(prompt, stream: false, tool_call: false)
|
||||||
|
model
|
||||||
|
.default_options
|
||||||
|
.merge(messages: prompt)
|
||||||
|
.tap do |b|
|
||||||
|
b[:stream] = true if stream
|
||||||
|
b[:tools] = [tool_payload] if tool_call
|
||||||
|
end
|
||||||
|
.to_json
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
yield response if block_given?
|
RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
|
subject(:endpoint) do
|
||||||
response
|
described_class.new("gpt-3.5-turbo", DiscourseAi::Tokenizer::OpenAiTokenizer)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
fab!(:user) { Fabricate(:user) }
|
||||||
|
|
||||||
|
let(:open_ai_mock) { OpenAiMock.new(endpoint) }
|
||||||
|
|
||||||
|
let(:compliance) do
|
||||||
|
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::ChatGpt, user)
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#perform_completion!" do
|
||||||
|
context "when using regular mode" do
|
||||||
|
context "with simple prompts" do
|
||||||
|
it "completes a trivial prompt and logs the response" do
|
||||||
|
compliance.regular_mode_simple_prompt(open_ai_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with tools" do
|
||||||
|
it "returns a function invocation" do
|
||||||
|
compliance.regular_mode_tools(open_ai_mock)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:remove_original_net_http) { Net.send(:remove_const, :HTTP) }
|
describe "when using streaming mode" do
|
||||||
let(:original_http) { remove_original_net_http }
|
context "with simple prompts" do
|
||||||
let(:stub_net_http) { Net.send(:const_set, :HTTP, mock_net_http) }
|
it "completes a trivial prompt and logs the response" do
|
||||||
|
compliance.streaming_mode_simple_prompt(open_ai_mock)
|
||||||
let(:remove_stubbed_net_http) { Net.send(:remove_const, :HTTP) }
|
|
||||||
let(:restore_net_http) { Net.send(:const_set, :HTTP, original_http) }
|
|
||||||
|
|
||||||
before do
|
|
||||||
mock_net_http
|
|
||||||
remove_original_net_http
|
|
||||||
stub_net_http
|
|
||||||
end
|
|
||||||
|
|
||||||
after do
|
|
||||||
remove_stubbed_net_http
|
|
||||||
restore_net_http
|
|
||||||
end
|
end
|
||||||
|
|
||||||
it "will automatically recover from a bad payload" do
|
it "will automatically recover from a bad payload" do
|
||||||
# this should not happen, but lets ensure nothing bad happens
|
# this should not happen, but lets ensure nothing bad happens
|
||||||
# the row with test1 is invalid json
|
# the row with test1 is invalid json
|
||||||
raw_data = <<~TEXT
|
raw_data = <<~TEXT.strip
|
||||||
d|a|t|a|:| |{|"choices":[{"delta":{"content":"test,"}}]}
|
d|a|t|a|:| |{|"choices":[{"delta":{"content":"test,"}}]}
|
||||||
|
|
||||||
data: {"choices":[{"delta":{"content":"test1,"}}]
|
data: {"choices":[{"delta":{"content":"test1,"}}]
|
||||||
|
@ -175,51 +186,25 @@ data: [D|ONE]
|
||||||
|
|
||||||
chunks = raw_data.split("|")
|
chunks = raw_data.split("|")
|
||||||
|
|
||||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
open_ai_mock.with_chunk_array_support do
|
||||||
status: 200,
|
open_ai_mock.stub_streamed_response(compliance.dialect.translate, chunks) do
|
||||||
body: chunks,
|
|
||||||
)
|
|
||||||
|
|
||||||
partials = []
|
partials = []
|
||||||
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
|
|
||||||
llm.generate(
|
|
||||||
DiscourseAi::Completions::Prompt.new("test"),
|
|
||||||
user: Discourse.system_user,
|
|
||||||
) { |partial| partials << partial }
|
|
||||||
|
|
||||||
expect(partials.join).to eq("test,test2,test3,test4")
|
endpoint.perform_completion!(compliance.dialect, user) do |partial|
|
||||||
|
partials << partial
|
||||||
end
|
end
|
||||||
|
|
||||||
it "supports chunked encoding properly" do
|
|
||||||
raw_data = <<~TEXT
|
|
||||||
da|ta: {"choices":[{"delta":{"content":"test,"}}]}
|
|
||||||
|
|
||||||
data: {"choices":[{"delta":{"content":"test1,"}}]}
|
|
||||||
|
|
||||||
data: {"choices":[{"delta":|{"content":"test2,"}}]}
|
|
||||||
|
|
||||||
data: {"choices":[{"delta":{"content":"test3,"}}]|}
|
|
||||||
|
|
||||||
data: {"choices":[{|"|d|elta":{"content":"test4"}}]|}
|
|
||||||
|
|
||||||
data: [D|ONE]
|
|
||||||
TEXT
|
|
||||||
|
|
||||||
chunks = raw_data.split("|")
|
|
||||||
|
|
||||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
|
||||||
status: 200,
|
|
||||||
body: chunks,
|
|
||||||
)
|
|
||||||
|
|
||||||
partials = []
|
|
||||||
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
|
|
||||||
llm.generate(
|
|
||||||
DiscourseAi::Completions::Prompt.new("test"),
|
|
||||||
user: Discourse.system_user,
|
|
||||||
) { |partial| partials << partial }
|
|
||||||
|
|
||||||
expect(partials.join).to eq("test,test1,test2,test3,test4")
|
expect(partials.join).to eq("test,test1,test2,test3,test4")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with tools" do
|
||||||
|
it "returns a function invoncation" do
|
||||||
|
compliance.streaming_mode_tools(open_ai_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
|
@ -1,27 +1,14 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
require_relative "endpoint_examples"
|
require_relative "endpoint_compliance"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
|
||||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::MixtralTokenizer) }
|
|
||||||
|
|
||||||
let(:model_name) { "mistralai/Mixtral-8x7B-Instruct-v0.1" }
|
|
||||||
let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) }
|
|
||||||
let(:prompt) { dialect.translate }
|
|
||||||
|
|
||||||
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
|
||||||
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
|
|
||||||
|
|
||||||
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
|
|
||||||
|
|
||||||
let(:tool_id) { "get_weather" }
|
|
||||||
|
|
||||||
|
class VllmMock < EndpointMock
|
||||||
def response(content)
|
def response(content)
|
||||||
{
|
{
|
||||||
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
|
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
|
||||||
object: "text_completion",
|
object: "text_completion",
|
||||||
created: 1_678_464_820,
|
created: 1_678_464_820,
|
||||||
model: model_name,
|
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
usage: {
|
usage: {
|
||||||
prompt_tokens: 337,
|
prompt_tokens: 337,
|
||||||
completion_tokens: 162,
|
completion_tokens: 162,
|
||||||
|
@ -34,7 +21,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
||||||
def stub_response(prompt, response_text, tool_call: false)
|
def stub_response(prompt, response_text, tool_call: false)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
|
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
|
||||||
.with(body: request_body)
|
.with(body: model.default_options.merge(prompt: prompt).to_json)
|
||||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -42,7 +29,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
||||||
+"data: " << {
|
+"data: " << {
|
||||||
id: "cmpl-#{SecureRandom.hex}",
|
id: "cmpl-#{SecureRandom.hex}",
|
||||||
created: 1_681_283_881,
|
created: 1_681_283_881,
|
||||||
model: model_name,
|
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
choices: [{ text: delta, finish_reason: finish_reason, index: 0 }],
|
choices: [{ text: delta, finish_reason: finish_reason, index: 0 }],
|
||||||
index: 0,
|
index: 0,
|
||||||
}.to_json
|
}.to_json
|
||||||
|
@ -62,33 +49,62 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
|
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
|
||||||
.with(body: stream_request_body)
|
.with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
end
|
end
|
||||||
|
end
|
||||||
let(:tool_deltas) { ["<function", <<~REPLY, <<~REPLY] }
|
|
||||||
_calls>
|
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
||||||
<invoke>
|
subject(:endpoint) do
|
||||||
<tool_name>get_weather</tool_name>
|
described_class.new(
|
||||||
<parameters>
|
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
<location>Sydney</location>
|
DiscourseAi::Tokenizer::MixtralTokenizer,
|
||||||
<unit>c</unit>
|
)
|
||||||
</parameters>
|
end
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
fab!(:user) { Fabricate(:user) }
|
||||||
REPLY
|
|
||||||
<function_calls>
|
let(:anthropic_mock) { VllmMock.new(endpoint) }
|
||||||
<invoke>
|
|
||||||
<tool_name>get_weather</tool_name>
|
let(:compliance) do
|
||||||
<parameters>
|
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mixtral, user)
|
||||||
<location>Sydney</location>
|
end
|
||||||
<unit>c</unit>
|
|
||||||
</parameters>
|
let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) }
|
||||||
</invoke>
|
let(:prompt) { dialect.translate }
|
||||||
</function_calls>
|
|
||||||
REPLY
|
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
||||||
|
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
|
||||||
let(:tool_call) { invocation }
|
|
||||||
|
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
|
||||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
|
||||||
|
describe "#perform_completion!" do
|
||||||
|
context "when using regular mode" do
|
||||||
|
context "with simple prompts" do
|
||||||
|
it "completes a trivial prompt and logs the response" do
|
||||||
|
compliance.regular_mode_simple_prompt(anthropic_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with tools" do
|
||||||
|
it "returns a function invocation" do
|
||||||
|
compliance.regular_mode_tools(anthropic_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "when using streaming mode" do
|
||||||
|
context "with simple prompts" do
|
||||||
|
it "completes a trivial prompt and logs the response" do
|
||||||
|
compliance.streaming_mode_simple_prompt(anthropic_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with tools" do
|
||||||
|
it "returns a function invoncation" do
|
||||||
|
compliance.streaming_mode_tools(anthropic_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue