DEV: Stop using shared_examples for endpoint specs (#430)

This commit is contained in:
Roman Rizzi 2024-01-17 15:08:49 -03:00 committed by GitHub
parent 8eb1e851fc
commit 5bdf3dc1f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 673 additions and 574 deletions

View File

@ -205,12 +205,10 @@ module DiscourseAi
tokenizer.size(extract_prompt_for_tokenizer(prompt)) tokenizer.size(extract_prompt_for_tokenizer(prompt))
end end
attr_reader :tokenizer attr_reader :tokenizer, :model
protected protected
attr_reader :model
# should normalize temperature, max_tokens, stop_words to endpoint specific values # should normalize temperature, max_tokens, stop_words to endpoint specific values
def normalize_model_params(model_params) def normalize_model_params(model_params)
raise NotImplementedError raise NotImplementedError

View File

@ -1,19 +1,8 @@
# frozen_String_literal: true # frozen_String_literal: true
require_relative "endpoint_examples" require_relative "endpoint_compliance"
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
let(:model_name) { "claude-2" }
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
let(:tool_id) { "get_weather" }
class AnthropicMock < EndpointMock
def response(content) def response(content)
{ {
completion: content, completion: content,
@ -21,7 +10,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
stop_reason: "stop_sequence", stop_reason: "stop_sequence",
truncated: false, truncated: false,
log_id: "12dcc7feafbee4a394e0de9dffde3ac5", log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
model: model_name, model: "claude-2",
exception: nil, exception: nil,
} }
end end
@ -29,7 +18,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
def stub_response(prompt, response_text, tool_call: false) def stub_response(prompt, response_text, tool_call: false)
WebMock WebMock
.stub_request(:post, "https://api.anthropic.com/v1/complete") .stub_request(:post, "https://api.anthropic.com/v1/complete")
.with(body: request_body) .with(body: model.default_options.merge(prompt: prompt).to_json)
.to_return(status: 200, body: JSON.dump(response(response_text))) .to_return(status: 200, body: JSON.dump(response(response_text)))
end end
@ -59,23 +48,49 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
WebMock WebMock
.stub_request(:post, "https://api.anthropic.com/v1/complete") .stub_request(:post, "https://api.anthropic.com/v1/complete")
.with(body: stream_request_body) .with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end
end
let(:tool_deltas) { ["Let me use a tool for that<function", <<~REPLY] }
_calls> RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
<invoke> subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
<tool_name>get_weather</tool_name>
<parameters> fab!(:user) { Fabricate(:user) }
<location>Sydney</location>
<unit>c</unit> let(:anthropic_mock) { AnthropicMock.new(endpoint) }
</parameters>
</invoke> let(:compliance) do
</function_calls> EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
REPLY end
let(:tool_call) { invocation } describe "#perform_completion!" do
context "when using regular mode" do
it_behaves_like "an endpoint that can communicate with a completion service" context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(anthropic_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(anthropic_mock)
end
end
end
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(anthropic_mock)
end
end
context "with tools" do
it "returns a function invoncation" do
compliance.streaming_mode_tools(anthropic_mock)
end
end
end
end
end end

View File

@ -1,28 +1,10 @@
# frozen_string_literal: true # frozen_string_literal: true
require_relative "endpoint_examples" require_relative "endpoint_compliance"
require "aws-eventstream" require "aws-eventstream"
require "aws-sigv4" require "aws-sigv4"
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do class BedrockMock < EndpointMock
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
let(:model_name) { "claude-2" }
let(:bedrock_name) { "claude-v2:1" }
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { request_body }
let(:tool_id) { "get_weather" }
before do
SiteSetting.ai_bedrock_access_key_id = "123456"
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
SiteSetting.ai_bedrock_region = "us-east-1"
end
def response(content) def response(content)
{ {
completion: content, completion: content,
@ -30,19 +12,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
stop_reason: "stop_sequence", stop_reason: "stop_sequence",
truncated: false, truncated: false,
log_id: "12dcc7feafbee4a394e0de9dffde3ac5", log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
model: model_name, model: "claude",
exception: nil, exception: nil,
} }
end end
def stub_response(prompt, response_text, tool_call: false) def stub_response(prompt, response_content, tool_call: false)
WebMock WebMock
.stub_request( .stub_request(:post, "#{base_url}/invoke")
:post, .with(body: model.default_options.merge(prompt: prompt).to_json)
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_name}/invoke", .to_return(status: 200, body: JSON.dump(response(response_content)))
)
.with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text)))
end end
def stream_line(delta, finish_reason: nil) def stream_line(delta, finish_reason: nil)
@ -83,27 +62,60 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
end end
WebMock WebMock
.stub_request( .stub_request(:post, "#{base_url}/invoke-with-response-stream")
:post, .with(body: model.default_options.merge(prompt: prompt).to_json)
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_name}/invoke-with-response-stream",
)
.with(body: stream_request_body)
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end
let(:tool_deltas) { ["<function", <<~REPLY] } def base_url
_calls> "https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.claude-v2:1"
<invoke> end
<tool_name>get_weather</tool_name> end
<parameters>
<location>Sydney</location> RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
<unit>c</unit> subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
</parameters>
</invoke> fab!(:user) { Fabricate(:user) }
</function_calls>
REPLY let(:bedrock_mock) { BedrockMock.new(endpoint) }
let(:tool_call) { invocation } let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
it_behaves_like "an endpoint that can communicate with a completion service" end
before do
SiteSetting.ai_bedrock_access_key_id = "123456"
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
SiteSetting.ai_bedrock_region = "us-east-1"
end
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(bedrock_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(bedrock_mock)
end
end
end
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(bedrock_mock)
end
end
context "with tools" do
it "returns a function invoncation" do
compliance.streaming_mode_tools(bedrock_mock)
end
end
end
end
end end

View File

@ -0,0 +1,229 @@
# frozen_string_literal: true
require "net/http"
class EndpointMock
def initialize(model)
@model = model
end
attr_reader :model
def stub_simple_call(prompt)
stub_response(prompt, simple_response)
end
def stub_tool_call(prompt)
stub_response(prompt, tool_response, tool_call: true)
end
def stub_streamed_simple_call(prompt)
with_chunk_array_support do
stub_streamed_response(prompt, streamed_simple_deltas)
yield
end
end
def stub_streamed_tool_call(prompt)
with_chunk_array_support do
stub_streamed_response(prompt, tool_deltas, tool_call: true)
yield
end
end
def simple_response
"1. Serenity\\n2. Laughter\\n3. Adventure"
end
def streamed_simple_deltas
["Mount", "ain", " ", "Tree ", "Frog"]
end
def tool_deltas
["Let me use a tool for that<function", <<~REPLY.strip, <<~REPLY.strip, <<~REPLY.strip]
_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</para
REPLY
meters>
</invoke>
</funct
REPLY
ion_calls>
REPLY
end
def tool_response
tool_deltas.join
end
def invocation_response
<<~TEXT
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<tool_id>#{tool_id}</tool_id>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
TEXT
end
def tool_id
"get_weather"
end
def tool
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
def with_chunk_array_support
mock = mocked_http
@original_net_http = ::Net.send(:remove_const, :HTTP)
::Net.send(:const_set, :HTTP, mock)
yield
ensure
::Net.send(:remove_const, :HTTP)
::Net.send(:const_set, :HTTP, @original_net_http)
end
protected
# Copied from https://github.com/bblimke/webmock/issues/629
# Workaround for stubbing a streamed response
def mocked_http
Class.new(::Net::HTTP) do
def request(*)
super do |response|
response.instance_eval do
def read_body(*, &block)
if block_given?
@body.each(&block)
else
super
end
end
end
yield response if block_given?
response
end
end
end
end
end
class EndpointsCompliance
def initialize(rspec, endpoint, dialect_klass, user)
@rspec = rspec
@endpoint = endpoint
@dialect_klass = dialect_klass
@user = user
end
delegate :expect, :eq, :be_present, to: :rspec
def generic_prompt(tools: [])
DiscourseAi::Completions::Prompt.new(
"You write words",
messages: [{ type: :user, content: "write 3 words" }],
tools: tools,
)
end
def dialect(prompt: generic_prompt)
dialect_klass.new(prompt, endpoint.model)
end
def regular_mode_simple_prompt(mock)
mock.stub_simple_call(dialect.translate)
completion_response = endpoint.perform_completion!(dialect, user)
expect(completion_response).to eq(mock.simple_response)
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
expect(log.provider_id).to eq(endpoint.provider_id)
expect(log.user_id).to eq(user.id)
expect(log.raw_request_payload).to be_present
expect(log.raw_response_payload).to eq(mock.response(completion_response).to_json)
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
expect(log.response_tokens).to eq(endpoint.tokenizer.size(completion_response))
end
def regular_mode_tools(mock)
prompt = generic_prompt(tools: [mock.tool])
a_dialect = dialect(prompt: prompt)
mock.stub_tool_call(a_dialect.translate)
completion_response = endpoint.perform_completion!(a_dialect, user)
expect(completion_response).to eq(mock.invocation_response)
end
def streaming_mode_simple_prompt(mock)
mock.stub_streamed_simple_call(dialect.translate) do
completion_response = +""
endpoint.perform_completion!(dialect, user) do |partial, cancel|
completion_response << partial
cancel.call if completion_response.split(" ").length == 2
end
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
expect(log.provider_id).to eq(endpoint.provider_id)
expect(log.user_id).to eq(user.id)
expect(log.raw_request_payload).to be_present
expect(log.raw_response_payload).to be_present
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
expect(log.response_tokens).to eq(
endpoint.tokenizer.size(mock.streamed_simple_deltas[0...-1].join),
)
end
end
def streaming_mode_tools(mock)
prompt = generic_prompt(tools: [mock.tool])
a_dialect = dialect(prompt: prompt)
mock.stub_streamed_tool_call(a_dialect.translate) do
buffered_partial = +""
endpoint.perform_completion!(a_dialect, user) do |partial, cancel|
buffered_partial << partial
cancel.call if buffered_partial.include?("<function_calls>")
end
expect(buffered_partial).to eq(mock.invocation_response)
end
end
attr_reader :rspec, :endpoint, :dialect_klass, :user
end

View File

@ -1,175 +0,0 @@
# frozen_string_literal: true
RSpec.shared_examples "an endpoint that can communicate with a completion service" do
# Copied from https://github.com/bblimke/webmock/issues/629
# Workaround for stubbing a streamed response
before do
mocked_http =
Class.new(Net::HTTP) do
def request(*)
super do |response|
response.instance_eval do
def read_body(*, &block)
if block_given?
@body.each(&block)
else
super
end
end
end
yield response if block_given?
response
end
end
end
@original_net_http = Net.send(:remove_const, :HTTP)
Net.send(:const_set, :HTTP, mocked_http)
end
after do
Net.send(:remove_const, :HTTP)
Net.send(:const_set, :HTTP, @original_net_http)
end
let(:generic_prompt) do
DiscourseAi::Completions::Prompt.new(
"You write words",
messages: [{ type: :user, content: "write 3 words" }],
)
end
describe "#perform_completion!" do
fab!(:user) { Fabricate(:user) }
let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
let(:invocation) { <<~TEXT }
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<tool_id>#{tool_id || "get_weather"}</tool_id>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
TEXT
context "when using regular mode" do
context "with simple prompts" do
let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" }
before { stub_response(prompt, response_text) }
it "can complete a trivial prompt" do
completion_response = model.perform_completion!(dialect, user)
expect(completion_response).to eq(response_text)
end
it "creates an audit log for the request" do
model.perform_completion!(dialect, user)
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
response_body = response(response_text).to_json
expect(log.provider_id).to eq(model.provider_id)
expect(log.user_id).to eq(user.id)
expect(log.raw_request_payload).to eq(request_body)
expect(log.raw_response_payload).to eq(response_body)
expect(log.request_tokens).to eq(model.prompt_size(prompt))
expect(log.response_tokens).to eq(model.tokenizer.size(response_text))
end
end
context "with functions" do
before do
generic_prompt.tools = [tool]
stub_response(prompt, tool_call, tool_call: true)
end
it "returns a function invocation" do
completion_response = model.perform_completion!(dialect, user)
expect(completion_response).to eq(invocation)
end
end
end
context "when using stream mode" do
context "with simple prompts" do
let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] }
before { stub_streamed_response(prompt, deltas) }
it "can complete a trivial prompt" do
completion_response = +""
model.perform_completion!(dialect, user) do |partial, cancel|
completion_response << partial
cancel.call if completion_response.split(" ").length == 2
end
expect(completion_response).to eq(deltas[0...-1].join)
end
it "creates an audit log and updates is on each read." do
completion_response = +""
model.perform_completion!(dialect, user) do |partial, cancel|
completion_response << partial
cancel.call if completion_response.split(" ").length == 2
end
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
expect(log.provider_id).to eq(model.provider_id)
expect(log.user_id).to eq(user.id)
expect(log.raw_request_payload).to eq(stream_request_body)
expect(log.raw_response_payload).to be_present
expect(log.request_tokens).to eq(model.prompt_size(prompt))
expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join))
end
end
context "with functions" do
before do
generic_prompt.tools = [tool]
stub_streamed_response(prompt, tool_deltas, tool_call: true)
end
it "waits for the invocation to finish before calling the partial" do
buffered_partial = ""
model.perform_completion!(dialect, user) do |partial, cancel|
buffered_partial = partial if partial.include?("<function_calls>")
end
expect(buffered_partial).to eq(invocation)
end
end
end
end
end

View File

@ -1,69 +1,8 @@
# frozen_string_literal: true # frozen_string_literal: true
require_relative "endpoint_examples" require_relative "endpoint_compliance"
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
let(:model_name) { "gemini-pro" }
let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:tool_id) { "get_weather" }
let(:tool_payload) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: {
type: "object",
required: %w[location unit],
properties: {
"location" => {
type: "string",
description: "the city name",
},
"unit" => {
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
},
},
}
end
let(:request_body) do
model
.default_options
.merge(contents: prompt)
.tap do |b|
b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt.tools.present?
end
.to_json
end
let(:stream_request_body) do
model
.default_options
.merge(contents: prompt)
.tap do |b|
b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt.tools.present?
end
.to_json
end
let(:tool_deltas) do
[
{ "functionCall" => { name: "get_weather", args: {} } },
{ "functionCall" => { name: "get_weather", args: { location: "" } } },
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
]
end
let(:tool_call) do
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }
end
class GeminiMock < EndpointMock
def response(content, tool_call: false) def response(content, tool_call: false)
{ {
candidates: [ candidates: [
@ -97,9 +36,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
WebMock WebMock
.stub_request( .stub_request(
:post, :post,
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}", "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
) )
.with(body: request_body) .with(body: request_body(prompt, tool_call))
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call))) .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
end end
@ -139,11 +78,93 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
WebMock WebMock
.stub_request( .stub_request(
:post, :post,
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}", "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
) )
.with(body: stream_request_body) .with(body: request_body(prompt, tool_call))
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end
it_behaves_like "an endpoint that can communicate with a completion service" def tool_payload
{
name: "get_weather",
description: "Get the weather in a city",
parameters: {
type: "object",
required: %w[location unit],
properties: {
"location" => {
type: "string",
description: "the city name",
},
"unit" => {
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
},
},
}
end
def request_body(prompt, tool_call)
model
.default_options
.merge(contents: prompt)
.tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if tool_call }
.to_json
end
def tool_deltas
[
{ "functionCall" => { name: "get_weather", args: {} } },
{ "functionCall" => { name: "get_weather", args: { location: "" } } },
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
]
end
def tool_response
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }
end
end
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
subject(:endpoint) { described_class.new("gemini-pro", DiscourseAi::Tokenizer::OpenAiTokenizer) }
fab!(:user) { Fabricate(:user) }
let(:bedrock_mock) { GeminiMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user)
end
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(bedrock_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(bedrock_mock)
end
end
end
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(bedrock_mock)
end
end
context "with tools" do
it "returns a function invoncation" do
compliance.streaming_mode_tools(bedrock_mock)
end
end
end
end
end end

View File

@ -1,42 +1,8 @@
# frozen_string_literal: true # frozen_string_literal: true
require_relative "endpoint_examples" require_relative "endpoint_compliance"
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::Llama2Tokenizer) }
let(:model_name) { "Llama2-*-chat-hf" }
let(:dialect) do
DiscourseAi::Completions::Dialects::Llama2Classic.new(generic_prompt, model_name)
end
let(:prompt) { dialect.translate }
let(:tool_id) { "get_weather" }
let(:request_body) do
model
.default_options
.merge(inputs: prompt)
.tap do |payload|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt)
end
.to_json
end
let(:stream_request_body) do
model
.default_options
.merge(inputs: prompt)
.tap do |payload|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt)
payload[:stream] = true
end
.to_json
end
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
class HuggingFaceMock < EndpointMock
def response(content) def response(content)
[{ generated_text: content }] [{ generated_text: content }]
end end
@ -44,7 +10,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
def stub_response(prompt, response_text, tool_call: false) def stub_response(prompt, response_text, tool_call: false)
WebMock WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
.with(body: request_body) .with(body: request_body(prompt))
.to_return(status: 200, body: JSON.dump(response(response_text))) .to_return(status: 200, body: JSON.dump(response(response_text)))
end end
@ -75,33 +41,65 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
WebMock WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
.with(body: stream_request_body) .with(body: request_body(prompt, stream: true))
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end
let(:tool_deltas) { ["<function", <<~REPLY, <<~REPLY] } def request_body(prompt, stream: false)
_calls> model
<invoke> .default_options
<tool_name>get_weather</tool_name> .merge(inputs: prompt)
<parameters> .tap do |payload|
<location>Sydney</location> payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
<unit>c</unit> model.prompt_size(prompt)
</parameters> payload[:stream] = true if stream
</invoke> end
</function_calls> .to_json
REPLY end
<function_calls> end
<invoke>
<tool_name>get_weather</tool_name> RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
<parameters> subject(:endpoint) do
<location>Sydney</location> described_class.new("Llama2-*-chat-hf", DiscourseAi::Tokenizer::Llama2Tokenizer)
<unit>c</unit> end
</parameters>
</invoke> before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
</function_calls>
REPLY fab!(:user) { Fabricate(:user) }
let(:tool_call) { invocation } let(:hf_mock) { HuggingFaceMock.new(endpoint) }
it_behaves_like "an endpoint that can communicate with a completion service" let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Llama2Classic, user)
end
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(hf_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(hf_mock)
end
end
end
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(hf_mock)
end
end
context "with tools" do
it "returns a function invoncation" do
compliance.streaming_mode_tools(hf_mock)
end
end
end
end
end end

View File

@ -1,53 +1,8 @@
# frozen_string_literal: true # frozen_string_literal: true
require_relative "endpoint_examples" require_relative "endpoint_compliance"
RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
let(:model_name) { "gpt-3.5-turbo" }
let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:tool_id) { "eujbuebfe" }
let(:tool_deltas) do
[
{ id: tool_id, function: {} },
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
{ id: tool_id, function: { name: "get_weather", arguments: "{" } },
{ id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } },
{ id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } },
]
end
let(:tool_call) do
{
id: tool_id,
function: {
name: "get_weather",
arguments: { location: "Sydney", unit: "c" }.to_json,
},
}
end
let(:request_body) do
model
.default_options
.merge(messages: prompt)
.tap { |b| b[:tools] = dialect.tools if generic_prompt.tools.present? }
.to_json
end
let(:stream_request_body) do
model
.default_options
.merge(messages: prompt, stream: true)
.tap { |b| b[:tools] = dialect.tools if generic_prompt.tools.present? }
.to_json
end
class OpenAiMock < EndpointMock
def response(content, tool_call: false) def response(content, tool_call: false)
message_content = message_content =
if tool_call if tool_call
@ -75,7 +30,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
def stub_response(prompt, response_text, tool_call: false) def stub_response(prompt, response_text, tool_call: false)
WebMock WebMock
.stub_request(:post, "https://api.openai.com/v1/chat/completions") .stub_request(:post, "https://api.openai.com/v1/chat/completions")
.with(body: request_body) .with(body: request_body(prompt, tool_call: tool_call))
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call))) .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
end end
@ -112,114 +67,144 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
WebMock WebMock
.stub_request(:post, "https://api.openai.com/v1/chat/completions") .stub_request(:post, "https://api.openai.com/v1/chat/completions")
.with(body: stream_request_body) .with(body: request_body(prompt, stream: true, tool_call: tool_call))
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end
it_behaves_like "an endpoint that can communicate with a completion service" def tool_deltas
[
{ id: tool_id, function: {} },
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
{ id: tool_id, function: { name: "get_weather", arguments: "{" } },
{ id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } },
{ id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } },
]
end
context "when chunked encoding returns partial chunks" do def tool_response
# See: https://github.com/bblimke/webmock/issues/629 {
let(:mock_net_http) do id: tool_id,
Class.new(Net::HTTP) do function: {
def request(*) name: "get_weather",
super do |response| arguments: { location: "Sydney", unit: "c" }.to_json,
response.instance_eval do },
def read_body(*, &) }
@body.each(&) end
end
end
yield response if block_given? def tool_id
"eujbuebfe"
end
response def tool_payload
end {
type: "function",
function: {
name: "get_weather",
description: "Get the weather in a city",
parameters: {
type: "object",
properties: {
location: {
type: "string",
description: "the city name",
},
unit: {
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
},
required: %w[location unit],
},
},
}
end
def request_body(prompt, stream: false, tool_call: false)
model
.default_options
.merge(messages: prompt)
.tap do |b|
b[:stream] = true if stream
b[:tools] = [tool_payload] if tool_call
end
.to_json
end
end
RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
subject(:endpoint) do
described_class.new("gpt-3.5-turbo", DiscourseAi::Tokenizer::OpenAiTokenizer)
end
fab!(:user) { Fabricate(:user) }
let(:open_ai_mock) { OpenAiMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::ChatGpt, user)
end
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(open_ai_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(open_ai_mock)
end end
end end
end end
let(:remove_original_net_http) { Net.send(:remove_const, :HTTP) } describe "when using streaming mode" do
let(:original_http) { remove_original_net_http } context "with simple prompts" do
let(:stub_net_http) { Net.send(:const_set, :HTTP, mock_net_http) } it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(open_ai_mock)
end
let(:remove_stubbed_net_http) { Net.send(:remove_const, :HTTP) } it "will automatically recover from a bad payload" do
let(:restore_net_http) { Net.send(:const_set, :HTTP, original_http) } # this should not happen, but lets ensure nothing bad happens
# the row with test1 is invalid json
raw_data = <<~TEXT.strip
d|a|t|a|:| |{|"choices":[{"delta":{"content":"test,"}}]}
data: {"choices":[{"delta":{"content":"test1,"}}]
data: {"choices":[{"delta":|{"content":"test2,"}}]}
data: {"choices":[{"delta":{"content":"test3,"}}]|}
data: {"choices":[{|"|d|elta":{"content":"test4"}}]|}
data: [D|ONE]
TEXT
before do chunks = raw_data.split("|")
mock_net_http
remove_original_net_http
stub_net_http
end
after do open_ai_mock.with_chunk_array_support do
remove_stubbed_net_http open_ai_mock.stub_streamed_response(compliance.dialect.translate, chunks) do
restore_net_http partials = []
end
it "will automatically recover from a bad payload" do endpoint.perform_completion!(compliance.dialect, user) do |partial|
# this should not happen, but lets ensure nothing bad happens partials << partial
# the row with test1 is invalid json end
raw_data = <<~TEXT
d|a|t|a|:| |{|"choices":[{"delta":{"content":"test,"}}]}
data: {"choices":[{"delta":{"content":"test1,"}}] expect(partials.join).to eq("test,test1,test2,test3,test4")
end
end
end
end
data: {"choices":[{"delta":|{"content":"test2,"}}]} context "with tools" do
it "returns a function invoncation" do
data: {"choices":[{"delta":{"content":"test3,"}}]|} compliance.streaming_mode_tools(open_ai_mock)
end
data: {"choices":[{|"|d|elta":{"content":"test4"}}]|} end
data: [D|ONE]
TEXT
chunks = raw_data.split("|")
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
status: 200,
body: chunks,
)
partials = []
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
llm.generate(
DiscourseAi::Completions::Prompt.new("test"),
user: Discourse.system_user,
) { |partial| partials << partial }
expect(partials.join).to eq("test,test2,test3,test4")
end
it "supports chunked encoding properly" do
raw_data = <<~TEXT
da|ta: {"choices":[{"delta":{"content":"test,"}}]}
data: {"choices":[{"delta":{"content":"test1,"}}]}
data: {"choices":[{"delta":|{"content":"test2,"}}]}
data: {"choices":[{"delta":{"content":"test3,"}}]|}
data: {"choices":[{|"|d|elta":{"content":"test4"}}]|}
data: [D|ONE]
TEXT
chunks = raw_data.split("|")
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
status: 200,
body: chunks,
)
partials = []
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
llm.generate(
DiscourseAi::Completions::Prompt.new("test"),
user: Discourse.system_user,
) { |partial| partials << partial }
expect(partials.join).to eq("test,test1,test2,test3,test4")
end end
end end
end end

View File

@ -1,27 +1,14 @@
# frozen_string_literal: true # frozen_string_literal: true
require_relative "endpoint_examples" require_relative "endpoint_compliance"
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::MixtralTokenizer) }
let(:model_name) { "mistralai/Mixtral-8x7B-Instruct-v0.1" }
let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
let(:tool_id) { "get_weather" }
class VllmMock < EndpointMock
def response(content) def response(content)
{ {
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S", id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
object: "text_completion", object: "text_completion",
created: 1_678_464_820, created: 1_678_464_820,
model: model_name, model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
usage: { usage: {
prompt_tokens: 337, prompt_tokens: 337,
completion_tokens: 162, completion_tokens: 162,
@ -34,7 +21,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
def stub_response(prompt, response_text, tool_call: false) def stub_response(prompt, response_text, tool_call: false)
WebMock WebMock
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions") .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
.with(body: request_body) .with(body: model.default_options.merge(prompt: prompt).to_json)
.to_return(status: 200, body: JSON.dump(response(response_text))) .to_return(status: 200, body: JSON.dump(response(response_text)))
end end
@ -42,7 +29,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
+"data: " << { +"data: " << {
id: "cmpl-#{SecureRandom.hex}", id: "cmpl-#{SecureRandom.hex}",
created: 1_681_283_881, created: 1_681_283_881,
model: model_name, model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
choices: [{ text: delta, finish_reason: finish_reason, index: 0 }], choices: [{ text: delta, finish_reason: finish_reason, index: 0 }],
index: 0, index: 0,
}.to_json }.to_json
@ -62,33 +49,62 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
WebMock WebMock
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions") .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
.with(body: stream_request_body) .with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end
end
let(:tool_deltas) { ["<function", <<~REPLY, <<~REPLY] }
_calls> RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
<invoke> subject(:endpoint) do
<tool_name>get_weather</tool_name> described_class.new(
<parameters> "mistralai/Mixtral-8x7B-Instruct-v0.1",
<location>Sydney</location> DiscourseAi::Tokenizer::MixtralTokenizer,
<unit>c</unit> )
</parameters> end
</invoke>
</function_calls> fab!(:user) { Fabricate(:user) }
REPLY
<function_calls> let(:anthropic_mock) { VllmMock.new(endpoint) }
<invoke>
<tool_name>get_weather</tool_name> let(:compliance) do
<parameters> EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mixtral, user)
<location>Sydney</location> end
<unit>c</unit>
</parameters> let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) }
</invoke> let(:prompt) { dialect.translate }
</function_calls>
REPLY let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
let(:tool_call) { invocation }
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
it_behaves_like "an endpoint that can communicate with a completion service"
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(anthropic_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(anthropic_mock)
end
end
end
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(anthropic_mock)
end
end
context "with tools" do
it "returns a function invoncation" do
compliance.streaming_mode_tools(anthropic_mock)
end
end
end
end
end end