DEV: Stop using shared_examples for endpoint specs (#430)
This commit is contained in:
parent
8eb1e851fc
commit
5bdf3dc1f4
|
@ -205,12 +205,10 @@ module DiscourseAi
|
|||
tokenizer.size(extract_prompt_for_tokenizer(prompt))
|
||||
end
|
||||
|
||||
attr_reader :tokenizer
|
||||
attr_reader :tokenizer, :model
|
||||
|
||||
protected
|
||||
|
||||
attr_reader :model
|
||||
|
||||
# should normalize temperature, max_tokens, stop_words to endpoint specific values
|
||||
def normalize_model_params(model_params)
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,19 +1,8 @@
|
|||
# frozen_String_literal: true
|
||||
|
||||
require_relative "endpoint_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
||||
|
||||
let(:model_name) { "claude-2" }
|
||||
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
|
||||
let(:prompt) { dialect.translate }
|
||||
|
||||
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
||||
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
|
||||
|
||||
let(:tool_id) { "get_weather" }
|
||||
require_relative "endpoint_compliance"
|
||||
|
||||
class AnthropicMock < EndpointMock
|
||||
def response(content)
|
||||
{
|
||||
completion: content,
|
||||
|
@ -21,7 +10,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
stop_reason: "stop_sequence",
|
||||
truncated: false,
|
||||
log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
|
||||
model: model_name,
|
||||
model: "claude-2",
|
||||
exception: nil,
|
||||
}
|
||||
end
|
||||
|
@ -29,7 +18,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
def stub_response(prompt, response_text, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
||||
.with(body: request_body)
|
||||
.with(body: model.default_options.merge(prompt: prompt).to_json)
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
end
|
||||
|
||||
|
@ -59,23 +48,49 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
||||
.with(body: stream_request_body)
|
||||
.with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
let(:tool_deltas) { ["Let me use a tool for that<function", <<~REPLY] }
|
||||
_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
REPLY
|
||||
|
||||
let(:tool_call) { invocation }
|
||||
|
||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||
subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
||||
|
||||
fab!(:user) { Fabricate(:user) }
|
||||
|
||||
let(:anthropic_mock) { AnthropicMock.new(endpoint) }
|
||||
|
||||
let(:compliance) do
|
||||
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
|
||||
end
|
||||
|
||||
describe "#perform_completion!" do
|
||||
context "when using regular mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.regular_mode_simple_prompt(anthropic_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.regular_mode_tools(anthropic_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "when using streaming mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.streaming_mode_simple_prompt(anthropic_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invoncation" do
|
||||
compliance.streaming_mode_tools(anthropic_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,28 +1,10 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "endpoint_examples"
|
||||
require_relative "endpoint_compliance"
|
||||
require "aws-eventstream"
|
||||
require "aws-sigv4"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
||||
|
||||
let(:model_name) { "claude-2" }
|
||||
let(:bedrock_name) { "claude-v2:1" }
|
||||
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
|
||||
let(:prompt) { dialect.translate }
|
||||
|
||||
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
||||
let(:stream_request_body) { request_body }
|
||||
|
||||
let(:tool_id) { "get_weather" }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_bedrock_access_key_id = "123456"
|
||||
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
|
||||
SiteSetting.ai_bedrock_region = "us-east-1"
|
||||
end
|
||||
|
||||
class BedrockMock < EndpointMock
|
||||
def response(content)
|
||||
{
|
||||
completion: content,
|
||||
|
@ -30,19 +12,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
stop_reason: "stop_sequence",
|
||||
truncated: false,
|
||||
log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
|
||||
model: model_name,
|
||||
model: "claude",
|
||||
exception: nil,
|
||||
}
|
||||
end
|
||||
|
||||
def stub_response(prompt, response_text, tool_call: false)
|
||||
def stub_response(prompt, response_content, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(
|
||||
:post,
|
||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_name}/invoke",
|
||||
)
|
||||
.with(body: request_body)
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
.stub_request(:post, "#{base_url}/invoke")
|
||||
.with(body: model.default_options.merge(prompt: prompt).to_json)
|
||||
.to_return(status: 200, body: JSON.dump(response(response_content)))
|
||||
end
|
||||
|
||||
def stream_line(delta, finish_reason: nil)
|
||||
|
@ -83,27 +62,60 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
end
|
||||
|
||||
WebMock
|
||||
.stub_request(
|
||||
:post,
|
||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_name}/invoke-with-response-stream",
|
||||
)
|
||||
.with(body: stream_request_body)
|
||||
.stub_request(:post, "#{base_url}/invoke-with-response-stream")
|
||||
.with(body: model.default_options.merge(prompt: prompt).to_json)
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
let(:tool_deltas) { ["<function", <<~REPLY] }
|
||||
_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
REPLY
|
||||
|
||||
let(:tool_call) { invocation }
|
||||
|
||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||
def base_url
|
||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.claude-v2:1"
|
||||
end
|
||||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||
subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
||||
|
||||
fab!(:user) { Fabricate(:user) }
|
||||
|
||||
let(:bedrock_mock) { BedrockMock.new(endpoint) }
|
||||
|
||||
let(:compliance) do
|
||||
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
|
||||
end
|
||||
|
||||
before do
|
||||
SiteSetting.ai_bedrock_access_key_id = "123456"
|
||||
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
|
||||
SiteSetting.ai_bedrock_region = "us-east-1"
|
||||
end
|
||||
|
||||
describe "#perform_completion!" do
|
||||
context "when using regular mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.regular_mode_simple_prompt(bedrock_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.regular_mode_tools(bedrock_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "when using streaming mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.streaming_mode_simple_prompt(bedrock_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invoncation" do
|
||||
compliance.streaming_mode_tools(bedrock_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,229 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "net/http"
|
||||
|
||||
class EndpointMock
|
||||
def initialize(model)
|
||||
@model = model
|
||||
end
|
||||
|
||||
attr_reader :model
|
||||
|
||||
def stub_simple_call(prompt)
|
||||
stub_response(prompt, simple_response)
|
||||
end
|
||||
|
||||
def stub_tool_call(prompt)
|
||||
stub_response(prompt, tool_response, tool_call: true)
|
||||
end
|
||||
|
||||
def stub_streamed_simple_call(prompt)
|
||||
with_chunk_array_support do
|
||||
stub_streamed_response(prompt, streamed_simple_deltas)
|
||||
yield
|
||||
end
|
||||
end
|
||||
|
||||
def stub_streamed_tool_call(prompt)
|
||||
with_chunk_array_support do
|
||||
stub_streamed_response(prompt, tool_deltas, tool_call: true)
|
||||
yield
|
||||
end
|
||||
end
|
||||
|
||||
def simple_response
|
||||
"1. Serenity\\n2. Laughter\\n3. Adventure"
|
||||
end
|
||||
|
||||
def streamed_simple_deltas
|
||||
["Mount", "ain", " ", "Tree ", "Frog"]
|
||||
end
|
||||
|
||||
def tool_deltas
|
||||
["Let me use a tool for that<function", <<~REPLY.strip, <<~REPLY.strip, <<~REPLY.strip]
|
||||
_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</para
|
||||
REPLY
|
||||
meters>
|
||||
</invoke>
|
||||
</funct
|
||||
REPLY
|
||||
ion_calls>
|
||||
REPLY
|
||||
end
|
||||
|
||||
def tool_response
|
||||
tool_deltas.join
|
||||
end
|
||||
|
||||
def invocation_response
|
||||
<<~TEXT
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<tool_id>#{tool_id}</tool_id>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
end
|
||||
|
||||
def tool_id
|
||||
"get_weather"
|
||||
end
|
||||
|
||||
def tool
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: [
|
||||
{ name: "location", type: "string", description: "the city name", required: true },
|
||||
{
|
||||
name: "unit",
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def with_chunk_array_support
|
||||
mock = mocked_http
|
||||
@original_net_http = ::Net.send(:remove_const, :HTTP)
|
||||
::Net.send(:const_set, :HTTP, mock)
|
||||
|
||||
yield
|
||||
ensure
|
||||
::Net.send(:remove_const, :HTTP)
|
||||
::Net.send(:const_set, :HTTP, @original_net_http)
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
# Copied from https://github.com/bblimke/webmock/issues/629
|
||||
# Workaround for stubbing a streamed response
|
||||
def mocked_http
|
||||
Class.new(::Net::HTTP) do
|
||||
def request(*)
|
||||
super do |response|
|
||||
response.instance_eval do
|
||||
def read_body(*, &block)
|
||||
if block_given?
|
||||
@body.each(&block)
|
||||
else
|
||||
super
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
yield response if block_given?
|
||||
|
||||
response
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
class EndpointsCompliance
|
||||
def initialize(rspec, endpoint, dialect_klass, user)
|
||||
@rspec = rspec
|
||||
@endpoint = endpoint
|
||||
@dialect_klass = dialect_klass
|
||||
@user = user
|
||||
end
|
||||
|
||||
delegate :expect, :eq, :be_present, to: :rspec
|
||||
|
||||
def generic_prompt(tools: [])
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You write words",
|
||||
messages: [{ type: :user, content: "write 3 words" }],
|
||||
tools: tools,
|
||||
)
|
||||
end
|
||||
|
||||
def dialect(prompt: generic_prompt)
|
||||
dialect_klass.new(prompt, endpoint.model)
|
||||
end
|
||||
|
||||
def regular_mode_simple_prompt(mock)
|
||||
mock.stub_simple_call(dialect.translate)
|
||||
|
||||
completion_response = endpoint.perform_completion!(dialect, user)
|
||||
|
||||
expect(completion_response).to eq(mock.simple_response)
|
||||
|
||||
expect(AiApiAuditLog.count).to eq(1)
|
||||
log = AiApiAuditLog.first
|
||||
|
||||
expect(log.provider_id).to eq(endpoint.provider_id)
|
||||
expect(log.user_id).to eq(user.id)
|
||||
expect(log.raw_request_payload).to be_present
|
||||
expect(log.raw_response_payload).to eq(mock.response(completion_response).to_json)
|
||||
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
|
||||
expect(log.response_tokens).to eq(endpoint.tokenizer.size(completion_response))
|
||||
end
|
||||
|
||||
def regular_mode_tools(mock)
|
||||
prompt = generic_prompt(tools: [mock.tool])
|
||||
a_dialect = dialect(prompt: prompt)
|
||||
|
||||
mock.stub_tool_call(a_dialect.translate)
|
||||
|
||||
completion_response = endpoint.perform_completion!(a_dialect, user)
|
||||
|
||||
expect(completion_response).to eq(mock.invocation_response)
|
||||
end
|
||||
|
||||
def streaming_mode_simple_prompt(mock)
|
||||
mock.stub_streamed_simple_call(dialect.translate) do
|
||||
completion_response = +""
|
||||
|
||||
endpoint.perform_completion!(dialect, user) do |partial, cancel|
|
||||
completion_response << partial
|
||||
cancel.call if completion_response.split(" ").length == 2
|
||||
end
|
||||
|
||||
expect(AiApiAuditLog.count).to eq(1)
|
||||
log = AiApiAuditLog.first
|
||||
|
||||
expect(log.provider_id).to eq(endpoint.provider_id)
|
||||
expect(log.user_id).to eq(user.id)
|
||||
expect(log.raw_request_payload).to be_present
|
||||
expect(log.raw_response_payload).to be_present
|
||||
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
|
||||
expect(log.response_tokens).to eq(
|
||||
endpoint.tokenizer.size(mock.streamed_simple_deltas[0...-1].join),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def streaming_mode_tools(mock)
|
||||
prompt = generic_prompt(tools: [mock.tool])
|
||||
a_dialect = dialect(prompt: prompt)
|
||||
|
||||
mock.stub_streamed_tool_call(a_dialect.translate) do
|
||||
buffered_partial = +""
|
||||
|
||||
endpoint.perform_completion!(a_dialect, user) do |partial, cancel|
|
||||
buffered_partial << partial
|
||||
cancel.call if buffered_partial.include?("<function_calls>")
|
||||
end
|
||||
|
||||
expect(buffered_partial).to eq(mock.invocation_response)
|
||||
end
|
||||
end
|
||||
|
||||
attr_reader :rspec, :endpoint, :dialect_klass, :user
|
||||
end
|
|
@ -1,175 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.shared_examples "an endpoint that can communicate with a completion service" do
|
||||
# Copied from https://github.com/bblimke/webmock/issues/629
|
||||
# Workaround for stubbing a streamed response
|
||||
before do
|
||||
mocked_http =
|
||||
Class.new(Net::HTTP) do
|
||||
def request(*)
|
||||
super do |response|
|
||||
response.instance_eval do
|
||||
def read_body(*, &block)
|
||||
if block_given?
|
||||
@body.each(&block)
|
||||
else
|
||||
super
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
yield response if block_given?
|
||||
|
||||
response
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@original_net_http = Net.send(:remove_const, :HTTP)
|
||||
Net.send(:const_set, :HTTP, mocked_http)
|
||||
end
|
||||
|
||||
after do
|
||||
Net.send(:remove_const, :HTTP)
|
||||
Net.send(:const_set, :HTTP, @original_net_http)
|
||||
end
|
||||
|
||||
let(:generic_prompt) do
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You write words",
|
||||
messages: [{ type: :user, content: "write 3 words" }],
|
||||
)
|
||||
end
|
||||
|
||||
describe "#perform_completion!" do
|
||||
fab!(:user) { Fabricate(:user) }
|
||||
|
||||
let(:tool) do
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: [
|
||||
{ name: "location", type: "string", description: "the city name", required: true },
|
||||
{
|
||||
name: "unit",
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
let(:invocation) { <<~TEXT }
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<tool_id>#{tool_id || "get_weather"}</tool_id>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
|
||||
context "when using regular mode" do
|
||||
context "with simple prompts" do
|
||||
let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" }
|
||||
|
||||
before { stub_response(prompt, response_text) }
|
||||
|
||||
it "can complete a trivial prompt" do
|
||||
completion_response = model.perform_completion!(dialect, user)
|
||||
|
||||
expect(completion_response).to eq(response_text)
|
||||
end
|
||||
|
||||
it "creates an audit log for the request" do
|
||||
model.perform_completion!(dialect, user)
|
||||
|
||||
expect(AiApiAuditLog.count).to eq(1)
|
||||
log = AiApiAuditLog.first
|
||||
|
||||
response_body = response(response_text).to_json
|
||||
|
||||
expect(log.provider_id).to eq(model.provider_id)
|
||||
expect(log.user_id).to eq(user.id)
|
||||
expect(log.raw_request_payload).to eq(request_body)
|
||||
expect(log.raw_response_payload).to eq(response_body)
|
||||
expect(log.request_tokens).to eq(model.prompt_size(prompt))
|
||||
expect(log.response_tokens).to eq(model.tokenizer.size(response_text))
|
||||
end
|
||||
end
|
||||
|
||||
context "with functions" do
|
||||
before do
|
||||
generic_prompt.tools = [tool]
|
||||
stub_response(prompt, tool_call, tool_call: true)
|
||||
end
|
||||
|
||||
it "returns a function invocation" do
|
||||
completion_response = model.perform_completion!(dialect, user)
|
||||
|
||||
expect(completion_response).to eq(invocation)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
context "when using stream mode" do
|
||||
context "with simple prompts" do
|
||||
let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] }
|
||||
|
||||
before { stub_streamed_response(prompt, deltas) }
|
||||
|
||||
it "can complete a trivial prompt" do
|
||||
completion_response = +""
|
||||
|
||||
model.perform_completion!(dialect, user) do |partial, cancel|
|
||||
completion_response << partial
|
||||
cancel.call if completion_response.split(" ").length == 2
|
||||
end
|
||||
|
||||
expect(completion_response).to eq(deltas[0...-1].join)
|
||||
end
|
||||
|
||||
it "creates an audit log and updates is on each read." do
|
||||
completion_response = +""
|
||||
|
||||
model.perform_completion!(dialect, user) do |partial, cancel|
|
||||
completion_response << partial
|
||||
cancel.call if completion_response.split(" ").length == 2
|
||||
end
|
||||
|
||||
expect(AiApiAuditLog.count).to eq(1)
|
||||
log = AiApiAuditLog.first
|
||||
|
||||
expect(log.provider_id).to eq(model.provider_id)
|
||||
expect(log.user_id).to eq(user.id)
|
||||
expect(log.raw_request_payload).to eq(stream_request_body)
|
||||
expect(log.raw_response_payload).to be_present
|
||||
expect(log.request_tokens).to eq(model.prompt_size(prompt))
|
||||
expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join))
|
||||
end
|
||||
end
|
||||
|
||||
context "with functions" do
|
||||
before do
|
||||
generic_prompt.tools = [tool]
|
||||
stub_streamed_response(prompt, tool_deltas, tool_call: true)
|
||||
end
|
||||
|
||||
it "waits for the invocation to finish before calling the partial" do
|
||||
buffered_partial = ""
|
||||
|
||||
model.perform_completion!(dialect, user) do |partial, cancel|
|
||||
buffered_partial = partial if partial.include?("<function_calls>")
|
||||
end
|
||||
|
||||
expect(buffered_partial).to eq(invocation)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,69 +1,8 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "endpoint_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
||||
|
||||
let(:model_name) { "gemini-pro" }
|
||||
let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) }
|
||||
let(:prompt) { dialect.translate }
|
||||
|
||||
let(:tool_id) { "get_weather" }
|
||||
|
||||
let(:tool_payload) do
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: %w[location unit],
|
||||
properties: {
|
||||
"location" => {
|
||||
type: "string",
|
||||
description: "the city name",
|
||||
},
|
||||
"unit" => {
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
let(:request_body) do
|
||||
model
|
||||
.default_options
|
||||
.merge(contents: prompt)
|
||||
.tap do |b|
|
||||
b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt.tools.present?
|
||||
end
|
||||
.to_json
|
||||
end
|
||||
let(:stream_request_body) do
|
||||
model
|
||||
.default_options
|
||||
.merge(contents: prompt)
|
||||
.tap do |b|
|
||||
b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt.tools.present?
|
||||
end
|
||||
.to_json
|
||||
end
|
||||
|
||||
let(:tool_deltas) do
|
||||
[
|
||||
{ "functionCall" => { name: "get_weather", args: {} } },
|
||||
{ "functionCall" => { name: "get_weather", args: { location: "" } } },
|
||||
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
|
||||
]
|
||||
end
|
||||
|
||||
let(:tool_call) do
|
||||
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }
|
||||
end
|
||||
require_relative "endpoint_compliance"
|
||||
|
||||
class GeminiMock < EndpointMock
|
||||
def response(content, tool_call: false)
|
||||
{
|
||||
candidates: [
|
||||
|
@ -97,9 +36,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
WebMock
|
||||
.stub_request(
|
||||
:post,
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
||||
)
|
||||
.with(body: request_body)
|
||||
.with(body: request_body(prompt, tool_call))
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
|
||||
end
|
||||
|
||||
|
@ -139,11 +78,93 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
WebMock
|
||||
.stub_request(
|
||||
:post,
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
||||
)
|
||||
.with(body: stream_request_body)
|
||||
.with(body: request_body(prompt, tool_call))
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||
def tool_payload
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: %w[location unit],
|
||||
properties: {
|
||||
"location" => {
|
||||
type: "string",
|
||||
description: "the city name",
|
||||
},
|
||||
"unit" => {
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
def request_body(prompt, tool_call)
|
||||
model
|
||||
.default_options
|
||||
.merge(contents: prompt)
|
||||
.tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if tool_call }
|
||||
.to_json
|
||||
end
|
||||
|
||||
def tool_deltas
|
||||
[
|
||||
{ "functionCall" => { name: "get_weather", args: {} } },
|
||||
{ "functionCall" => { name: "get_weather", args: { location: "" } } },
|
||||
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
|
||||
]
|
||||
end
|
||||
|
||||
def tool_response
|
||||
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }
|
||||
end
|
||||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||
subject(:endpoint) { described_class.new("gemini-pro", DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
||||
|
||||
fab!(:user) { Fabricate(:user) }
|
||||
|
||||
let(:bedrock_mock) { GeminiMock.new(endpoint) }
|
||||
|
||||
let(:compliance) do
|
||||
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user)
|
||||
end
|
||||
|
||||
describe "#perform_completion!" do
|
||||
context "when using regular mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.regular_mode_simple_prompt(bedrock_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.regular_mode_tools(bedrock_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "when using streaming mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.streaming_mode_simple_prompt(bedrock_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invoncation" do
|
||||
compliance.streaming_mode_tools(bedrock_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,42 +1,8 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "endpoint_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::Llama2Tokenizer) }
|
||||
|
||||
let(:model_name) { "Llama2-*-chat-hf" }
|
||||
let(:dialect) do
|
||||
DiscourseAi::Completions::Dialects::Llama2Classic.new(generic_prompt, model_name)
|
||||
end
|
||||
let(:prompt) { dialect.translate }
|
||||
|
||||
let(:tool_id) { "get_weather" }
|
||||
|
||||
let(:request_body) do
|
||||
model
|
||||
.default_options
|
||||
.merge(inputs: prompt)
|
||||
.tap do |payload|
|
||||
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
||||
model.prompt_size(prompt)
|
||||
end
|
||||
.to_json
|
||||
end
|
||||
let(:stream_request_body) do
|
||||
model
|
||||
.default_options
|
||||
.merge(inputs: prompt)
|
||||
.tap do |payload|
|
||||
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
||||
model.prompt_size(prompt)
|
||||
payload[:stream] = true
|
||||
end
|
||||
.to_json
|
||||
end
|
||||
|
||||
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
|
||||
require_relative "endpoint_compliance"
|
||||
|
||||
class HuggingFaceMock < EndpointMock
|
||||
def response(content)
|
||||
[{ generated_text: content }]
|
||||
end
|
||||
|
@ -44,7 +10,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
|||
def stub_response(prompt, response_text, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
||||
.with(body: request_body)
|
||||
.with(body: request_body(prompt))
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
end
|
||||
|
||||
|
@ -75,33 +41,65 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
|||
|
||||
WebMock
|
||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
||||
.with(body: stream_request_body)
|
||||
.with(body: request_body(prompt, stream: true))
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
let(:tool_deltas) { ["<function", <<~REPLY, <<~REPLY] }
|
||||
_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
REPLY
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
REPLY
|
||||
|
||||
let(:tool_call) { invocation }
|
||||
|
||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||
def request_body(prompt, stream: false)
|
||||
model
|
||||
.default_options
|
||||
.merge(inputs: prompt)
|
||||
.tap do |payload|
|
||||
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
||||
model.prompt_size(prompt)
|
||||
payload[:stream] = true if stream
|
||||
end
|
||||
.to_json
|
||||
end
|
||||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||
subject(:endpoint) do
|
||||
described_class.new("Llama2-*-chat-hf", DiscourseAi::Tokenizer::Llama2Tokenizer)
|
||||
end
|
||||
|
||||
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
|
||||
|
||||
fab!(:user) { Fabricate(:user) }
|
||||
|
||||
let(:hf_mock) { HuggingFaceMock.new(endpoint) }
|
||||
|
||||
let(:compliance) do
|
||||
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Llama2Classic, user)
|
||||
end
|
||||
|
||||
describe "#perform_completion!" do
|
||||
context "when using regular mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.regular_mode_simple_prompt(hf_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.regular_mode_tools(hf_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "when using streaming mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.streaming_mode_simple_prompt(hf_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invoncation" do
|
||||
compliance.streaming_mode_tools(hf_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,53 +1,8 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "endpoint_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
||||
|
||||
let(:model_name) { "gpt-3.5-turbo" }
|
||||
let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) }
|
||||
let(:prompt) { dialect.translate }
|
||||
|
||||
let(:tool_id) { "eujbuebfe" }
|
||||
|
||||
let(:tool_deltas) do
|
||||
[
|
||||
{ id: tool_id, function: {} },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: "{" } },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } },
|
||||
]
|
||||
end
|
||||
|
||||
let(:tool_call) do
|
||||
{
|
||||
id: tool_id,
|
||||
function: {
|
||||
name: "get_weather",
|
||||
arguments: { location: "Sydney", unit: "c" }.to_json,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
let(:request_body) do
|
||||
model
|
||||
.default_options
|
||||
.merge(messages: prompt)
|
||||
.tap { |b| b[:tools] = dialect.tools if generic_prompt.tools.present? }
|
||||
.to_json
|
||||
end
|
||||
|
||||
let(:stream_request_body) do
|
||||
model
|
||||
.default_options
|
||||
.merge(messages: prompt, stream: true)
|
||||
.tap { |b| b[:tools] = dialect.tools if generic_prompt.tools.present? }
|
||||
.to_json
|
||||
end
|
||||
require_relative "endpoint_compliance"
|
||||
|
||||
class OpenAiMock < EndpointMock
|
||||
def response(content, tool_call: false)
|
||||
message_content =
|
||||
if tool_call
|
||||
|
@ -75,7 +30,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|||
def stub_response(prompt, response_text, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
|
||||
.with(body: request_body)
|
||||
.with(body: request_body(prompt, tool_call: tool_call))
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
|
||||
end
|
||||
|
||||
|
@ -112,114 +67,144 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|||
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
|
||||
.with(body: stream_request_body)
|
||||
.with(body: request_body(prompt, stream: true, tool_call: tool_call))
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||
def tool_deltas
|
||||
[
|
||||
{ id: tool_id, function: {} },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: "{" } },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } },
|
||||
]
|
||||
end
|
||||
|
||||
context "when chunked encoding returns partial chunks" do
|
||||
# See: https://github.com/bblimke/webmock/issues/629
|
||||
let(:mock_net_http) do
|
||||
Class.new(Net::HTTP) do
|
||||
def request(*)
|
||||
super do |response|
|
||||
response.instance_eval do
|
||||
def read_body(*, &)
|
||||
@body.each(&)
|
||||
end
|
||||
end
|
||||
def tool_response
|
||||
{
|
||||
id: tool_id,
|
||||
function: {
|
||||
name: "get_weather",
|
||||
arguments: { location: "Sydney", unit: "c" }.to_json,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
yield response if block_given?
|
||||
def tool_id
|
||||
"eujbuebfe"
|
||||
end
|
||||
|
||||
response
|
||||
end
|
||||
def tool_payload
|
||||
{
|
||||
type: "function",
|
||||
function: {
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: {
|
||||
type: "string",
|
||||
description: "the city name",
|
||||
},
|
||||
unit: {
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
},
|
||||
},
|
||||
required: %w[location unit],
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
def request_body(prompt, stream: false, tool_call: false)
|
||||
model
|
||||
.default_options
|
||||
.merge(messages: prompt)
|
||||
.tap do |b|
|
||||
b[:stream] = true if stream
|
||||
b[:tools] = [tool_payload] if tool_call
|
||||
end
|
||||
.to_json
|
||||
end
|
||||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||
subject(:endpoint) do
|
||||
described_class.new("gpt-3.5-turbo", DiscourseAi::Tokenizer::OpenAiTokenizer)
|
||||
end
|
||||
|
||||
fab!(:user) { Fabricate(:user) }
|
||||
|
||||
let(:open_ai_mock) { OpenAiMock.new(endpoint) }
|
||||
|
||||
let(:compliance) do
|
||||
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::ChatGpt, user)
|
||||
end
|
||||
|
||||
describe "#perform_completion!" do
|
||||
context "when using regular mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.regular_mode_simple_prompt(open_ai_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.regular_mode_tools(open_ai_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
let(:remove_original_net_http) { Net.send(:remove_const, :HTTP) }
|
||||
let(:original_http) { remove_original_net_http }
|
||||
let(:stub_net_http) { Net.send(:const_set, :HTTP, mock_net_http) }
|
||||
describe "when using streaming mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.streaming_mode_simple_prompt(open_ai_mock)
|
||||
end
|
||||
|
||||
let(:remove_stubbed_net_http) { Net.send(:remove_const, :HTTP) }
|
||||
let(:restore_net_http) { Net.send(:const_set, :HTTP, original_http) }
|
||||
it "will automatically recover from a bad payload" do
|
||||
# this should not happen, but lets ensure nothing bad happens
|
||||
# the row with test1 is invalid json
|
||||
raw_data = <<~TEXT.strip
|
||||
d|a|t|a|:| |{|"choices":[{"delta":{"content":"test,"}}]}
|
||||
|
||||
before do
|
||||
mock_net_http
|
||||
remove_original_net_http
|
||||
stub_net_http
|
||||
end
|
||||
data: {"choices":[{"delta":{"content":"test1,"}}]
|
||||
|
||||
after do
|
||||
remove_stubbed_net_http
|
||||
restore_net_http
|
||||
end
|
||||
data: {"choices":[{"delta":|{"content":"test2,"}}]}
|
||||
|
||||
it "will automatically recover from a bad payload" do
|
||||
# this should not happen, but lets ensure nothing bad happens
|
||||
# the row with test1 is invalid json
|
||||
raw_data = <<~TEXT
|
||||
d|a|t|a|:| |{|"choices":[{"delta":{"content":"test,"}}]}
|
||||
data: {"choices":[{"delta":{"content":"test3,"}}]|}
|
||||
|
||||
data: {"choices":[{"delta":{"content":"test1,"}}]
|
||||
data: {"choices":[{|"|d|elta":{"content":"test4"}}]|}
|
||||
|
||||
data: {"choices":[{"delta":|{"content":"test2,"}}]}
|
||||
data: [D|ONE]
|
||||
TEXT
|
||||
|
||||
data: {"choices":[{"delta":{"content":"test3,"}}]|}
|
||||
chunks = raw_data.split("|")
|
||||
|
||||
data: {"choices":[{|"|d|elta":{"content":"test4"}}]|}
|
||||
open_ai_mock.with_chunk_array_support do
|
||||
open_ai_mock.stub_streamed_response(compliance.dialect.translate, chunks) do
|
||||
partials = []
|
||||
|
||||
data: [D|ONE]
|
||||
TEXT
|
||||
endpoint.perform_completion!(compliance.dialect, user) do |partial|
|
||||
partials << partial
|
||||
end
|
||||
|
||||
chunks = raw_data.split("|")
|
||||
expect(partials.join).to eq("test,test1,test2,test3,test4")
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||
status: 200,
|
||||
body: chunks,
|
||||
)
|
||||
|
||||
partials = []
|
||||
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
|
||||
llm.generate(
|
||||
DiscourseAi::Completions::Prompt.new("test"),
|
||||
user: Discourse.system_user,
|
||||
) { |partial| partials << partial }
|
||||
|
||||
expect(partials.join).to eq("test,test2,test3,test4")
|
||||
end
|
||||
|
||||
it "supports chunked encoding properly" do
|
||||
raw_data = <<~TEXT
|
||||
da|ta: {"choices":[{"delta":{"content":"test,"}}]}
|
||||
|
||||
data: {"choices":[{"delta":{"content":"test1,"}}]}
|
||||
|
||||
data: {"choices":[{"delta":|{"content":"test2,"}}]}
|
||||
|
||||
data: {"choices":[{"delta":{"content":"test3,"}}]|}
|
||||
|
||||
data: {"choices":[{|"|d|elta":{"content":"test4"}}]|}
|
||||
|
||||
data: [D|ONE]
|
||||
TEXT
|
||||
|
||||
chunks = raw_data.split("|")
|
||||
|
||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||
status: 200,
|
||||
body: chunks,
|
||||
)
|
||||
|
||||
partials = []
|
||||
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
|
||||
llm.generate(
|
||||
DiscourseAi::Completions::Prompt.new("test"),
|
||||
user: Discourse.system_user,
|
||||
) { |partial| partials << partial }
|
||||
|
||||
expect(partials.join).to eq("test,test1,test2,test3,test4")
|
||||
context "with tools" do
|
||||
it "returns a function invoncation" do
|
||||
compliance.streaming_mode_tools(open_ai_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,27 +1,14 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "endpoint_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::MixtralTokenizer) }
|
||||
|
||||
let(:model_name) { "mistralai/Mixtral-8x7B-Instruct-v0.1" }
|
||||
let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) }
|
||||
let(:prompt) { dialect.translate }
|
||||
|
||||
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
||||
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
|
||||
|
||||
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
|
||||
|
||||
let(:tool_id) { "get_weather" }
|
||||
require_relative "endpoint_compliance"
|
||||
|
||||
class VllmMock < EndpointMock
|
||||
def response(content)
|
||||
{
|
||||
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
|
||||
object: "text_completion",
|
||||
created: 1_678_464_820,
|
||||
model: model_name,
|
||||
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
usage: {
|
||||
prompt_tokens: 337,
|
||||
completion_tokens: 162,
|
||||
|
@ -34,7 +21,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
|||
def stub_response(prompt, response_text, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
|
||||
.with(body: request_body)
|
||||
.with(body: model.default_options.merge(prompt: prompt).to_json)
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
end
|
||||
|
||||
|
@ -42,7 +29,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
|||
+"data: " << {
|
||||
id: "cmpl-#{SecureRandom.hex}",
|
||||
created: 1_681_283_881,
|
||||
model: model_name,
|
||||
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
choices: [{ text: delta, finish_reason: finish_reason, index: 0 }],
|
||||
index: 0,
|
||||
}.to_json
|
||||
|
@ -62,33 +49,62 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
|||
|
||||
WebMock
|
||||
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
|
||||
.with(body: stream_request_body)
|
||||
.with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
let(:tool_deltas) { ["<function", <<~REPLY, <<~REPLY] }
|
||||
_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
REPLY
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
REPLY
|
||||
|
||||
let(:tool_call) { invocation }
|
||||
|
||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
||||
subject(:endpoint) do
|
||||
described_class.new(
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
DiscourseAi::Tokenizer::MixtralTokenizer,
|
||||
)
|
||||
end
|
||||
|
||||
fab!(:user) { Fabricate(:user) }
|
||||
|
||||
let(:anthropic_mock) { VllmMock.new(endpoint) }
|
||||
|
||||
let(:compliance) do
|
||||
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mixtral, user)
|
||||
end
|
||||
|
||||
let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) }
|
||||
let(:prompt) { dialect.translate }
|
||||
|
||||
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
||||
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
|
||||
|
||||
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
|
||||
|
||||
describe "#perform_completion!" do
|
||||
context "when using regular mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.regular_mode_simple_prompt(anthropic_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.regular_mode_tools(anthropic_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "when using streaming mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.streaming_mode_simple_prompt(anthropic_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invoncation" do
|
||||
compliance.streaming_mode_tools(anthropic_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue