discourse-ai/lib/inference/hugging_face_text_embeddings.rb
Roman Rizzi 1f1c94e5c6
FEATURE: AI Bot RAG support. (#537)
This PR lets you associate uploads to an AI persona, which we'll split and generate embeddings from. When building the system prompt to get a bot reply, we'll do a similarity search followed by a re-ranking (if available). This will let us find the most relevant fragments from the body of knowledge you associated with the persona, resulting in better, more informed responses.

For now, we'll only allow plain-text files, but this will change in the future.

Commits:

* FEATURE: RAG embeddings for the AI Bot

This first commit introduces a UI where admins can upload text files, which we'll store, split into fragments,
and generate embeddings of. In a next commit, we'll use those to give the bot additional information during
conversations.

* Basic asymmetric similarity search to provide guidance in system prompt

* Fix tests and lint

* Apply reranker to fragments

* Uploads filter, css adjustments and file validations

* Add placeholder for rag fragments

* Update annotations
2024-04-01 13:43:34 -03:00

72 lines
2.6 KiB
Ruby

# frozen_string_literal: true
module ::DiscourseAi
module Inference
class HuggingFaceTextEmbeddings
class << self
def perform!(content)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
body = { inputs: content, truncate: true }.to_json
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
api_endpoint = "https://#{service.target}:#{service.port}"
else
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
end
if SiteSetting.ai_hugging_face_tei_api_key.present?
headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_api_key
end
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(api_endpoint, body, headers)
raise Net::HTTPBadResponse if ![200].include?(response.status)
JSON.parse(response.body, symbolize_names: true)
end
def rerank(content, candidates)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
body = { query: content, texts: candidates, truncate: true }.to_json
if SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv,
)
api_endpoint = "https://#{service.target}:#{service.port}"
else
api_endpoint = SiteSetting.ai_hugging_face_tei_reranker_endpoint
end
if SiteSetting.ai_hugging_face_tei_reranker_api_key.present?
headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_reranker_api_key
end
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post("#{api_endpoint}/rerank", body, headers)
if response.status != 200
raise Net::HTTPBadResponse.new("Status: #{response.status}\n\n#{response.body}")
end
JSON.parse(response.body, symbolize_names: true)
end
def reranker_configured?
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
end
def configured?
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
end
end
end
end
end