FIX: Many fixes for huggingface and llama2 inference (#335)

This commit is contained in:
Rafael dos Santos Silva 2023-12-06 11:22:42 -03:00 committed by GitHub
parent 24370a9ca6
commit d8267d8da0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 33 additions and 20 deletions

View File

@ -5,7 +5,7 @@ module DiscourseAi
module Dialects
class Llama2Classic
def self.can_translate?(model_name)
"Llama2-*-chat-hf" == model_name
%w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name)
end
def translate(generic_prompt)

View File

@ -90,6 +90,7 @@ module DiscourseAi
begin
partial = extract_completion_from(raw_partial)
next if partial.nil?
leftover = ""
response_data << partial

View File

@ -5,7 +5,9 @@ module DiscourseAi
module Endpoints
class HuggingFace < Base
def self.can_contact?(model_name)
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf].include?(model_name)
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf Llama2-chat-hf].include?(
model_name,
)
end
def default_options
@ -19,9 +21,7 @@ module DiscourseAi
private
def model_uri
URI(SiteSetting.ai_hugging_face_api_url).tap do |uri|
uri.path = @streaming_mode ? "/generate_stream" : "/generate"
end
URI(SiteSetting.ai_hugging_face_api_url)
end
def prepare_payload(prompt, model_params)
@ -30,9 +30,11 @@ module DiscourseAi
.tap do |payload|
payload[:parameters].merge!(model_params)
token_limit = 2_000 || SiteSetting.ai_hugging_face_token_limit
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt)
payload[:stream] = true if @streaming_mode
end
end
@ -56,7 +58,7 @@ module DiscourseAi
parsed.dig(:token, :text).to_s
else
parsed[:generated_text].to_s
parsed[0][:generated_text].to_s
end
end
@ -64,7 +66,7 @@ module DiscourseAi
decoded_chunk
.split("\n")
.map do |line|
data = line.split("data: ", 2)[1]
data = line.split("data:", 2)[1]
data&.squish == "[DONE]" ? nil : data
end
.compact

View File

@ -22,11 +22,6 @@ module ::DiscourseAi
raise CompletionFailed if model.blank?
url = URI(SiteSetting.ai_hugging_face_api_url)
if block_given?
url.path = "/generate_stream"
else
url.path = "/generate"
end
headers = { "Content-Type" => "application/json" }
if SiteSetting.ai_hugging_face_api_key.present?
@ -46,6 +41,8 @@ module ::DiscourseAi
parameters[:temperature] = temperature if temperature
parameters[:repetition_penalty] = repetition_penalty if repetition_penalty
payload[:stream] = true if block_given?
Net::HTTP.start(
url.host,
url.port,
@ -80,7 +77,7 @@ module ::DiscourseAi
log.update!(
raw_response_payload: response_body,
request_tokens: tokenizer.size(prompt),
response_tokens: tokenizer.size(parsed_response[:generated_text]),
response_tokens: tokenizer.size(parsed_response.first[:generated_text]),
)
return parsed_response
end

View File

@ -15,20 +15,33 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
model
.default_options
.merge(inputs: prompt)
.tap { |payload| payload[:parameters][:max_new_tokens] = 2_000 - model.prompt_size(prompt) }
.tap do |payload|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt)
end
.to_json
end
let(:stream_request_body) do
model
.default_options
.merge(inputs: prompt)
.tap do |payload|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt)
payload[:stream] = true
end
.to_json
end
let(:stream_request_body) { request_body }
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
def response(content)
{ generated_text: content }
[{ generated_text: content }]
end
def stub_response(prompt, response_text)
WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}/generate")
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
.with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
@ -59,8 +72,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
chunks = chunks.join("\n\n")
WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}/generate_stream")
.with(body: request_body)
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
.with(body: stream_request_body)
.to_return(status: 200, body: chunks)
end