diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 405fe82c..df1712a7 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -205,12 +205,10 @@ module DiscourseAi tokenizer.size(extract_prompt_for_tokenizer(prompt)) end - attr_reader :tokenizer + attr_reader :tokenizer, :model protected - attr_reader :model - # should normalize temperature, max_tokens, stop_words to endpoint specific values def normalize_model_params(model_params) raise NotImplementedError diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index 61b23e69..915547af 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -1,19 +1,8 @@ # frozen_String_literal: true -require_relative "endpoint_examples" - -RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do - subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) } - - let(:model_name) { "claude-2" } - let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) } - let(:prompt) { dialect.translate } - - let(:request_body) { model.default_options.merge(prompt: prompt).to_json } - let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json } - - let(:tool_id) { "get_weather" } +require_relative "endpoint_compliance" +class AnthropicMock < EndpointMock def response(content) { completion: content, @@ -21,7 +10,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do stop_reason: "stop_sequence", truncated: false, log_id: "12dcc7feafbee4a394e0de9dffde3ac5", - model: model_name, + model: "claude-2", exception: nil, } end @@ -29,7 +18,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request(:post, "https://api.anthropic.com/v1/complete") - .with(body: request_body) + .with(body: model.default_options.merge(prompt: prompt).to_json) .to_return(status: 200, body: JSON.dump(response(response_text))) end @@ -59,23 +48,49 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do WebMock .stub_request(:post, "https://api.anthropic.com/v1/complete") - .with(body: stream_request_body) + .with(body: model.default_options.merge(prompt: prompt, stream: true).to_json) .to_return(status: 200, body: chunks) end - - let(:tool_deltas) { ["Let me use a tool for that - - get_weather - - Sydney - c - - - - REPLY - - let(:tool_call) { invocation } - - it_behaves_like "an endpoint that can communicate with a completion service" +end + +RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do + subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) } + + fab!(:user) { Fabricate(:user) } + + let(:anthropic_mock) { AnthropicMock.new(endpoint) } + + let(:compliance) do + EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user) + end + + describe "#perform_completion!" do + context "when using regular mode" do + context "with simple prompts" do + it "completes a trivial prompt and logs the response" do + compliance.regular_mode_simple_prompt(anthropic_mock) + end + end + + context "with tools" do + it "returns a function invocation" do + compliance.regular_mode_tools(anthropic_mock) + end + end + end + + describe "when using streaming mode" do + context "with simple prompts" do + it "completes a trivial prompt and logs the response" do + compliance.streaming_mode_simple_prompt(anthropic_mock) + end + end + + context "with tools" do + it "returns a function invoncation" do + compliance.streaming_mode_tools(anthropic_mock) + end + end + end + end end diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index 33c0d272..34271094 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -1,28 +1,10 @@ # frozen_string_literal: true -require_relative "endpoint_examples" +require_relative "endpoint_compliance" require "aws-eventstream" require "aws-sigv4" -RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do - subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) } - - let(:model_name) { "claude-2" } - let(:bedrock_name) { "claude-v2:1" } - let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) } - let(:prompt) { dialect.translate } - - let(:request_body) { model.default_options.merge(prompt: prompt).to_json } - let(:stream_request_body) { request_body } - - let(:tool_id) { "get_weather" } - - before do - SiteSetting.ai_bedrock_access_key_id = "123456" - SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd" - SiteSetting.ai_bedrock_region = "us-east-1" - end - +class BedrockMock < EndpointMock def response(content) { completion: content, @@ -30,19 +12,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do stop_reason: "stop_sequence", truncated: false, log_id: "12dcc7feafbee4a394e0de9dffde3ac5", - model: model_name, + model: "claude", exception: nil, } end - def stub_response(prompt, response_text, tool_call: false) + def stub_response(prompt, response_content, tool_call: false) WebMock - .stub_request( - :post, - "https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_name}/invoke", - ) - .with(body: request_body) - .to_return(status: 200, body: JSON.dump(response(response_text))) + .stub_request(:post, "#{base_url}/invoke") + .with(body: model.default_options.merge(prompt: prompt).to_json) + .to_return(status: 200, body: JSON.dump(response(response_content))) end def stream_line(delta, finish_reason: nil) @@ -83,27 +62,60 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do end WebMock - .stub_request( - :post, - "https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_name}/invoke-with-response-stream", - ) - .with(body: stream_request_body) + .stub_request(:post, "#{base_url}/invoke-with-response-stream") + .with(body: model.default_options.merge(prompt: prompt).to_json) .to_return(status: 200, body: chunks) end - let(:tool_deltas) { [" - - get_weather - - Sydney - c - - - - REPLY - - let(:tool_call) { invocation } - - it_behaves_like "an endpoint that can communicate with a completion service" + def base_url + "https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.claude-v2:1" + end +end + +RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do + subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) } + + fab!(:user) { Fabricate(:user) } + + let(:bedrock_mock) { BedrockMock.new(endpoint) } + + let(:compliance) do + EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user) + end + + before do + SiteSetting.ai_bedrock_access_key_id = "123456" + SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd" + SiteSetting.ai_bedrock_region = "us-east-1" + end + + describe "#perform_completion!" do + context "when using regular mode" do + context "with simple prompts" do + it "completes a trivial prompt and logs the response" do + compliance.regular_mode_simple_prompt(bedrock_mock) + end + end + + context "with tools" do + it "returns a function invocation" do + compliance.regular_mode_tools(bedrock_mock) + end + end + end + + describe "when using streaming mode" do + context "with simple prompts" do + it "completes a trivial prompt and logs the response" do + compliance.streaming_mode_simple_prompt(bedrock_mock) + end + end + + context "with tools" do + it "returns a function invoncation" do + compliance.streaming_mode_tools(bedrock_mock) + end + end + end + end end diff --git a/spec/lib/completions/endpoints/endpoint_compliance.rb b/spec/lib/completions/endpoints/endpoint_compliance.rb new file mode 100644 index 00000000..0fc3671c --- /dev/null +++ b/spec/lib/completions/endpoints/endpoint_compliance.rb @@ -0,0 +1,229 @@ +# frozen_string_literal: true + +require "net/http" + +class EndpointMock + def initialize(model) + @model = model + end + + attr_reader :model + + def stub_simple_call(prompt) + stub_response(prompt, simple_response) + end + + def stub_tool_call(prompt) + stub_response(prompt, tool_response, tool_call: true) + end + + def stub_streamed_simple_call(prompt) + with_chunk_array_support do + stub_streamed_response(prompt, streamed_simple_deltas) + yield + end + end + + def stub_streamed_tool_call(prompt) + with_chunk_array_support do + stub_streamed_response(prompt, tool_deltas, tool_call: true) + yield + end + end + + def simple_response + "1. Serenity\\n2. Laughter\\n3. Adventure" + end + + def streamed_simple_deltas + ["Mount", "ain", " ", "Tree ", "Frog"] + end + + def tool_deltas + ["Let me use a tool for that + + get_weather + + Sydney + c + + + + REPLY + end + + def tool_response + tool_deltas.join + end + + def invocation_response + <<~TEXT + + + get_weather + #{tool_id} + + Sydney + c + + + + TEXT + end + + def tool_id + "get_weather" + end + + def tool + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + required: true, + }, + ], + } + end + + def with_chunk_array_support + mock = mocked_http + @original_net_http = ::Net.send(:remove_const, :HTTP) + ::Net.send(:const_set, :HTTP, mock) + + yield + ensure + ::Net.send(:remove_const, :HTTP) + ::Net.send(:const_set, :HTTP, @original_net_http) + end + + protected + + # Copied from https://github.com/bblimke/webmock/issues/629 + # Workaround for stubbing a streamed response + def mocked_http + Class.new(::Net::HTTP) do + def request(*) + super do |response| + response.instance_eval do + def read_body(*, &block) + if block_given? + @body.each(&block) + else + super + end + end + end + + yield response if block_given? + + response + end + end + end + end +end + +class EndpointsCompliance + def initialize(rspec, endpoint, dialect_klass, user) + @rspec = rspec + @endpoint = endpoint + @dialect_klass = dialect_klass + @user = user + end + + delegate :expect, :eq, :be_present, to: :rspec + + def generic_prompt(tools: []) + DiscourseAi::Completions::Prompt.new( + "You write words", + messages: [{ type: :user, content: "write 3 words" }], + tools: tools, + ) + end + + def dialect(prompt: generic_prompt) + dialect_klass.new(prompt, endpoint.model) + end + + def regular_mode_simple_prompt(mock) + mock.stub_simple_call(dialect.translate) + + completion_response = endpoint.perform_completion!(dialect, user) + + expect(completion_response).to eq(mock.simple_response) + + expect(AiApiAuditLog.count).to eq(1) + log = AiApiAuditLog.first + + expect(log.provider_id).to eq(endpoint.provider_id) + expect(log.user_id).to eq(user.id) + expect(log.raw_request_payload).to be_present + expect(log.raw_response_payload).to eq(mock.response(completion_response).to_json) + expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate)) + expect(log.response_tokens).to eq(endpoint.tokenizer.size(completion_response)) + end + + def regular_mode_tools(mock) + prompt = generic_prompt(tools: [mock.tool]) + a_dialect = dialect(prompt: prompt) + + mock.stub_tool_call(a_dialect.translate) + + completion_response = endpoint.perform_completion!(a_dialect, user) + + expect(completion_response).to eq(mock.invocation_response) + end + + def streaming_mode_simple_prompt(mock) + mock.stub_streamed_simple_call(dialect.translate) do + completion_response = +"" + + endpoint.perform_completion!(dialect, user) do |partial, cancel| + completion_response << partial + cancel.call if completion_response.split(" ").length == 2 + end + + expect(AiApiAuditLog.count).to eq(1) + log = AiApiAuditLog.first + + expect(log.provider_id).to eq(endpoint.provider_id) + expect(log.user_id).to eq(user.id) + expect(log.raw_request_payload).to be_present + expect(log.raw_response_payload).to be_present + expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate)) + expect(log.response_tokens).to eq( + endpoint.tokenizer.size(mock.streamed_simple_deltas[0...-1].join), + ) + end + end + + def streaming_mode_tools(mock) + prompt = generic_prompt(tools: [mock.tool]) + a_dialect = dialect(prompt: prompt) + + mock.stub_streamed_tool_call(a_dialect.translate) do + buffered_partial = +"" + + endpoint.perform_completion!(a_dialect, user) do |partial, cancel| + buffered_partial << partial + cancel.call if buffered_partial.include?("") + end + + expect(buffered_partial).to eq(mock.invocation_response) + end + end + + attr_reader :rspec, :endpoint, :dialect_klass, :user +end diff --git a/spec/lib/completions/endpoints/endpoint_examples.rb b/spec/lib/completions/endpoints/endpoint_examples.rb deleted file mode 100644 index 60eeac83..00000000 --- a/spec/lib/completions/endpoints/endpoint_examples.rb +++ /dev/null @@ -1,175 +0,0 @@ -# frozen_string_literal: true - -RSpec.shared_examples "an endpoint that can communicate with a completion service" do - # Copied from https://github.com/bblimke/webmock/issues/629 - # Workaround for stubbing a streamed response - before do - mocked_http = - Class.new(Net::HTTP) do - def request(*) - super do |response| - response.instance_eval do - def read_body(*, &block) - if block_given? - @body.each(&block) - else - super - end - end - end - - yield response if block_given? - - response - end - end - end - - @original_net_http = Net.send(:remove_const, :HTTP) - Net.send(:const_set, :HTTP, mocked_http) - end - - after do - Net.send(:remove_const, :HTTP) - Net.send(:const_set, :HTTP, @original_net_http) - end - - let(:generic_prompt) do - DiscourseAi::Completions::Prompt.new( - "You write words", - messages: [{ type: :user, content: "write 3 words" }], - ) - end - - describe "#perform_completion!" do - fab!(:user) { Fabricate(:user) } - - let(:tool) do - { - name: "get_weather", - description: "Get the weather in a city", - parameters: [ - { name: "location", type: "string", description: "the city name", required: true }, - { - name: "unit", - type: "string", - description: "the unit of measurement celcius c or fahrenheit f", - enum: %w[c f], - required: true, - }, - ], - } - end - - let(:invocation) { <<~TEXT } - - - get_weather - #{tool_id || "get_weather"} - - Sydney - c - - - - TEXT - - context "when using regular mode" do - context "with simple prompts" do - let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" } - - before { stub_response(prompt, response_text) } - - it "can complete a trivial prompt" do - completion_response = model.perform_completion!(dialect, user) - - expect(completion_response).to eq(response_text) - end - - it "creates an audit log for the request" do - model.perform_completion!(dialect, user) - - expect(AiApiAuditLog.count).to eq(1) - log = AiApiAuditLog.first - - response_body = response(response_text).to_json - - expect(log.provider_id).to eq(model.provider_id) - expect(log.user_id).to eq(user.id) - expect(log.raw_request_payload).to eq(request_body) - expect(log.raw_response_payload).to eq(response_body) - expect(log.request_tokens).to eq(model.prompt_size(prompt)) - expect(log.response_tokens).to eq(model.tokenizer.size(response_text)) - end - end - - context "with functions" do - before do - generic_prompt.tools = [tool] - stub_response(prompt, tool_call, tool_call: true) - end - - it "returns a function invocation" do - completion_response = model.perform_completion!(dialect, user) - - expect(completion_response).to eq(invocation) - end - end - end - - context "when using stream mode" do - context "with simple prompts" do - let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] } - - before { stub_streamed_response(prompt, deltas) } - - it "can complete a trivial prompt" do - completion_response = +"" - - model.perform_completion!(dialect, user) do |partial, cancel| - completion_response << partial - cancel.call if completion_response.split(" ").length == 2 - end - - expect(completion_response).to eq(deltas[0...-1].join) - end - - it "creates an audit log and updates is on each read." do - completion_response = +"" - - model.perform_completion!(dialect, user) do |partial, cancel| - completion_response << partial - cancel.call if completion_response.split(" ").length == 2 - end - - expect(AiApiAuditLog.count).to eq(1) - log = AiApiAuditLog.first - - expect(log.provider_id).to eq(model.provider_id) - expect(log.user_id).to eq(user.id) - expect(log.raw_request_payload).to eq(stream_request_body) - expect(log.raw_response_payload).to be_present - expect(log.request_tokens).to eq(model.prompt_size(prompt)) - expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join)) - end - end - - context "with functions" do - before do - generic_prompt.tools = [tool] - stub_streamed_response(prompt, tool_deltas, tool_call: true) - end - - it "waits for the invocation to finish before calling the partial" do - buffered_partial = "" - - model.perform_completion!(dialect, user) do |partial, cancel| - buffered_partial = partial if partial.include?("") - end - - expect(buffered_partial).to eq(invocation) - end - end - end - end -end diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index 037b8c28..841dde50 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -1,69 +1,8 @@ # frozen_string_literal: true -require_relative "endpoint_examples" - -RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do - subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) } - - let(:model_name) { "gemini-pro" } - let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) } - let(:prompt) { dialect.translate } - - let(:tool_id) { "get_weather" } - - let(:tool_payload) do - { - name: "get_weather", - description: "Get the weather in a city", - parameters: { - type: "object", - required: %w[location unit], - properties: { - "location" => { - type: "string", - description: "the city name", - }, - "unit" => { - type: "string", - description: "the unit of measurement celcius c or fahrenheit f", - enum: %w[c f], - }, - }, - }, - } - end - - let(:request_body) do - model - .default_options - .merge(contents: prompt) - .tap do |b| - b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt.tools.present? - end - .to_json - end - let(:stream_request_body) do - model - .default_options - .merge(contents: prompt) - .tap do |b| - b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt.tools.present? - end - .to_json - end - - let(:tool_deltas) do - [ - { "functionCall" => { name: "get_weather", args: {} } }, - { "functionCall" => { name: "get_weather", args: { location: "" } } }, - { "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }, - ] - end - - let(:tool_call) do - { "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } } - end +require_relative "endpoint_compliance" +class GeminiMock < EndpointMock def response(content, tool_call: false) { candidates: [ @@ -97,9 +36,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do WebMock .stub_request( :post, - "https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}", + "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=#{SiteSetting.ai_gemini_api_key}", ) - .with(body: request_body) + .with(body: request_body(prompt, tool_call)) .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call))) end @@ -139,11 +78,93 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do WebMock .stub_request( :post, - "https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}", + "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}", ) - .with(body: stream_request_body) + .with(body: request_body(prompt, tool_call)) .to_return(status: 200, body: chunks) end - it_behaves_like "an endpoint that can communicate with a completion service" + def tool_payload + { + name: "get_weather", + description: "Get the weather in a city", + parameters: { + type: "object", + required: %w[location unit], + properties: { + "location" => { + type: "string", + description: "the city name", + }, + "unit" => { + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + }, + }, + }, + } + end + + def request_body(prompt, tool_call) + model + .default_options + .merge(contents: prompt) + .tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if tool_call } + .to_json + end + + def tool_deltas + [ + { "functionCall" => { name: "get_weather", args: {} } }, + { "functionCall" => { name: "get_weather", args: { location: "" } } }, + { "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }, + ] + end + + def tool_response + { "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } } + end +end + +RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do + subject(:endpoint) { described_class.new("gemini-pro", DiscourseAi::Tokenizer::OpenAiTokenizer) } + + fab!(:user) { Fabricate(:user) } + + let(:bedrock_mock) { GeminiMock.new(endpoint) } + + let(:compliance) do + EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user) + end + + describe "#perform_completion!" do + context "when using regular mode" do + context "with simple prompts" do + it "completes a trivial prompt and logs the response" do + compliance.regular_mode_simple_prompt(bedrock_mock) + end + end + + context "with tools" do + it "returns a function invocation" do + compliance.regular_mode_tools(bedrock_mock) + end + end + end + + describe "when using streaming mode" do + context "with simple prompts" do + it "completes a trivial prompt and logs the response" do + compliance.streaming_mode_simple_prompt(bedrock_mock) + end + end + + context "with tools" do + it "returns a function invoncation" do + compliance.streaming_mode_tools(bedrock_mock) + end + end + end + end end diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb index 0520b661..b68ecbe4 100644 --- a/spec/lib/completions/endpoints/hugging_face_spec.rb +++ b/spec/lib/completions/endpoints/hugging_face_spec.rb @@ -1,42 +1,8 @@ # frozen_string_literal: true -require_relative "endpoint_examples" - -RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do - subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::Llama2Tokenizer) } - - let(:model_name) { "Llama2-*-chat-hf" } - let(:dialect) do - DiscourseAi::Completions::Dialects::Llama2Classic.new(generic_prompt, model_name) - end - let(:prompt) { dialect.translate } - - let(:tool_id) { "get_weather" } - - let(:request_body) do - model - .default_options - .merge(inputs: prompt) - .tap do |payload| - payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - - model.prompt_size(prompt) - end - .to_json - end - let(:stream_request_body) do - model - .default_options - .merge(inputs: prompt) - .tap do |payload| - payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - - model.prompt_size(prompt) - payload[:stream] = true - end - .to_json - end - - before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" } +require_relative "endpoint_compliance" +class HuggingFaceMock < EndpointMock def response(content) [{ generated_text: content }] end @@ -44,7 +10,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") - .with(body: request_body) + .with(body: request_body(prompt)) .to_return(status: 200, body: JSON.dump(response(response_text))) end @@ -75,33 +41,65 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do WebMock .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") - .with(body: stream_request_body) + .with(body: request_body(prompt, stream: true)) .to_return(status: 200, body: chunks) end - let(:tool_deltas) { [" - - get_weather - - Sydney - c - - - - REPLY - - - get_weather - - Sydney - c - - - - REPLY - - let(:tool_call) { invocation } - - it_behaves_like "an endpoint that can communicate with a completion service" + def request_body(prompt, stream: false) + model + .default_options + .merge(inputs: prompt) + .tap do |payload| + payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - + model.prompt_size(prompt) + payload[:stream] = true if stream + end + .to_json + end +end + +RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do + subject(:endpoint) do + described_class.new("Llama2-*-chat-hf", DiscourseAi::Tokenizer::Llama2Tokenizer) + end + + before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" } + + fab!(:user) { Fabricate(:user) } + + let(:hf_mock) { HuggingFaceMock.new(endpoint) } + + let(:compliance) do + EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Llama2Classic, user) + end + + describe "#perform_completion!" do + context "when using regular mode" do + context "with simple prompts" do + it "completes a trivial prompt and logs the response" do + compliance.regular_mode_simple_prompt(hf_mock) + end + end + + context "with tools" do + it "returns a function invocation" do + compliance.regular_mode_tools(hf_mock) + end + end + end + + describe "when using streaming mode" do + context "with simple prompts" do + it "completes a trivial prompt and logs the response" do + compliance.streaming_mode_simple_prompt(hf_mock) + end + end + + context "with tools" do + it "returns a function invoncation" do + compliance.streaming_mode_tools(hf_mock) + end + end + end + end end diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index 3723d9b9..5124dab1 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -1,53 +1,8 @@ # frozen_string_literal: true -require_relative "endpoint_examples" - -RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do - subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) } - - let(:model_name) { "gpt-3.5-turbo" } - let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) } - let(:prompt) { dialect.translate } - - let(:tool_id) { "eujbuebfe" } - - let(:tool_deltas) do - [ - { id: tool_id, function: {} }, - { id: tool_id, function: { name: "get_weather", arguments: "" } }, - { id: tool_id, function: { name: "get_weather", arguments: "" } }, - { id: tool_id, function: { name: "get_weather", arguments: "{" } }, - { id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } }, - { id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } }, - ] - end - - let(:tool_call) do - { - id: tool_id, - function: { - name: "get_weather", - arguments: { location: "Sydney", unit: "c" }.to_json, - }, - } - end - - let(:request_body) do - model - .default_options - .merge(messages: prompt) - .tap { |b| b[:tools] = dialect.tools if generic_prompt.tools.present? } - .to_json - end - - let(:stream_request_body) do - model - .default_options - .merge(messages: prompt, stream: true) - .tap { |b| b[:tools] = dialect.tools if generic_prompt.tools.present? } - .to_json - end +require_relative "endpoint_compliance" +class OpenAiMock < EndpointMock def response(content, tool_call: false) message_content = if tool_call @@ -75,7 +30,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request(:post, "https://api.openai.com/v1/chat/completions") - .with(body: request_body) + .with(body: request_body(prompt, tool_call: tool_call)) .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call))) end @@ -112,114 +67,144 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do WebMock .stub_request(:post, "https://api.openai.com/v1/chat/completions") - .with(body: stream_request_body) + .with(body: request_body(prompt, stream: true, tool_call: tool_call)) .to_return(status: 200, body: chunks) end - it_behaves_like "an endpoint that can communicate with a completion service" + def tool_deltas + [ + { id: tool_id, function: {} }, + { id: tool_id, function: { name: "get_weather", arguments: "" } }, + { id: tool_id, function: { name: "get_weather", arguments: "" } }, + { id: tool_id, function: { name: "get_weather", arguments: "{" } }, + { id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } }, + { id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } }, + ] + end - context "when chunked encoding returns partial chunks" do - # See: https://github.com/bblimke/webmock/issues/629 - let(:mock_net_http) do - Class.new(Net::HTTP) do - def request(*) - super do |response| - response.instance_eval do - def read_body(*, &) - @body.each(&) - end - end + def tool_response + { + id: tool_id, + function: { + name: "get_weather", + arguments: { location: "Sydney", unit: "c" }.to_json, + }, + } + end - yield response if block_given? + def tool_id + "eujbuebfe" + end - response - end + def tool_payload + { + type: "function", + function: { + name: "get_weather", + description: "Get the weather in a city", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "the city name", + }, + unit: { + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + }, + }, + required: %w[location unit], + }, + }, + } + end + + def request_body(prompt, stream: false, tool_call: false) + model + .default_options + .merge(messages: prompt) + .tap do |b| + b[:stream] = true if stream + b[:tools] = [tool_payload] if tool_call + end + .to_json + end +end + +RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do + subject(:endpoint) do + described_class.new("gpt-3.5-turbo", DiscourseAi::Tokenizer::OpenAiTokenizer) + end + + fab!(:user) { Fabricate(:user) } + + let(:open_ai_mock) { OpenAiMock.new(endpoint) } + + let(:compliance) do + EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::ChatGpt, user) + end + + describe "#perform_completion!" do + context "when using regular mode" do + context "with simple prompts" do + it "completes a trivial prompt and logs the response" do + compliance.regular_mode_simple_prompt(open_ai_mock) + end + end + + context "with tools" do + it "returns a function invocation" do + compliance.regular_mode_tools(open_ai_mock) end end end - let(:remove_original_net_http) { Net.send(:remove_const, :HTTP) } - let(:original_http) { remove_original_net_http } - let(:stub_net_http) { Net.send(:const_set, :HTTP, mock_net_http) } + describe "when using streaming mode" do + context "with simple prompts" do + it "completes a trivial prompt and logs the response" do + compliance.streaming_mode_simple_prompt(open_ai_mock) + end - let(:remove_stubbed_net_http) { Net.send(:remove_const, :HTTP) } - let(:restore_net_http) { Net.send(:const_set, :HTTP, original_http) } + it "will automatically recover from a bad payload" do + # this should not happen, but lets ensure nothing bad happens + # the row with test1 is invalid json + raw_data = <<~TEXT.strip + d|a|t|a|:| |{|"choices":[{"delta":{"content":"test,"}}]} + + data: {"choices":[{"delta":{"content":"test1,"}}] + + data: {"choices":[{"delta":|{"content":"test2,"}}]} + + data: {"choices":[{"delta":{"content":"test3,"}}]|} + + data: {"choices":[{|"|d|elta":{"content":"test4"}}]|} + + data: [D|ONE] + TEXT - before do - mock_net_http - remove_original_net_http - stub_net_http - end + chunks = raw_data.split("|") - after do - remove_stubbed_net_http - restore_net_http - end + open_ai_mock.with_chunk_array_support do + open_ai_mock.stub_streamed_response(compliance.dialect.translate, chunks) do + partials = [] - it "will automatically recover from a bad payload" do - # this should not happen, but lets ensure nothing bad happens - # the row with test1 is invalid json - raw_data = <<~TEXT -d|a|t|a|:| |{|"choices":[{"delta":{"content":"test,"}}]} + endpoint.perform_completion!(compliance.dialect, user) do |partial| + partials << partial + end -data: {"choices":[{"delta":{"content":"test1,"}}] + expect(partials.join).to eq("test,test1,test2,test3,test4") + end + end + end + end -data: {"choices":[{"delta":|{"content":"test2,"}}]} - -data: {"choices":[{"delta":{"content":"test3,"}}]|} - -data: {"choices":[{|"|d|elta":{"content":"test4"}}]|} - -data: [D|ONE] - TEXT - - chunks = raw_data.split("|") - - stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( - status: 200, - body: chunks, - ) - - partials = [] - llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") - llm.generate( - DiscourseAi::Completions::Prompt.new("test"), - user: Discourse.system_user, - ) { |partial| partials << partial } - - expect(partials.join).to eq("test,test2,test3,test4") - end - - it "supports chunked encoding properly" do - raw_data = <<~TEXT -da|ta: {"choices":[{"delta":{"content":"test,"}}]} - -data: {"choices":[{"delta":{"content":"test1,"}}]} - -data: {"choices":[{"delta":|{"content":"test2,"}}]} - -data: {"choices":[{"delta":{"content":"test3,"}}]|} - -data: {"choices":[{|"|d|elta":{"content":"test4"}}]|} - -data: [D|ONE] - TEXT - - chunks = raw_data.split("|") - - stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( - status: 200, - body: chunks, - ) - - partials = [] - llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") - llm.generate( - DiscourseAi::Completions::Prompt.new("test"), - user: Discourse.system_user, - ) { |partial| partials << partial } - - expect(partials.join).to eq("test,test1,test2,test3,test4") + context "with tools" do + it "returns a function invoncation" do + compliance.streaming_mode_tools(open_ai_mock) + end + end end end end diff --git a/spec/lib/completions/endpoints/vllm_spec.rb b/spec/lib/completions/endpoints/vllm_spec.rb index 143e40b9..245e816a 100644 --- a/spec/lib/completions/endpoints/vllm_spec.rb +++ b/spec/lib/completions/endpoints/vllm_spec.rb @@ -1,27 +1,14 @@ # frozen_string_literal: true -require_relative "endpoint_examples" - -RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do - subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::MixtralTokenizer) } - - let(:model_name) { "mistralai/Mixtral-8x7B-Instruct-v0.1" } - let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) } - let(:prompt) { dialect.translate } - - let(:request_body) { model.default_options.merge(prompt: prompt).to_json } - let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json } - - before { SiteSetting.ai_vllm_endpoint = "https://test.dev" } - - let(:tool_id) { "get_weather" } +require_relative "endpoint_compliance" +class VllmMock < EndpointMock def response(content) { id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S", object: "text_completion", created: 1_678_464_820, - model: model_name, + model: "mistralai/Mixtral-8x7B-Instruct-v0.1", usage: { prompt_tokens: 337, completion_tokens: 162, @@ -34,7 +21,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions") - .with(body: request_body) + .with(body: model.default_options.merge(prompt: prompt).to_json) .to_return(status: 200, body: JSON.dump(response(response_text))) end @@ -42,7 +29,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do +"data: " << { id: "cmpl-#{SecureRandom.hex}", created: 1_681_283_881, - model: model_name, + model: "mistralai/Mixtral-8x7B-Instruct-v0.1", choices: [{ text: delta, finish_reason: finish_reason, index: 0 }], index: 0, }.to_json @@ -62,33 +49,62 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do WebMock .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions") - .with(body: stream_request_body) + .with(body: model.default_options.merge(prompt: prompt, stream: true).to_json) .to_return(status: 200, body: chunks) end - - let(:tool_deltas) { [" - - get_weather - - Sydney - c - - - - REPLY - - - get_weather - - Sydney - c - - - - REPLY - - let(:tool_call) { invocation } - - it_behaves_like "an endpoint that can communicate with a completion service" +end + +RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do + subject(:endpoint) do + described_class.new( + "mistralai/Mixtral-8x7B-Instruct-v0.1", + DiscourseAi::Tokenizer::MixtralTokenizer, + ) + end + + fab!(:user) { Fabricate(:user) } + + let(:anthropic_mock) { VllmMock.new(endpoint) } + + let(:compliance) do + EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mixtral, user) + end + + let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) } + let(:prompt) { dialect.translate } + + let(:request_body) { model.default_options.merge(prompt: prompt).to_json } + let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json } + + before { SiteSetting.ai_vllm_endpoint = "https://test.dev" } + + describe "#perform_completion!" do + context "when using regular mode" do + context "with simple prompts" do + it "completes a trivial prompt and logs the response" do + compliance.regular_mode_simple_prompt(anthropic_mock) + end + end + + context "with tools" do + it "returns a function invocation" do + compliance.regular_mode_tools(anthropic_mock) + end + end + end + + describe "when using streaming mode" do + context "with simple prompts" do + it "completes a trivial prompt and logs the response" do + compliance.streaming_mode_simple_prompt(anthropic_mock) + end + end + + context "with tools" do + it "returns a function invoncation" do + compliance.streaming_mode_tools(anthropic_mock) + end + end + end + end end