DEV: Build sentiment clients outside of promises (#1117)

This commit is contained in:
Roman Rizzi 2025-02-06 13:11:10 -03:00 committed by GitHub
parent e52045ebdc
commit 90bcb8b503
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 39 deletions

View File

@ -12,11 +12,6 @@ module ::DiscourseAi
attr_reader :endpoint, :key, :referer
class << self
def configured?
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
end
def reranker_configured?
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
@ -50,32 +45,23 @@ module ::DiscourseAi
JSON.parse(response.body, symbolize_names: true)
end
def classify(content, model_config, base_url = Discourse.base_url)
headers = { "Referer" => base_url, "Content-Type" => "application/json" }
headers["X-API-KEY"] = model_config.api_key
headers["Authorization"] = "Bearer #{model_config.api_key}"
body = { inputs: content, truncate: true }.to_json
api_endpoint = model_config.endpoint
if api_endpoint.present? && api_endpoint.start_with?("srv://")
service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://"))
api_endpoint = "https://#{service.target}:#{service.port}"
end
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(api_endpoint, body, headers)
if response.status != 200
raise Net::HTTPBadResponse.new("Status: #{response.status}\n\n#{response.body}")
end
def classify_by_sentiment!(content)
response = do_request!(content)
JSON.parse(response.body, symbolize_names: true)
end
end
def perform!(content)
response = do_request!(content)
JSON.parse(response.body, symbolize_names: true).first
end
private
def do_request!(content)
headers = { "Referer" => referer, "Content-Type" => "application/json" }
body = { inputs: content, truncate: true }.to_json
@ -89,7 +75,7 @@ module ::DiscourseAi
raise Net::HTTPBadResponse.new(response.body.to_s) if ![200].include?(response.status)
JSON.parse(response.body, symbolize_names: true).first
response
end
end
end

View File

@ -55,7 +55,6 @@ module DiscourseAi
available_classifiers = classifiers
return if available_classifiers.blank?
base_url = Discourse.base_url
promised_classifications =
relation
@ -70,12 +69,14 @@ module DiscourseAi
already_classified = w_text[:target].sentiment_classifications.map(&:model_used)
classifiers_for_target =
available_classifiers.reject { |ac| already_classified.include?(ac.model_name) }
available_classifiers.reject do |ac|
already_classified.include?(ac[:model_name])
end
promised_target_results =
classifiers_for_target.map do |c|
classifiers_for_target.map do |cft|
Concurrent::Promises.future_on(pool) do
results[c.model_name] = request_with(w_text[:text], c, base_url)
results[cft[:model_name]] = request_with(cft[:client], w_text[:text])
end
end
@ -98,18 +99,19 @@ module DiscourseAi
def classify!(target)
return if target.blank?
return if classifiers.blank?
available_classifiers = classifiers
return if available_classifiers.blank?
to_classify = prepare_text(target)
return if to_classify.blank?
already_classified = target.sentiment_classifications.map(&:model_used)
classifiers_for_target =
classifiers.reject { |ac| already_classified.include?(ac.model_name) }
available_classifiers.reject { |ac| already_classified.include?(ac[:model_name]) }
results =
classifiers_for_target.reduce({}) do |memo, model|
memo[model.model_name] = request_with(to_classify, model)
classifiers_for_target.reduce({}) do |memo, cft|
memo[cft[:model_name]] = request_with(cft[:client], to_classify)
memo
end
@ -117,7 +119,20 @@ module DiscourseAi
end
def classifiers
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.map do |config|
api_endpoint = config.endpoint
if api_endpoint.present? && api_endpoint.start_with?("srv://")
service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://"))
api_endpoint = "https://#{service.target}:#{service.port}"
end
{
model_name: config.model_name,
client:
DiscourseAi::Inference::HuggingFaceTextEmbeddings.new(api_endpoint, config.api_key),
}
end
end
def has_classifiers?
@ -137,9 +152,9 @@ module DiscourseAi
Tokenizer::BertTokenizer.truncate(content, 512)
end
def request_with(content, config, base_url = Discourse.base_url)
result =
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
def request_with(client, content)
result = client.classify_by_sentiment!(content)
transform_result(result)
end