mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-09 11:48:47 +00:00
FIX: Many fixes for huggingface and llama2 inference (#335)
This commit is contained in:
parent
24370a9ca6
commit
d8267d8da0
@ -5,7 +5,7 @@ module DiscourseAi
|
||||
module Dialects
|
||||
class Llama2Classic
|
||||
def self.can_translate?(model_name)
|
||||
"Llama2-*-chat-hf" == model_name
|
||||
%w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name)
|
||||
end
|
||||
|
||||
def translate(generic_prompt)
|
||||
|
@ -90,6 +90,7 @@ module DiscourseAi
|
||||
|
||||
begin
|
||||
partial = extract_completion_from(raw_partial)
|
||||
next if partial.nil?
|
||||
leftover = ""
|
||||
response_data << partial
|
||||
|
||||
|
@ -5,7 +5,9 @@ module DiscourseAi
|
||||
module Endpoints
|
||||
class HuggingFace < Base
|
||||
def self.can_contact?(model_name)
|
||||
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf].include?(model_name)
|
||||
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf Llama2-chat-hf].include?(
|
||||
model_name,
|
||||
)
|
||||
end
|
||||
|
||||
def default_options
|
||||
@ -19,9 +21,7 @@ module DiscourseAi
|
||||
private
|
||||
|
||||
def model_uri
|
||||
URI(SiteSetting.ai_hugging_face_api_url).tap do |uri|
|
||||
uri.path = @streaming_mode ? "/generate_stream" : "/generate"
|
||||
end
|
||||
URI(SiteSetting.ai_hugging_face_api_url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params)
|
||||
@ -30,9 +30,11 @@ module DiscourseAi
|
||||
.tap do |payload|
|
||||
payload[:parameters].merge!(model_params)
|
||||
|
||||
token_limit = 2_000 || SiteSetting.ai_hugging_face_token_limit
|
||||
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
|
||||
|
||||
payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt)
|
||||
|
||||
payload[:stream] = true if @streaming_mode
|
||||
end
|
||||
end
|
||||
|
||||
@ -56,7 +58,7 @@ module DiscourseAi
|
||||
|
||||
parsed.dig(:token, :text).to_s
|
||||
else
|
||||
parsed[:generated_text].to_s
|
||||
parsed[0][:generated_text].to_s
|
||||
end
|
||||
end
|
||||
|
||||
@ -64,7 +66,7 @@ module DiscourseAi
|
||||
decoded_chunk
|
||||
.split("\n")
|
||||
.map do |line|
|
||||
data = line.split("data: ", 2)[1]
|
||||
data = line.split("data:", 2)[1]
|
||||
data&.squish == "[DONE]" ? nil : data
|
||||
end
|
||||
.compact
|
||||
|
@ -22,11 +22,6 @@ module ::DiscourseAi
|
||||
raise CompletionFailed if model.blank?
|
||||
|
||||
url = URI(SiteSetting.ai_hugging_face_api_url)
|
||||
if block_given?
|
||||
url.path = "/generate_stream"
|
||||
else
|
||||
url.path = "/generate"
|
||||
end
|
||||
headers = { "Content-Type" => "application/json" }
|
||||
|
||||
if SiteSetting.ai_hugging_face_api_key.present?
|
||||
@ -46,6 +41,8 @@ module ::DiscourseAi
|
||||
parameters[:temperature] = temperature if temperature
|
||||
parameters[:repetition_penalty] = repetition_penalty if repetition_penalty
|
||||
|
||||
payload[:stream] = true if block_given?
|
||||
|
||||
Net::HTTP.start(
|
||||
url.host,
|
||||
url.port,
|
||||
@ -80,7 +77,7 @@ module ::DiscourseAi
|
||||
log.update!(
|
||||
raw_response_payload: response_body,
|
||||
request_tokens: tokenizer.size(prompt),
|
||||
response_tokens: tokenizer.size(parsed_response[:generated_text]),
|
||||
response_tokens: tokenizer.size(parsed_response.first[:generated_text]),
|
||||
)
|
||||
return parsed_response
|
||||
end
|
||||
|
@ -15,20 +15,33 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||
model
|
||||
.default_options
|
||||
.merge(inputs: prompt)
|
||||
.tap { |payload| payload[:parameters][:max_new_tokens] = 2_000 - model.prompt_size(prompt) }
|
||||
.tap do |payload|
|
||||
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
||||
model.prompt_size(prompt)
|
||||
end
|
||||
.to_json
|
||||
end
|
||||
let(:stream_request_body) do
|
||||
model
|
||||
.default_options
|
||||
.merge(inputs: prompt)
|
||||
.tap do |payload|
|
||||
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
||||
model.prompt_size(prompt)
|
||||
payload[:stream] = true
|
||||
end
|
||||
.to_json
|
||||
end
|
||||
let(:stream_request_body) { request_body }
|
||||
|
||||
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
|
||||
|
||||
def response(content)
|
||||
{ generated_text: content }
|
||||
[{ generated_text: content }]
|
||||
end
|
||||
|
||||
def stub_response(prompt, response_text)
|
||||
WebMock
|
||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}/generate")
|
||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
||||
.with(body: request_body)
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
end
|
||||
@ -59,8 +72,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||
chunks = chunks.join("\n\n")
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}/generate_stream")
|
||||
.with(body: request_body)
|
||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
||||
.with(body: stream_request_body)
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user