diff --git a/lib/completions/dialects/llama2_classic.rb b/lib/completions/dialects/llama2_classic.rb index b6c58c8b..542e5f57 100644 --- a/lib/completions/dialects/llama2_classic.rb +++ b/lib/completions/dialects/llama2_classic.rb @@ -5,7 +5,7 @@ module DiscourseAi module Dialects class Llama2Classic def self.can_translate?(model_name) - "Llama2-*-chat-hf" == model_name + %w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name) end def translate(generic_prompt) diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 56882de9..43468c15 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -90,6 +90,7 @@ module DiscourseAi begin partial = extract_completion_from(raw_partial) + next if partial.nil? leftover = "" response_data << partial diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb index bd418380..48c1f607 100644 --- a/lib/completions/endpoints/hugging_face.rb +++ b/lib/completions/endpoints/hugging_face.rb @@ -5,7 +5,9 @@ module DiscourseAi module Endpoints class HuggingFace < Base def self.can_contact?(model_name) - %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf].include?(model_name) + %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf Llama2-chat-hf].include?( + model_name, + ) end def default_options @@ -19,9 +21,7 @@ module DiscourseAi private def model_uri - URI(SiteSetting.ai_hugging_face_api_url).tap do |uri| - uri.path = @streaming_mode ? "/generate_stream" : "/generate" - end + URI(SiteSetting.ai_hugging_face_api_url) end def prepare_payload(prompt, model_params) @@ -30,9 +30,11 @@ module DiscourseAi .tap do |payload| payload[:parameters].merge!(model_params) - token_limit = 2_000 || SiteSetting.ai_hugging_face_token_limit + token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000 payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt) + + payload[:stream] = true if @streaming_mode end end @@ -56,7 +58,7 @@ module DiscourseAi parsed.dig(:token, :text).to_s else - parsed[:generated_text].to_s + parsed[0][:generated_text].to_s end end @@ -64,7 +66,7 @@ module DiscourseAi decoded_chunk .split("\n") .map do |line| - data = line.split("data: ", 2)[1] + data = line.split("data:", 2)[1] data&.squish == "[DONE]" ? nil : data end .compact diff --git a/lib/inference/hugging_face_text_generation.rb b/lib/inference/hugging_face_text_generation.rb index 9a8cd22e..1ea35d3f 100644 --- a/lib/inference/hugging_face_text_generation.rb +++ b/lib/inference/hugging_face_text_generation.rb @@ -22,11 +22,6 @@ module ::DiscourseAi raise CompletionFailed if model.blank? url = URI(SiteSetting.ai_hugging_face_api_url) - if block_given? - url.path = "/generate_stream" - else - url.path = "/generate" - end headers = { "Content-Type" => "application/json" } if SiteSetting.ai_hugging_face_api_key.present? @@ -46,6 +41,8 @@ module ::DiscourseAi parameters[:temperature] = temperature if temperature parameters[:repetition_penalty] = repetition_penalty if repetition_penalty + payload[:stream] = true if block_given? + Net::HTTP.start( url.host, url.port, @@ -80,7 +77,7 @@ module ::DiscourseAi log.update!( raw_response_payload: response_body, request_tokens: tokenizer.size(prompt), - response_tokens: tokenizer.size(parsed_response[:generated_text]), + response_tokens: tokenizer.size(parsed_response.first[:generated_text]), ) return parsed_response end diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb index cfe76e76..b11413ce 100644 --- a/spec/lib/completions/endpoints/hugging_face_spec.rb +++ b/spec/lib/completions/endpoints/hugging_face_spec.rb @@ -15,20 +15,33 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do model .default_options .merge(inputs: prompt) - .tap { |payload| payload[:parameters][:max_new_tokens] = 2_000 - model.prompt_size(prompt) } + .tap do |payload| + payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - + model.prompt_size(prompt) + end + .to_json + end + let(:stream_request_body) do + model + .default_options + .merge(inputs: prompt) + .tap do |payload| + payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - + model.prompt_size(prompt) + payload[:stream] = true + end .to_json end - let(:stream_request_body) { request_body } before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" } def response(content) - { generated_text: content } + [{ generated_text: content }] end def stub_response(prompt, response_text) WebMock - .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}/generate") + .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") .with(body: request_body) .to_return(status: 200, body: JSON.dump(response(response_text))) end @@ -59,8 +72,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do chunks = chunks.join("\n\n") WebMock - .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}/generate_stream") - .with(body: request_body) + .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") + .with(body: stream_request_body) .to_return(status: 200, body: chunks) end