DEV: Stop using shared_examples for endpoint specs (#430)

This commit is contained in:
Roman Rizzi 2024-01-17 15:08:49 -03:00 committed by GitHub
parent 8eb1e851fc
commit 5bdf3dc1f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 673 additions and 574 deletions

View File

@ -205,12 +205,10 @@ module DiscourseAi
tokenizer.size(extract_prompt_for_tokenizer(prompt))
end
attr_reader :tokenizer
attr_reader :tokenizer, :model
protected
attr_reader :model
# should normalize temperature, max_tokens, stop_words to endpoint specific values
def normalize_model_params(model_params)
raise NotImplementedError

View File

@ -1,19 +1,8 @@
# frozen_String_literal: true
require_relative "endpoint_examples"
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
let(:model_name) { "claude-2" }
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
let(:tool_id) { "get_weather" }
require_relative "endpoint_compliance"
class AnthropicMock < EndpointMock
def response(content)
{
completion: content,
@ -21,7 +10,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
stop_reason: "stop_sequence",
truncated: false,
log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
model: model_name,
model: "claude-2",
exception: nil,
}
end
@ -29,7 +18,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(:post, "https://api.anthropic.com/v1/complete")
.with(body: request_body)
.with(body: model.default_options.merge(prompt: prompt).to_json)
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
@ -59,23 +48,49 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
WebMock
.stub_request(:post, "https://api.anthropic.com/v1/complete")
.with(body: stream_request_body)
.with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
.to_return(status: 200, body: chunks)
end
let(:tool_deltas) { ["Let me use a tool for that<function", <<~REPLY] }
_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
let(:tool_call) { invocation }
it_behaves_like "an endpoint that can communicate with a completion service"
end
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
fab!(:user) { Fabricate(:user) }
let(:anthropic_mock) { AnthropicMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
end
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(anthropic_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(anthropic_mock)
end
end
end
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(anthropic_mock)
end
end
context "with tools" do
it "returns a function invoncation" do
compliance.streaming_mode_tools(anthropic_mock)
end
end
end
end
end

View File

@ -1,28 +1,10 @@
# frozen_string_literal: true
require_relative "endpoint_examples"
require_relative "endpoint_compliance"
require "aws-eventstream"
require "aws-sigv4"
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
let(:model_name) { "claude-2" }
let(:bedrock_name) { "claude-v2:1" }
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { request_body }
let(:tool_id) { "get_weather" }
before do
SiteSetting.ai_bedrock_access_key_id = "123456"
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
SiteSetting.ai_bedrock_region = "us-east-1"
end
class BedrockMock < EndpointMock
def response(content)
{
completion: content,
@ -30,19 +12,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
stop_reason: "stop_sequence",
truncated: false,
log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
model: model_name,
model: "claude",
exception: nil,
}
end
def stub_response(prompt, response_text, tool_call: false)
def stub_response(prompt, response_content, tool_call: false)
WebMock
.stub_request(
:post,
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_name}/invoke",
)
.with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text)))
.stub_request(:post, "#{base_url}/invoke")
.with(body: model.default_options.merge(prompt: prompt).to_json)
.to_return(status: 200, body: JSON.dump(response(response_content)))
end
def stream_line(delta, finish_reason: nil)
@ -83,27 +62,60 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
end
WebMock
.stub_request(
:post,
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_name}/invoke-with-response-stream",
)
.with(body: stream_request_body)
.stub_request(:post, "#{base_url}/invoke-with-response-stream")
.with(body: model.default_options.merge(prompt: prompt).to_json)
.to_return(status: 200, body: chunks)
end
let(:tool_deltas) { ["<function", <<~REPLY] }
_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
let(:tool_call) { invocation }
it_behaves_like "an endpoint that can communicate with a completion service"
def base_url
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.claude-v2:1"
end
end
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
fab!(:user) { Fabricate(:user) }
let(:bedrock_mock) { BedrockMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
end
before do
SiteSetting.ai_bedrock_access_key_id = "123456"
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
SiteSetting.ai_bedrock_region = "us-east-1"
end
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(bedrock_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(bedrock_mock)
end
end
end
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(bedrock_mock)
end
end
context "with tools" do
it "returns a function invoncation" do
compliance.streaming_mode_tools(bedrock_mock)
end
end
end
end
end

View File

@ -0,0 +1,229 @@
# frozen_string_literal: true
require "net/http"
class EndpointMock
def initialize(model)
@model = model
end
attr_reader :model
def stub_simple_call(prompt)
stub_response(prompt, simple_response)
end
def stub_tool_call(prompt)
stub_response(prompt, tool_response, tool_call: true)
end
def stub_streamed_simple_call(prompt)
with_chunk_array_support do
stub_streamed_response(prompt, streamed_simple_deltas)
yield
end
end
def stub_streamed_tool_call(prompt)
with_chunk_array_support do
stub_streamed_response(prompt, tool_deltas, tool_call: true)
yield
end
end
def simple_response
"1. Serenity\\n2. Laughter\\n3. Adventure"
end
def streamed_simple_deltas
["Mount", "ain", " ", "Tree ", "Frog"]
end
def tool_deltas
["Let me use a tool for that<function", <<~REPLY.strip, <<~REPLY.strip, <<~REPLY.strip]
_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</para
REPLY
meters>
</invoke>
</funct
REPLY
ion_calls>
REPLY
end
def tool_response
tool_deltas.join
end
def invocation_response
<<~TEXT
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<tool_id>#{tool_id}</tool_id>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
TEXT
end
def tool_id
"get_weather"
end
def tool
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
def with_chunk_array_support
mock = mocked_http
@original_net_http = ::Net.send(:remove_const, :HTTP)
::Net.send(:const_set, :HTTP, mock)
yield
ensure
::Net.send(:remove_const, :HTTP)
::Net.send(:const_set, :HTTP, @original_net_http)
end
protected
# Copied from https://github.com/bblimke/webmock/issues/629
# Workaround for stubbing a streamed response
def mocked_http
Class.new(::Net::HTTP) do
def request(*)
super do |response|
response.instance_eval do
def read_body(*, &block)
if block_given?
@body.each(&block)
else
super
end
end
end
yield response if block_given?
response
end
end
end
end
end
class EndpointsCompliance
def initialize(rspec, endpoint, dialect_klass, user)
@rspec = rspec
@endpoint = endpoint
@dialect_klass = dialect_klass
@user = user
end
delegate :expect, :eq, :be_present, to: :rspec
def generic_prompt(tools: [])
DiscourseAi::Completions::Prompt.new(
"You write words",
messages: [{ type: :user, content: "write 3 words" }],
tools: tools,
)
end
def dialect(prompt: generic_prompt)
dialect_klass.new(prompt, endpoint.model)
end
def regular_mode_simple_prompt(mock)
mock.stub_simple_call(dialect.translate)
completion_response = endpoint.perform_completion!(dialect, user)
expect(completion_response).to eq(mock.simple_response)
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
expect(log.provider_id).to eq(endpoint.provider_id)
expect(log.user_id).to eq(user.id)
expect(log.raw_request_payload).to be_present
expect(log.raw_response_payload).to eq(mock.response(completion_response).to_json)
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
expect(log.response_tokens).to eq(endpoint.tokenizer.size(completion_response))
end
def regular_mode_tools(mock)
prompt = generic_prompt(tools: [mock.tool])
a_dialect = dialect(prompt: prompt)
mock.stub_tool_call(a_dialect.translate)
completion_response = endpoint.perform_completion!(a_dialect, user)
expect(completion_response).to eq(mock.invocation_response)
end
def streaming_mode_simple_prompt(mock)
mock.stub_streamed_simple_call(dialect.translate) do
completion_response = +""
endpoint.perform_completion!(dialect, user) do |partial, cancel|
completion_response << partial
cancel.call if completion_response.split(" ").length == 2
end
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
expect(log.provider_id).to eq(endpoint.provider_id)
expect(log.user_id).to eq(user.id)
expect(log.raw_request_payload).to be_present
expect(log.raw_response_payload).to be_present
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
expect(log.response_tokens).to eq(
endpoint.tokenizer.size(mock.streamed_simple_deltas[0...-1].join),
)
end
end
def streaming_mode_tools(mock)
prompt = generic_prompt(tools: [mock.tool])
a_dialect = dialect(prompt: prompt)
mock.stub_streamed_tool_call(a_dialect.translate) do
buffered_partial = +""
endpoint.perform_completion!(a_dialect, user) do |partial, cancel|
buffered_partial << partial
cancel.call if buffered_partial.include?("<function_calls>")
end
expect(buffered_partial).to eq(mock.invocation_response)
end
end
attr_reader :rspec, :endpoint, :dialect_klass, :user
end

View File

@ -1,175 +0,0 @@
# frozen_string_literal: true
RSpec.shared_examples "an endpoint that can communicate with a completion service" do
# Copied from https://github.com/bblimke/webmock/issues/629
# Workaround for stubbing a streamed response
before do
mocked_http =
Class.new(Net::HTTP) do
def request(*)
super do |response|
response.instance_eval do
def read_body(*, &block)
if block_given?
@body.each(&block)
else
super
end
end
end
yield response if block_given?
response
end
end
end
@original_net_http = Net.send(:remove_const, :HTTP)
Net.send(:const_set, :HTTP, mocked_http)
end
after do
Net.send(:remove_const, :HTTP)
Net.send(:const_set, :HTTP, @original_net_http)
end
let(:generic_prompt) do
DiscourseAi::Completions::Prompt.new(
"You write words",
messages: [{ type: :user, content: "write 3 words" }],
)
end
describe "#perform_completion!" do
fab!(:user) { Fabricate(:user) }
let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
let(:invocation) { <<~TEXT }
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<tool_id>#{tool_id || "get_weather"}</tool_id>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
TEXT
context "when using regular mode" do
context "with simple prompts" do
let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" }
before { stub_response(prompt, response_text) }
it "can complete a trivial prompt" do
completion_response = model.perform_completion!(dialect, user)
expect(completion_response).to eq(response_text)
end
it "creates an audit log for the request" do
model.perform_completion!(dialect, user)
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
response_body = response(response_text).to_json
expect(log.provider_id).to eq(model.provider_id)
expect(log.user_id).to eq(user.id)
expect(log.raw_request_payload).to eq(request_body)
expect(log.raw_response_payload).to eq(response_body)
expect(log.request_tokens).to eq(model.prompt_size(prompt))
expect(log.response_tokens).to eq(model.tokenizer.size(response_text))
end
end
context "with functions" do
before do
generic_prompt.tools = [tool]
stub_response(prompt, tool_call, tool_call: true)
end
it "returns a function invocation" do
completion_response = model.perform_completion!(dialect, user)
expect(completion_response).to eq(invocation)
end
end
end
context "when using stream mode" do
context "with simple prompts" do
let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] }
before { stub_streamed_response(prompt, deltas) }
it "can complete a trivial prompt" do
completion_response = +""
model.perform_completion!(dialect, user) do |partial, cancel|
completion_response << partial
cancel.call if completion_response.split(" ").length == 2
end
expect(completion_response).to eq(deltas[0...-1].join)
end
it "creates an audit log and updates is on each read." do
completion_response = +""
model.perform_completion!(dialect, user) do |partial, cancel|
completion_response << partial
cancel.call if completion_response.split(" ").length == 2
end
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
expect(log.provider_id).to eq(model.provider_id)
expect(log.user_id).to eq(user.id)
expect(log.raw_request_payload).to eq(stream_request_body)
expect(log.raw_response_payload).to be_present
expect(log.request_tokens).to eq(model.prompt_size(prompt))
expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join))
end
end
context "with functions" do
before do
generic_prompt.tools = [tool]
stub_streamed_response(prompt, tool_deltas, tool_call: true)
end
it "waits for the invocation to finish before calling the partial" do
buffered_partial = ""
model.perform_completion!(dialect, user) do |partial, cancel|
buffered_partial = partial if partial.include?("<function_calls>")
end
expect(buffered_partial).to eq(invocation)
end
end
end
end
end

View File

@ -1,69 +1,8 @@
# frozen_string_literal: true
require_relative "endpoint_examples"
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
let(:model_name) { "gemini-pro" }
let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:tool_id) { "get_weather" }
let(:tool_payload) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: {
type: "object",
required: %w[location unit],
properties: {
"location" => {
type: "string",
description: "the city name",
},
"unit" => {
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
},
},
}
end
let(:request_body) do
model
.default_options
.merge(contents: prompt)
.tap do |b|
b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt.tools.present?
end
.to_json
end
let(:stream_request_body) do
model
.default_options
.merge(contents: prompt)
.tap do |b|
b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt.tools.present?
end
.to_json
end
let(:tool_deltas) do
[
{ "functionCall" => { name: "get_weather", args: {} } },
{ "functionCall" => { name: "get_weather", args: { location: "" } } },
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
]
end
let(:tool_call) do
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }
end
require_relative "endpoint_compliance"
class GeminiMock < EndpointMock
def response(content, tool_call: false)
{
candidates: [
@ -97,9 +36,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
WebMock
.stub_request(
:post,
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
)
.with(body: request_body)
.with(body: request_body(prompt, tool_call))
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
end
@ -139,11 +78,93 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
WebMock
.stub_request(
:post,
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
)
.with(body: stream_request_body)
.with(body: request_body(prompt, tool_call))
.to_return(status: 200, body: chunks)
end
it_behaves_like "an endpoint that can communicate with a completion service"
def tool_payload
{
name: "get_weather",
description: "Get the weather in a city",
parameters: {
type: "object",
required: %w[location unit],
properties: {
"location" => {
type: "string",
description: "the city name",
},
"unit" => {
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
},
},
}
end
def request_body(prompt, tool_call)
model
.default_options
.merge(contents: prompt)
.tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if tool_call }
.to_json
end
def tool_deltas
[
{ "functionCall" => { name: "get_weather", args: {} } },
{ "functionCall" => { name: "get_weather", args: { location: "" } } },
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
]
end
def tool_response
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }
end
end
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
subject(:endpoint) { described_class.new("gemini-pro", DiscourseAi::Tokenizer::OpenAiTokenizer) }
fab!(:user) { Fabricate(:user) }
let(:bedrock_mock) { GeminiMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user)
end
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(bedrock_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(bedrock_mock)
end
end
end
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(bedrock_mock)
end
end
context "with tools" do
it "returns a function invoncation" do
compliance.streaming_mode_tools(bedrock_mock)
end
end
end
end
end

View File

@ -1,42 +1,8 @@
# frozen_string_literal: true
require_relative "endpoint_examples"
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::Llama2Tokenizer) }
let(:model_name) { "Llama2-*-chat-hf" }
let(:dialect) do
DiscourseAi::Completions::Dialects::Llama2Classic.new(generic_prompt, model_name)
end
let(:prompt) { dialect.translate }
let(:tool_id) { "get_weather" }
let(:request_body) do
model
.default_options
.merge(inputs: prompt)
.tap do |payload|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt)
end
.to_json
end
let(:stream_request_body) do
model
.default_options
.merge(inputs: prompt)
.tap do |payload|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt)
payload[:stream] = true
end
.to_json
end
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
require_relative "endpoint_compliance"
class HuggingFaceMock < EndpointMock
def response(content)
[{ generated_text: content }]
end
@ -44,7 +10,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
.with(body: request_body)
.with(body: request_body(prompt))
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
@ -75,33 +41,65 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
.with(body: stream_request_body)
.with(body: request_body(prompt, stream: true))
.to_return(status: 200, body: chunks)
end
let(:tool_deltas) { ["<function", <<~REPLY, <<~REPLY] }
_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
let(:tool_call) { invocation }
it_behaves_like "an endpoint that can communicate with a completion service"
def request_body(prompt, stream: false)
model
.default_options
.merge(inputs: prompt)
.tap do |payload|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt)
payload[:stream] = true if stream
end
.to_json
end
end
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
subject(:endpoint) do
described_class.new("Llama2-*-chat-hf", DiscourseAi::Tokenizer::Llama2Tokenizer)
end
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
fab!(:user) { Fabricate(:user) }
let(:hf_mock) { HuggingFaceMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Llama2Classic, user)
end
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(hf_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(hf_mock)
end
end
end
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(hf_mock)
end
end
context "with tools" do
it "returns a function invoncation" do
compliance.streaming_mode_tools(hf_mock)
end
end
end
end
end

View File

@ -1,53 +1,8 @@
# frozen_string_literal: true
require_relative "endpoint_examples"
RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
let(:model_name) { "gpt-3.5-turbo" }
let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:tool_id) { "eujbuebfe" }
let(:tool_deltas) do
[
{ id: tool_id, function: {} },
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
{ id: tool_id, function: { name: "get_weather", arguments: "{" } },
{ id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } },
{ id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } },
]
end
let(:tool_call) do
{
id: tool_id,
function: {
name: "get_weather",
arguments: { location: "Sydney", unit: "c" }.to_json,
},
}
end
let(:request_body) do
model
.default_options
.merge(messages: prompt)
.tap { |b| b[:tools] = dialect.tools if generic_prompt.tools.present? }
.to_json
end
let(:stream_request_body) do
model
.default_options
.merge(messages: prompt, stream: true)
.tap { |b| b[:tools] = dialect.tools if generic_prompt.tools.present? }
.to_json
end
require_relative "endpoint_compliance"
class OpenAiMock < EndpointMock
def response(content, tool_call: false)
message_content =
if tool_call
@ -75,7 +30,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
.with(body: request_body)
.with(body: request_body(prompt, tool_call: tool_call))
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
end
@ -112,114 +67,144 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
WebMock
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
.with(body: stream_request_body)
.with(body: request_body(prompt, stream: true, tool_call: tool_call))
.to_return(status: 200, body: chunks)
end
it_behaves_like "an endpoint that can communicate with a completion service"
def tool_deltas
[
{ id: tool_id, function: {} },
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
{ id: tool_id, function: { name: "get_weather", arguments: "{" } },
{ id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } },
{ id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } },
]
end
context "when chunked encoding returns partial chunks" do
# See: https://github.com/bblimke/webmock/issues/629
let(:mock_net_http) do
Class.new(Net::HTTP) do
def request(*)
super do |response|
response.instance_eval do
def read_body(*, &)
@body.each(&)
end
end
def tool_response
{
id: tool_id,
function: {
name: "get_weather",
arguments: { location: "Sydney", unit: "c" }.to_json,
},
}
end
yield response if block_given?
def tool_id
"eujbuebfe"
end
response
end
def tool_payload
{
type: "function",
function: {
name: "get_weather",
description: "Get the weather in a city",
parameters: {
type: "object",
properties: {
location: {
type: "string",
description: "the city name",
},
unit: {
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
},
required: %w[location unit],
},
},
}
end
def request_body(prompt, stream: false, tool_call: false)
model
.default_options
.merge(messages: prompt)
.tap do |b|
b[:stream] = true if stream
b[:tools] = [tool_payload] if tool_call
end
.to_json
end
end
RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
subject(:endpoint) do
described_class.new("gpt-3.5-turbo", DiscourseAi::Tokenizer::OpenAiTokenizer)
end
fab!(:user) { Fabricate(:user) }
let(:open_ai_mock) { OpenAiMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::ChatGpt, user)
end
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(open_ai_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(open_ai_mock)
end
end
end
let(:remove_original_net_http) { Net.send(:remove_const, :HTTP) }
let(:original_http) { remove_original_net_http }
let(:stub_net_http) { Net.send(:const_set, :HTTP, mock_net_http) }
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(open_ai_mock)
end
let(:remove_stubbed_net_http) { Net.send(:remove_const, :HTTP) }
let(:restore_net_http) { Net.send(:const_set, :HTTP, original_http) }
it "will automatically recover from a bad payload" do
# this should not happen, but lets ensure nothing bad happens
# the row with test1 is invalid json
raw_data = <<~TEXT.strip
d|a|t|a|:| |{|"choices":[{"delta":{"content":"test,"}}]}
before do
mock_net_http
remove_original_net_http
stub_net_http
end
data: {"choices":[{"delta":{"content":"test1,"}}]
after do
remove_stubbed_net_http
restore_net_http
end
data: {"choices":[{"delta":|{"content":"test2,"}}]}
it "will automatically recover from a bad payload" do
# this should not happen, but lets ensure nothing bad happens
# the row with test1 is invalid json
raw_data = <<~TEXT
d|a|t|a|:| |{|"choices":[{"delta":{"content":"test,"}}]}
data: {"choices":[{"delta":{"content":"test3,"}}]|}
data: {"choices":[{"delta":{"content":"test1,"}}]
data: {"choices":[{|"|d|elta":{"content":"test4"}}]|}
data: {"choices":[{"delta":|{"content":"test2,"}}]}
data: [D|ONE]
TEXT
data: {"choices":[{"delta":{"content":"test3,"}}]|}
chunks = raw_data.split("|")
data: {"choices":[{|"|d|elta":{"content":"test4"}}]|}
open_ai_mock.with_chunk_array_support do
open_ai_mock.stub_streamed_response(compliance.dialect.translate, chunks) do
partials = []
data: [D|ONE]
TEXT
endpoint.perform_completion!(compliance.dialect, user) do |partial|
partials << partial
end
chunks = raw_data.split("|")
expect(partials.join).to eq("test,test1,test2,test3,test4")
end
end
end
end
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
status: 200,
body: chunks,
)
partials = []
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
llm.generate(
DiscourseAi::Completions::Prompt.new("test"),
user: Discourse.system_user,
) { |partial| partials << partial }
expect(partials.join).to eq("test,test2,test3,test4")
end
it "supports chunked encoding properly" do
raw_data = <<~TEXT
da|ta: {"choices":[{"delta":{"content":"test,"}}]}
data: {"choices":[{"delta":{"content":"test1,"}}]}
data: {"choices":[{"delta":|{"content":"test2,"}}]}
data: {"choices":[{"delta":{"content":"test3,"}}]|}
data: {"choices":[{|"|d|elta":{"content":"test4"}}]|}
data: [D|ONE]
TEXT
chunks = raw_data.split("|")
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
status: 200,
body: chunks,
)
partials = []
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
llm.generate(
DiscourseAi::Completions::Prompt.new("test"),
user: Discourse.system_user,
) { |partial| partials << partial }
expect(partials.join).to eq("test,test1,test2,test3,test4")
context "with tools" do
it "returns a function invoncation" do
compliance.streaming_mode_tools(open_ai_mock)
end
end
end
end
end

View File

@ -1,27 +1,14 @@
# frozen_string_literal: true
require_relative "endpoint_examples"
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::MixtralTokenizer) }
let(:model_name) { "mistralai/Mixtral-8x7B-Instruct-v0.1" }
let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
let(:tool_id) { "get_weather" }
require_relative "endpoint_compliance"
class VllmMock < EndpointMock
def response(content)
{
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
object: "text_completion",
created: 1_678_464_820,
model: model_name,
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
usage: {
prompt_tokens: 337,
completion_tokens: 162,
@ -34,7 +21,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
.with(body: request_body)
.with(body: model.default_options.merge(prompt: prompt).to_json)
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
@ -42,7 +29,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
+"data: " << {
id: "cmpl-#{SecureRandom.hex}",
created: 1_681_283_881,
model: model_name,
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
choices: [{ text: delta, finish_reason: finish_reason, index: 0 }],
index: 0,
}.to_json
@ -62,33 +49,62 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
WebMock
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
.with(body: stream_request_body)
.with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
.to_return(status: 200, body: chunks)
end
let(:tool_deltas) { ["<function", <<~REPLY, <<~REPLY] }
_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
let(:tool_call) { invocation }
it_behaves_like "an endpoint that can communicate with a completion service"
end
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
subject(:endpoint) do
described_class.new(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
DiscourseAi::Tokenizer::MixtralTokenizer,
)
end
fab!(:user) { Fabricate(:user) }
let(:anthropic_mock) { VllmMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mixtral, user)
end
let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(anthropic_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(anthropic_mock)
end
end
end
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(anthropic_mock)
end
end
context "with tools" do
it "returns a function invoncation" do
compliance.streaming_mode_tools(anthropic_mock)
end
end
end
end
end