mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-09 11:48:47 +00:00
FIX: Many fixes for huggingface and llama2 inference (#335)
This commit is contained in:
parent
24370a9ca6
commit
d8267d8da0
@ -5,7 +5,7 @@ module DiscourseAi
|
|||||||
module Dialects
|
module Dialects
|
||||||
class Llama2Classic
|
class Llama2Classic
|
||||||
def self.can_translate?(model_name)
|
def self.can_translate?(model_name)
|
||||||
"Llama2-*-chat-hf" == model_name
|
%w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
def translate(generic_prompt)
|
def translate(generic_prompt)
|
||||||
|
@ -90,6 +90,7 @@ module DiscourseAi
|
|||||||
|
|
||||||
begin
|
begin
|
||||||
partial = extract_completion_from(raw_partial)
|
partial = extract_completion_from(raw_partial)
|
||||||
|
next if partial.nil?
|
||||||
leftover = ""
|
leftover = ""
|
||||||
response_data << partial
|
response_data << partial
|
||||||
|
|
||||||
|
@ -5,7 +5,9 @@ module DiscourseAi
|
|||||||
module Endpoints
|
module Endpoints
|
||||||
class HuggingFace < Base
|
class HuggingFace < Base
|
||||||
def self.can_contact?(model_name)
|
def self.can_contact?(model_name)
|
||||||
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf].include?(model_name)
|
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf Llama2-chat-hf].include?(
|
||||||
|
model_name,
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
def default_options
|
def default_options
|
||||||
@ -19,9 +21,7 @@ module DiscourseAi
|
|||||||
private
|
private
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
URI(SiteSetting.ai_hugging_face_api_url).tap do |uri|
|
URI(SiteSetting.ai_hugging_face_api_url)
|
||||||
uri.path = @streaming_mode ? "/generate_stream" : "/generate"
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params)
|
def prepare_payload(prompt, model_params)
|
||||||
@ -30,9 +30,11 @@ module DiscourseAi
|
|||||||
.tap do |payload|
|
.tap do |payload|
|
||||||
payload[:parameters].merge!(model_params)
|
payload[:parameters].merge!(model_params)
|
||||||
|
|
||||||
token_limit = 2_000 || SiteSetting.ai_hugging_face_token_limit
|
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
|
||||||
|
|
||||||
payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt)
|
payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt)
|
||||||
|
|
||||||
|
payload[:stream] = true if @streaming_mode
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -56,7 +58,7 @@ module DiscourseAi
|
|||||||
|
|
||||||
parsed.dig(:token, :text).to_s
|
parsed.dig(:token, :text).to_s
|
||||||
else
|
else
|
||||||
parsed[:generated_text].to_s
|
parsed[0][:generated_text].to_s
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -64,7 +66,7 @@ module DiscourseAi
|
|||||||
decoded_chunk
|
decoded_chunk
|
||||||
.split("\n")
|
.split("\n")
|
||||||
.map do |line|
|
.map do |line|
|
||||||
data = line.split("data: ", 2)[1]
|
data = line.split("data:", 2)[1]
|
||||||
data&.squish == "[DONE]" ? nil : data
|
data&.squish == "[DONE]" ? nil : data
|
||||||
end
|
end
|
||||||
.compact
|
.compact
|
||||||
|
@ -22,11 +22,6 @@ module ::DiscourseAi
|
|||||||
raise CompletionFailed if model.blank?
|
raise CompletionFailed if model.blank?
|
||||||
|
|
||||||
url = URI(SiteSetting.ai_hugging_face_api_url)
|
url = URI(SiteSetting.ai_hugging_face_api_url)
|
||||||
if block_given?
|
|
||||||
url.path = "/generate_stream"
|
|
||||||
else
|
|
||||||
url.path = "/generate"
|
|
||||||
end
|
|
||||||
headers = { "Content-Type" => "application/json" }
|
headers = { "Content-Type" => "application/json" }
|
||||||
|
|
||||||
if SiteSetting.ai_hugging_face_api_key.present?
|
if SiteSetting.ai_hugging_face_api_key.present?
|
||||||
@ -46,6 +41,8 @@ module ::DiscourseAi
|
|||||||
parameters[:temperature] = temperature if temperature
|
parameters[:temperature] = temperature if temperature
|
||||||
parameters[:repetition_penalty] = repetition_penalty if repetition_penalty
|
parameters[:repetition_penalty] = repetition_penalty if repetition_penalty
|
||||||
|
|
||||||
|
payload[:stream] = true if block_given?
|
||||||
|
|
||||||
Net::HTTP.start(
|
Net::HTTP.start(
|
||||||
url.host,
|
url.host,
|
||||||
url.port,
|
url.port,
|
||||||
@ -80,7 +77,7 @@ module ::DiscourseAi
|
|||||||
log.update!(
|
log.update!(
|
||||||
raw_response_payload: response_body,
|
raw_response_payload: response_body,
|
||||||
request_tokens: tokenizer.size(prompt),
|
request_tokens: tokenizer.size(prompt),
|
||||||
response_tokens: tokenizer.size(parsed_response[:generated_text]),
|
response_tokens: tokenizer.size(parsed_response.first[:generated_text]),
|
||||||
)
|
)
|
||||||
return parsed_response
|
return parsed_response
|
||||||
end
|
end
|
||||||
|
@ -15,20 +15,33 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
|||||||
model
|
model
|
||||||
.default_options
|
.default_options
|
||||||
.merge(inputs: prompt)
|
.merge(inputs: prompt)
|
||||||
.tap { |payload| payload[:parameters][:max_new_tokens] = 2_000 - model.prompt_size(prompt) }
|
.tap do |payload|
|
||||||
|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
||||||
|
model.prompt_size(prompt)
|
||||||
|
end
|
||||||
|
.to_json
|
||||||
|
end
|
||||||
|
let(:stream_request_body) do
|
||||||
|
model
|
||||||
|
.default_options
|
||||||
|
.merge(inputs: prompt)
|
||||||
|
.tap do |payload|
|
||||||
|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
||||||
|
model.prompt_size(prompt)
|
||||||
|
payload[:stream] = true
|
||||||
|
end
|
||||||
.to_json
|
.to_json
|
||||||
end
|
end
|
||||||
let(:stream_request_body) { request_body }
|
|
||||||
|
|
||||||
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
|
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
|
||||||
|
|
||||||
def response(content)
|
def response(content)
|
||||||
{ generated_text: content }
|
[{ generated_text: content }]
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_response(prompt, response_text)
|
def stub_response(prompt, response_text)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}/generate")
|
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
||||||
.with(body: request_body)
|
.with(body: request_body)
|
||||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||||
end
|
end
|
||||||
@ -59,8 +72,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
|||||||
chunks = chunks.join("\n\n")
|
chunks = chunks.join("\n\n")
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}/generate_stream")
|
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
||||||
.with(body: request_body)
|
.with(body: stream_request_body)
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user