mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-06 17:30:20 +00:00
FEATURE: configurable embeddings (#1049)
* Use AR model for embeddings features * endpoints * Embeddings CRUD UI * Add presets. Hide a couple more settings * system specs * Seed embedding definition from old settings * Generate search bit index on the fly. cleanup orphaned data * support for seeded models * Fix run test for new embedding * fix selected model not set correctly
This commit is contained in:
parent
fad4b65d4f
commit
f5cf1019fb
@ -0,0 +1,21 @@
|
|||||||
|
import DiscourseRoute from "discourse/routes/discourse";
|
||||||
|
|
||||||
|
export default class AdminPluginsShowDiscourseAiEmbeddingsEdit extends DiscourseRoute {
|
||||||
|
async model(params) {
|
||||||
|
const allEmbeddings = this.modelFor(
|
||||||
|
"adminPlugins.show.discourse-ai-embeddings"
|
||||||
|
);
|
||||||
|
const id = parseInt(params.id, 10);
|
||||||
|
const record = allEmbeddings.findBy("id", id);
|
||||||
|
record.provider_params = record.provider_params || {};
|
||||||
|
return record;
|
||||||
|
}
|
||||||
|
|
||||||
|
setupController(controller, model) {
|
||||||
|
super.setupController(controller, model);
|
||||||
|
controller.set(
|
||||||
|
"allEmbeddings",
|
||||||
|
this.modelFor("adminPlugins.show.discourse-ai-embeddings")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,17 @@
|
|||||||
|
import DiscourseRoute from "discourse/routes/discourse";
|
||||||
|
|
||||||
|
export default class AdminPluginsShowDiscourseAiEmbeddingsNew extends DiscourseRoute {
|
||||||
|
async model() {
|
||||||
|
const record = this.store.createRecord("ai-embedding");
|
||||||
|
record.provider_params = {};
|
||||||
|
return record;
|
||||||
|
}
|
||||||
|
|
||||||
|
setupController(controller, model) {
|
||||||
|
super.setupController(controller, model);
|
||||||
|
controller.set(
|
||||||
|
"allEmbeddings",
|
||||||
|
this.modelFor("adminPlugins.show.discourse-ai-embeddings")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,7 @@
|
|||||||
|
import DiscourseRoute from "discourse/routes/discourse";
|
||||||
|
|
||||||
|
export default class DiscourseAiAiEmbeddingsRoute extends DiscourseRoute {
|
||||||
|
model() {
|
||||||
|
return this.store.findAll("ai-embedding");
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
<AiEmbeddingsListEditor
|
||||||
|
@embeddings={{this.allEmbeddings}}
|
||||||
|
@currentEmbedding={{this.model}}
|
||||||
|
/>
|
@ -0,0 +1 @@
|
|||||||
|
<AiEmbeddingsListEditor @embeddings={{this.model}} />
|
@ -0,0 +1,4 @@
|
|||||||
|
<AiEmbeddingsListEditor
|
||||||
|
@embeddings={{this.allEmbeddings}}
|
||||||
|
@currentEmbedding={{this.model}}
|
||||||
|
/>
|
130
app/controllers/discourse_ai/admin/ai_embeddings_controller.rb
Normal file
130
app/controllers/discourse_ai/admin/ai_embeddings_controller.rb
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Admin
|
||||||
|
class AiEmbeddingsController < ::Admin::AdminController
|
||||||
|
requires_plugin ::DiscourseAi::PLUGIN_NAME
|
||||||
|
|
||||||
|
def index
|
||||||
|
embedding_defs = EmbeddingDefinition.all.order(:display_name)
|
||||||
|
|
||||||
|
render json: {
|
||||||
|
ai_embeddings:
|
||||||
|
ActiveModel::ArraySerializer.new(
|
||||||
|
embedding_defs,
|
||||||
|
each_serializer: AiEmbeddingDefinitionSerializer,
|
||||||
|
root: false,
|
||||||
|
).as_json,
|
||||||
|
meta: {
|
||||||
|
provider_params: EmbeddingDefinition.provider_params,
|
||||||
|
providers: EmbeddingDefinition.provider_names,
|
||||||
|
distance_functions: EmbeddingDefinition.distance_functions,
|
||||||
|
tokenizers:
|
||||||
|
EmbeddingDefinition.tokenizer_names.map { |tn|
|
||||||
|
{ id: tn, name: tn.split("::").last }
|
||||||
|
},
|
||||||
|
presets: EmbeddingDefinition.presets,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def new
|
||||||
|
end
|
||||||
|
|
||||||
|
def edit
|
||||||
|
embedding_def = EmbeddingDefinition.find(params[:id])
|
||||||
|
render json: AiEmbeddingDefinitionSerializer.new(embedding_def)
|
||||||
|
end
|
||||||
|
|
||||||
|
def create
|
||||||
|
embedding_def = EmbeddingDefinition.new(ai_embeddings_params)
|
||||||
|
|
||||||
|
if embedding_def.save
|
||||||
|
render json: AiEmbeddingDefinitionSerializer.new(embedding_def), status: :created
|
||||||
|
else
|
||||||
|
render_json_error embedding_def
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def update
|
||||||
|
embedding_def = EmbeddingDefinition.find(params[:id])
|
||||||
|
|
||||||
|
if embedding_def.seeded?
|
||||||
|
return(
|
||||||
|
render_json_error(I18n.t("discourse_ai.embeddings.cannot_edit_builtin"), status: 403)
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
if embedding_def.update(ai_embeddings_params.except(:dimensions))
|
||||||
|
render json: AiEmbeddingDefinitionSerializer.new(embedding_def)
|
||||||
|
else
|
||||||
|
render_json_error embedding_def
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def destroy
|
||||||
|
embedding_def = EmbeddingDefinition.find(params[:id])
|
||||||
|
|
||||||
|
if embedding_def.seeded?
|
||||||
|
return(
|
||||||
|
render_json_error(I18n.t("discourse_ai.embeddings.cannot_edit_builtin"), status: 403)
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
if embedding_def.id == SiteSetting.ai_embeddings_selected_model.to_i
|
||||||
|
return render_json_error(I18n.t("discourse_ai.embeddings.delete_failed"), status: 409)
|
||||||
|
end
|
||||||
|
|
||||||
|
if embedding_def.destroy
|
||||||
|
head :no_content
|
||||||
|
else
|
||||||
|
render_json_error embedding_def
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def test
|
||||||
|
RateLimiter.new(
|
||||||
|
current_user,
|
||||||
|
"ai_embeddings_test_#{current_user.id}",
|
||||||
|
3,
|
||||||
|
1.minute,
|
||||||
|
).performed!
|
||||||
|
|
||||||
|
embedding_def = EmbeddingDefinition.new(ai_embeddings_params)
|
||||||
|
DiscourseAi::Embeddings::Vector.new(embedding_def).vector_from("this is a test")
|
||||||
|
|
||||||
|
render json: { success: true }
|
||||||
|
rescue Net::HTTPBadResponse => e
|
||||||
|
render json: { success: false, error: e.message }
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def ai_embeddings_params
|
||||||
|
permitted =
|
||||||
|
params.require(:ai_embedding).permit(
|
||||||
|
:display_name,
|
||||||
|
:dimensions,
|
||||||
|
:max_sequence_length,
|
||||||
|
:pg_function,
|
||||||
|
:provider,
|
||||||
|
:url,
|
||||||
|
:api_key,
|
||||||
|
:tokenizer_class,
|
||||||
|
)
|
||||||
|
|
||||||
|
extra_field_names = EmbeddingDefinition.provider_params.dig(permitted[:provider]&.to_sym)
|
||||||
|
if extra_field_names.present?
|
||||||
|
received_prov_params =
|
||||||
|
params.dig(:ai_embedding, :provider_params)&.slice(*extra_field_names.keys)
|
||||||
|
|
||||||
|
if received_prov_params.present?
|
||||||
|
permitted[:provider_params] = received_prov_params.permit!
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
permitted
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -18,7 +18,7 @@ module ::Jobs
|
|||||||
target = target_type.constantize.find_by(id: target_id)
|
target = target_type.constantize.find_by(id: target_id)
|
||||||
return if !target
|
return if !target
|
||||||
|
|
||||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
vector_rep = DiscourseAi::Embeddings::Vector.instance
|
||||||
|
|
||||||
tokenizer = vector_rep.tokenizer
|
tokenizer = vector_rep.tokenizer
|
||||||
chunk_tokens = target.rag_chunk_tokens
|
chunk_tokens = target.rag_chunk_tokens
|
||||||
|
13
app/jobs/regular/manage_embedding_def_search_index.rb
Normal file
13
app/jobs/regular/manage_embedding_def_search_index.rb
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module ::Jobs
|
||||||
|
class ManageEmbeddingDefSearchIndex < ::Jobs::Base
|
||||||
|
def execute(args)
|
||||||
|
embedding_def = EmbeddingDefinition.find_by(id: args[:id])
|
||||||
|
return if embedding_def.nil?
|
||||||
|
return if DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_def)
|
||||||
|
|
||||||
|
DiscourseAi::Embeddings::Schema.prepare_search_indexes(embedding_def)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
11
app/jobs/scheduled/remove_orphaned_embeddings.rb
Normal file
11
app/jobs/scheduled/remove_orphaned_embeddings.rb
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module Jobs
|
||||||
|
class RemoveOrphanedEmbeddings < ::Jobs::Scheduled
|
||||||
|
every 1.week
|
||||||
|
|
||||||
|
def execute(_args)
|
||||||
|
DiscourseAi::Embeddings::Schema.remove_orphaned_data
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
231
app/models/embedding_definition.rb
Normal file
231
app/models/embedding_definition.rb
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class EmbeddingDefinition < ActiveRecord::Base
|
||||||
|
CLOUDFLARE = "cloudflare"
|
||||||
|
HUGGING_FACE = "hugging_face"
|
||||||
|
OPEN_AI = "open_ai"
|
||||||
|
GOOGLE = "google"
|
||||||
|
|
||||||
|
class << self
|
||||||
|
def provider_names
|
||||||
|
[CLOUDFLARE, HUGGING_FACE, OPEN_AI, GOOGLE]
|
||||||
|
end
|
||||||
|
|
||||||
|
def distance_functions
|
||||||
|
%w[<#> <=>]
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer_names
|
||||||
|
[
|
||||||
|
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer,
|
||||||
|
DiscourseAi::Tokenizer::BgeLargeEnTokenizer,
|
||||||
|
DiscourseAi::Tokenizer::BgeM3Tokenizer,
|
||||||
|
DiscourseAi::Tokenizer::OpenAiTokenizer,
|
||||||
|
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer,
|
||||||
|
DiscourseAi::Tokenizer::OpenAiTokenizer,
|
||||||
|
].map(&:name)
|
||||||
|
end
|
||||||
|
|
||||||
|
def provider_params
|
||||||
|
{ open_ai: { model_name: :text } }
|
||||||
|
end
|
||||||
|
|
||||||
|
def presets
|
||||||
|
@presets ||=
|
||||||
|
begin
|
||||||
|
[
|
||||||
|
{
|
||||||
|
preset_id: "bge-large-en",
|
||||||
|
display_name: "bge-large-en",
|
||||||
|
dimensions: 1024,
|
||||||
|
max_sequence_length: 512,
|
||||||
|
pg_function: "<#>",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer",
|
||||||
|
provider: HUGGING_FACE,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
preset_id: "bge-m3",
|
||||||
|
display_name: "bge-m3",
|
||||||
|
dimensions: 1024,
|
||||||
|
max_sequence_length: 8192,
|
||||||
|
pg_function: "<#>",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
|
||||||
|
provider: HUGGING_FACE,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
preset_id: "gemini-embedding-001",
|
||||||
|
display_name: "Gemini's embedding-001",
|
||||||
|
dimensions: 768,
|
||||||
|
max_sequence_length: 1536,
|
||||||
|
pg_function: "<=>",
|
||||||
|
url:
|
||||||
|
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||||
|
provider: GOOGLE,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
preset_id: "multilingual-e5-large",
|
||||||
|
display_name: "multilingual-e5-large",
|
||||||
|
dimensions: 1024,
|
||||||
|
max_sequence_length: 512,
|
||||||
|
pg_function: "<=>",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer",
|
||||||
|
provider: HUGGING_FACE,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
preset_id: "text-embedding-3-large",
|
||||||
|
display_name: "OpenAI's text-embedding-3-large",
|
||||||
|
dimensions: 2000,
|
||||||
|
max_sequence_length: 8191,
|
||||||
|
pg_function: "<=>",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||||
|
url: "https://api.openai.com/v1/embeddings",
|
||||||
|
provider: OPEN_AI,
|
||||||
|
provider_params: {
|
||||||
|
model_name: "text-embedding-3-large",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
preset_id: "text-embedding-3-small",
|
||||||
|
display_name: "OpenAI's text-embedding-3-small",
|
||||||
|
dimensions: 1536,
|
||||||
|
max_sequence_length: 8191,
|
||||||
|
pg_function: "<=>",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||||
|
url: "https://api.openai.com/v1/embeddings",
|
||||||
|
provider: OPEN_AI,
|
||||||
|
provider_params: {
|
||||||
|
model_name: "text-embedding-3-small",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
preset_id: "text-embedding-ada-002",
|
||||||
|
display_name: "OpenAI's text-embedding-ada-002",
|
||||||
|
dimensions: 1536,
|
||||||
|
max_sequence_length: 8191,
|
||||||
|
pg_function: "<=>",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||||
|
url: "https://api.openai.com/v1/embeddings",
|
||||||
|
provider: OPEN_AI,
|
||||||
|
provider_params: {
|
||||||
|
model_name: "text-embedding-ada-002",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
validates :provider, presence: true, inclusion: provider_names
|
||||||
|
validates :display_name, presence: true, length: { maximum: 100 }
|
||||||
|
validates :tokenizer_class, presence: true, inclusion: tokenizer_names
|
||||||
|
validates_presence_of :url, :api_key, :dimensions, :max_sequence_length, :pg_function
|
||||||
|
|
||||||
|
after_create :create_indexes
|
||||||
|
|
||||||
|
def create_indexes
|
||||||
|
Jobs.enqueue(:manage_embedding_def_search_index, id: self.id)
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
tokenizer_class.constantize
|
||||||
|
end
|
||||||
|
|
||||||
|
def inference_client
|
||||||
|
case provider
|
||||||
|
when CLOUDFLARE
|
||||||
|
cloudflare_client
|
||||||
|
when HUGGING_FACE
|
||||||
|
hugging_face_client
|
||||||
|
when OPEN_AI
|
||||||
|
open_ai_client
|
||||||
|
when GOOGLE
|
||||||
|
gemini_client
|
||||||
|
else
|
||||||
|
raise "Uknown embeddings provider"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def lookup_custom_param(key)
|
||||||
|
provider_params&.dig(key)
|
||||||
|
end
|
||||||
|
|
||||||
|
def endpoint_url
|
||||||
|
return url if !url.starts_with?("srv://")
|
||||||
|
|
||||||
|
service = DiscourseAi::Utils::DnsSrv.lookup(url.sub("srv://", ""))
|
||||||
|
"https://#{service.target}:#{service.port}"
|
||||||
|
end
|
||||||
|
|
||||||
|
def prepare_query_text(text, asymetric: false)
|
||||||
|
strategy.prepare_query_text(text, self, asymetric: asymetric)
|
||||||
|
end
|
||||||
|
|
||||||
|
def prepare_target_text(target)
|
||||||
|
strategy.prepare_target_text(target, self)
|
||||||
|
end
|
||||||
|
|
||||||
|
def strategy_id
|
||||||
|
strategy.id
|
||||||
|
end
|
||||||
|
|
||||||
|
def strategy_version
|
||||||
|
strategy.version
|
||||||
|
end
|
||||||
|
|
||||||
|
def api_key
|
||||||
|
if seeded?
|
||||||
|
env_key = "DISCOURSE_AI_SEEDED_EMBEDDING_API_KEY"
|
||||||
|
ENV[env_key] || self[:api_key]
|
||||||
|
else
|
||||||
|
self[:api_key]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def strategy
|
||||||
|
@strategy ||= DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
|
end
|
||||||
|
|
||||||
|
def cloudflare_client
|
||||||
|
DiscourseAi::Inference::CloudflareWorkersAi.new(endpoint_url, api_key)
|
||||||
|
end
|
||||||
|
|
||||||
|
def hugging_face_client
|
||||||
|
DiscourseAi::Inference::HuggingFaceTextEmbeddings.new(endpoint_url, api_key)
|
||||||
|
end
|
||||||
|
|
||||||
|
def open_ai_client
|
||||||
|
DiscourseAi::Inference::OpenAiEmbeddings.new(
|
||||||
|
endpoint_url,
|
||||||
|
api_key,
|
||||||
|
lookup_custom_param("model_name"),
|
||||||
|
dimensions,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def gemini_client
|
||||||
|
DiscourseAi::Inference::GeminiEmbeddings.new(endpoint_url, api_key)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# == Schema Information
|
||||||
|
#
|
||||||
|
# Table name: embedding_definitions
|
||||||
|
#
|
||||||
|
# id :bigint not null, primary key
|
||||||
|
# display_name :string not null
|
||||||
|
# dimensions :integer not null
|
||||||
|
# max_sequence_length :integer not null
|
||||||
|
# version :integer default(1), not null
|
||||||
|
# pg_function :string not null
|
||||||
|
# provider :string not null
|
||||||
|
# tokenizer_class :string not null
|
||||||
|
# url :string not null
|
||||||
|
# api_key :string
|
||||||
|
# seeded :boolean default(FALSE), not null
|
||||||
|
# provider_params :jsonb
|
||||||
|
# created_at :datetime not null
|
||||||
|
# updated_at :datetime not null
|
||||||
|
#
|
29
app/serializers/ai_embedding_definition_serializer.rb
Normal file
29
app/serializers/ai_embedding_definition_serializer.rb
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class AiEmbeddingDefinitionSerializer < ApplicationSerializer
|
||||||
|
root "ai_embedding"
|
||||||
|
|
||||||
|
attributes :id,
|
||||||
|
:display_name,
|
||||||
|
:dimensions,
|
||||||
|
:max_sequence_length,
|
||||||
|
:pg_function,
|
||||||
|
:provider,
|
||||||
|
:url,
|
||||||
|
:api_key,
|
||||||
|
:seeded,
|
||||||
|
:tokenizer_class,
|
||||||
|
:provider_params
|
||||||
|
|
||||||
|
def api_key
|
||||||
|
object.seeded? ? "********" : object.api_key
|
||||||
|
end
|
||||||
|
|
||||||
|
def url
|
||||||
|
object.seeded? ? "********" : object.url
|
||||||
|
end
|
||||||
|
|
||||||
|
def provider
|
||||||
|
object.seeded? ? "CDCK" : object.provider
|
||||||
|
end
|
||||||
|
end
|
@ -20,5 +20,14 @@ export default {
|
|||||||
});
|
});
|
||||||
this.route("discourse-ai-spam", { path: "ai-spam" });
|
this.route("discourse-ai-spam", { path: "ai-spam" });
|
||||||
this.route("discourse-ai-usage", { path: "ai-usage" });
|
this.route("discourse-ai-usage", { path: "ai-usage" });
|
||||||
|
|
||||||
|
this.route(
|
||||||
|
"discourse-ai-embeddings",
|
||||||
|
{ path: "ai-embeddings" },
|
||||||
|
function () {
|
||||||
|
this.route("new");
|
||||||
|
this.route("edit", { path: "/:id/edit" });
|
||||||
|
}
|
||||||
|
);
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
21
assets/javascripts/discourse/admin/adapters/ai-embedding.js
Normal file
21
assets/javascripts/discourse/admin/adapters/ai-embedding.js
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import RestAdapter from "discourse/adapters/rest";
|
||||||
|
|
||||||
|
export default class Adapter extends RestAdapter {
|
||||||
|
jsonMode = true;
|
||||||
|
|
||||||
|
basePath() {
|
||||||
|
return "/admin/plugins/discourse-ai/";
|
||||||
|
}
|
||||||
|
|
||||||
|
pathFor(store, type, findArgs) {
|
||||||
|
// removes underscores which are implemented in base
|
||||||
|
let path =
|
||||||
|
this.basePath(store, type, findArgs) +
|
||||||
|
store.pluralize(this.apiNameFor(type));
|
||||||
|
return this.appendQueryParams(path, findArgs);
|
||||||
|
}
|
||||||
|
|
||||||
|
apiNameFor() {
|
||||||
|
return "ai-embedding";
|
||||||
|
}
|
||||||
|
}
|
38
assets/javascripts/discourse/admin/models/ai-embedding.js
Normal file
38
assets/javascripts/discourse/admin/models/ai-embedding.js
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import { ajax } from "discourse/lib/ajax";
|
||||||
|
import RestModel from "discourse/models/rest";
|
||||||
|
|
||||||
|
export default class AiEmbedding extends RestModel {
|
||||||
|
createProperties() {
|
||||||
|
return this.getProperties(
|
||||||
|
"id",
|
||||||
|
"display_name",
|
||||||
|
"dimensions",
|
||||||
|
"provider",
|
||||||
|
"tokenizer_class",
|
||||||
|
"dimensions",
|
||||||
|
"url",
|
||||||
|
"api_key",
|
||||||
|
"max_sequence_length",
|
||||||
|
"provider_params",
|
||||||
|
"pg_function"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
updateProperties() {
|
||||||
|
const attrs = this.createProperties();
|
||||||
|
attrs.id = this.id;
|
||||||
|
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
|
||||||
|
async testConfig() {
|
||||||
|
return await ajax(`/admin/plugins/discourse-ai/ai-embeddings/test.json`, {
|
||||||
|
data: { ai_embedding: this.createProperties() },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
workingCopy() {
|
||||||
|
const attrs = this.createProperties();
|
||||||
|
return this.store.createRecord("ai-embedding", attrs);
|
||||||
|
}
|
||||||
|
}
|
376
assets/javascripts/discourse/components/ai-embedding-editor.gjs
Normal file
376
assets/javascripts/discourse/components/ai-embedding-editor.gjs
Normal file
@ -0,0 +1,376 @@
|
|||||||
|
import Component from "@glimmer/component";
|
||||||
|
import { tracked } from "@glimmer/tracking";
|
||||||
|
import { Input } from "@ember/component";
|
||||||
|
import { concat, get } from "@ember/helper";
|
||||||
|
import { on } from "@ember/modifier";
|
||||||
|
import { action, computed } from "@ember/object";
|
||||||
|
import didInsert from "@ember/render-modifiers/modifiers/did-insert";
|
||||||
|
import didUpdate from "@ember/render-modifiers/modifiers/did-update";
|
||||||
|
import { later } from "@ember/runloop";
|
||||||
|
import { service } from "@ember/service";
|
||||||
|
import BackButton from "discourse/components/back-button";
|
||||||
|
import DButton from "discourse/components/d-button";
|
||||||
|
import icon from "discourse/helpers/d-icon";
|
||||||
|
import { popupAjaxError } from "discourse/lib/ajax-error";
|
||||||
|
import { i18n } from "discourse-i18n";
|
||||||
|
import ComboBox from "select-kit/components/combo-box";
|
||||||
|
import DTooltip from "float-kit/components/d-tooltip";
|
||||||
|
import not from "truth-helpers/helpers/not";
|
||||||
|
|
||||||
|
export default class AiEmbeddingEditor extends Component {
|
||||||
|
@service toasts;
|
||||||
|
@service router;
|
||||||
|
@service dialog;
|
||||||
|
@service store;
|
||||||
|
|
||||||
|
@tracked isSaving = false;
|
||||||
|
@tracked selectedPreset = null;
|
||||||
|
|
||||||
|
@tracked testRunning = false;
|
||||||
|
@tracked testResult = null;
|
||||||
|
@tracked testError = null;
|
||||||
|
@tracked apiKeySecret = true;
|
||||||
|
@tracked editingModel = null;
|
||||||
|
|
||||||
|
get selectedProviders() {
|
||||||
|
const t = (provName) => {
|
||||||
|
return i18n(`discourse_ai.embeddings.providers.${provName}`);
|
||||||
|
};
|
||||||
|
|
||||||
|
return this.args.embeddings.resultSetMeta.providers.map((prov) => {
|
||||||
|
return { id: prov, name: t(prov) };
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
get distanceFunctions() {
|
||||||
|
const t = (df) => {
|
||||||
|
return i18n(`discourse_ai.embeddings.distance_functions.${df}`);
|
||||||
|
};
|
||||||
|
|
||||||
|
return this.args.embeddings.resultSetMeta.distance_functions.map((df) => {
|
||||||
|
return { id: df, name: t(df) };
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
get presets() {
|
||||||
|
const presets = this.args.embeddings.resultSetMeta.presets.map((preset) => {
|
||||||
|
return {
|
||||||
|
name: preset.display_name,
|
||||||
|
id: preset.preset_id,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
presets.pushObject({
|
||||||
|
name: i18n("discourse_ai.embeddings.configure_manually"),
|
||||||
|
id: "manual",
|
||||||
|
});
|
||||||
|
|
||||||
|
return presets;
|
||||||
|
}
|
||||||
|
|
||||||
|
get showPresets() {
|
||||||
|
return !this.selectedPreset && this.args.model.isNew;
|
||||||
|
}
|
||||||
|
|
||||||
|
@computed("editingModel.provider")
|
||||||
|
get metaProviderParams() {
|
||||||
|
return (
|
||||||
|
this.args.embeddings.resultSetMeta.provider_params[
|
||||||
|
this.editingModel?.provider
|
||||||
|
] || {}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
get testErrorMessage() {
|
||||||
|
return i18n("discourse_ai.llms.tests.failure", { error: this.testError });
|
||||||
|
}
|
||||||
|
|
||||||
|
get displayTestResult() {
|
||||||
|
return this.testRunning || this.testResult !== null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
configurePreset() {
|
||||||
|
this.selectedPreset =
|
||||||
|
this.args.embeddings.resultSetMeta.presets.findBy(
|
||||||
|
"preset_id",
|
||||||
|
this.presetId
|
||||||
|
) || {};
|
||||||
|
|
||||||
|
this.editingModel = this.store
|
||||||
|
.createRecord("ai-embedding", this.selectedPreset)
|
||||||
|
.workingCopy();
|
||||||
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
updateModel() {
|
||||||
|
this.editingModel = this.args.model.workingCopy();
|
||||||
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
makeApiKeySecret() {
|
||||||
|
this.apiKeySecret = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
toggleApiKeySecret() {
|
||||||
|
this.apiKeySecret = !this.apiKeySecret;
|
||||||
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
async save() {
|
||||||
|
this.isSaving = true;
|
||||||
|
const isNew = this.args.model.isNew;
|
||||||
|
|
||||||
|
try {
|
||||||
|
await this.editingModel.save();
|
||||||
|
|
||||||
|
if (isNew) {
|
||||||
|
this.args.embeddings.addObject(this.editingModel);
|
||||||
|
this.router.transitionTo(
|
||||||
|
"adminPlugins.show.discourse-ai-embeddings.index"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
this.toasts.success({
|
||||||
|
data: { message: i18n("discourse_ai.embeddings.saved") },
|
||||||
|
duration: 2000,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
popupAjaxError(e);
|
||||||
|
} finally {
|
||||||
|
later(() => {
|
||||||
|
this.isSaving = false;
|
||||||
|
}, 1000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
async test() {
|
||||||
|
this.testRunning = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const configTestResult = await this.editingModel.testConfig();
|
||||||
|
this.testResult = configTestResult.success;
|
||||||
|
|
||||||
|
if (this.testResult) {
|
||||||
|
this.testError = null;
|
||||||
|
} else {
|
||||||
|
this.testError = configTestResult.error;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
popupAjaxError(e);
|
||||||
|
} finally {
|
||||||
|
later(() => {
|
||||||
|
this.testRunning = false;
|
||||||
|
}, 1000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
delete() {
|
||||||
|
return this.dialog.confirm({
|
||||||
|
message: i18n("discourse_ai.embeddings.confirm_delete"),
|
||||||
|
didConfirm: () => {
|
||||||
|
return this.args.model
|
||||||
|
.destroyRecord()
|
||||||
|
.then(() => {
|
||||||
|
this.args.llms.removeObject(this.args.model);
|
||||||
|
this.router.transitionTo(
|
||||||
|
"adminPlugins.show.discourse-ai-embeddings.index"
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch(popupAjaxError);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
<template>
|
||||||
|
<BackButton
|
||||||
|
@route="adminPlugins.show.discourse-ai-embeddings"
|
||||||
|
@label="discourse_ai.embeddings.back"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<form
|
||||||
|
{{didInsert this.updateModel @model.id}}
|
||||||
|
{{didUpdate this.updateModel @model.id}}
|
||||||
|
class="form-horizontal ai-embedding-editor"
|
||||||
|
>
|
||||||
|
{{#if this.showPresets}}
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{i18n "discourse_ai.embeddings.presets"}}</label>
|
||||||
|
<ComboBox
|
||||||
|
@value={{this.presetId}}
|
||||||
|
@content={{this.presets}}
|
||||||
|
class="ai-embedding-editor__presets"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="control-group ai-llm-editor__action_panel">
|
||||||
|
<DButton
|
||||||
|
@action={{this.configurePreset}}
|
||||||
|
@label="discourse_ai.tools.next.title"
|
||||||
|
class="ai-embedding-editor__next"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{{else}}
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{i18n "discourse_ai.embeddings.display_name"}}</label>
|
||||||
|
<Input
|
||||||
|
class="ai-embedding-editor-input ai-embedding-editor__display-name"
|
||||||
|
@type="text"
|
||||||
|
@value={{this.editingModel.display_name}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{i18n "discourse_ai.embeddings.provider"}}</label>
|
||||||
|
<ComboBox
|
||||||
|
@value={{this.editingModel.provider}}
|
||||||
|
@content={{this.selectedProviders}}
|
||||||
|
@class="ai-embedding-editor__provider"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{i18n "discourse_ai.embeddings.url"}}</label>
|
||||||
|
<Input
|
||||||
|
class="ai-embedding-editor-input ai-embedding-editor__url"
|
||||||
|
@type="text"
|
||||||
|
@value={{this.editingModel.url}}
|
||||||
|
required="true"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{i18n "discourse_ai.embeddings.api_key"}}</label>
|
||||||
|
<div class="ai-embedding-editor__secret-api-key-group">
|
||||||
|
<Input
|
||||||
|
@value={{this.editingModel.api_key}}
|
||||||
|
class="ai-embedding-editor-input ai-embedding-editor__api-key"
|
||||||
|
@type={{if this.apiKeySecret "password" "text"}}
|
||||||
|
required="true"
|
||||||
|
{{on "focusout" this.makeApiKeySecret}}
|
||||||
|
/>
|
||||||
|
<DButton
|
||||||
|
@action={{this.toggleApiKeySecret}}
|
||||||
|
@icon="far-eye-slash"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{i18n "discourse_ai.embeddings.tokenizer"}}</label>
|
||||||
|
<ComboBox
|
||||||
|
@value={{this.editingModel.tokenizer_class}}
|
||||||
|
@content={{@embeddings.resultSetMeta.tokenizers}}
|
||||||
|
@class="ai-embedding-editor__tokenizer"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{i18n "discourse_ai.embeddings.dimensions"}}</label>
|
||||||
|
<Input
|
||||||
|
@type="number"
|
||||||
|
class="ai-embedding-editor-input ai-embedding-editor__dimensions"
|
||||||
|
step="any"
|
||||||
|
min="0"
|
||||||
|
lang="en"
|
||||||
|
@value={{this.editingModel.dimensions}}
|
||||||
|
required="true"
|
||||||
|
disabled={{not this.editingModel.isNew}}
|
||||||
|
/>
|
||||||
|
{{#if this.editingModel.isNew}}
|
||||||
|
<DTooltip
|
||||||
|
@icon="circle-exclamation"
|
||||||
|
@content={{i18n
|
||||||
|
"discourse_ai.embeddings.hints.dimensions_warning"
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
{{/if}}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{i18n "discourse_ai.embeddings.max_sequence_length"}}</label>
|
||||||
|
<Input
|
||||||
|
@type="number"
|
||||||
|
class="ai-embedding-editor-input ai-embedding-editor__max_sequence_length"
|
||||||
|
step="any"
|
||||||
|
min="0"
|
||||||
|
lang="en"
|
||||||
|
@value={{this.editingModel.max_sequence_length}}
|
||||||
|
required="true"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{i18n "discourse_ai.embeddings.distance_function"}}</label>
|
||||||
|
<ComboBox
|
||||||
|
@value={{this.editingModel.pg_function}}
|
||||||
|
@content={{this.distanceFunctions}}
|
||||||
|
@class="ai-embedding-editor__distance_functions"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{#each-in this.metaProviderParams as |field type|}}
|
||||||
|
<div
|
||||||
|
class="control-group ai-embedding-editor-provider-param__{{type}}"
|
||||||
|
>
|
||||||
|
<label>
|
||||||
|
{{i18n (concat "discourse_ai.embeddings.provider_fields." field)}}
|
||||||
|
</label>
|
||||||
|
<Input
|
||||||
|
@type="text"
|
||||||
|
class="ai-embedding-editor-input ai-embedding-editor__{{field}}"
|
||||||
|
@value={{mut (get this.editingModel.provider_params field)}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{{/each-in}}
|
||||||
|
|
||||||
|
<div class="control-group ai-embedding-editor__action_panel">
|
||||||
|
<DButton
|
||||||
|
class="ai-embedding-editor__test"
|
||||||
|
@action={{this.test}}
|
||||||
|
@disabled={{this.testRunning}}
|
||||||
|
@label="discourse_ai.embeddings.tests.title"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<DButton
|
||||||
|
class="btn-primary ai-embedding-editor__save"
|
||||||
|
@action={{this.save}}
|
||||||
|
@disabled={{this.isSaving}}
|
||||||
|
@label="discourse_ai.embeddings.save"
|
||||||
|
/>
|
||||||
|
{{#unless this.editingModel.isNew}}
|
||||||
|
<DButton
|
||||||
|
@action={{this.delete}}
|
||||||
|
class="btn-danger ai-embedding-editor__delete"
|
||||||
|
@label="discourse_ai.embeddings.delete"
|
||||||
|
/>
|
||||||
|
{{/unless}}
|
||||||
|
|
||||||
|
<div class="control-group ai-embedding-editor-tests">
|
||||||
|
{{#if this.displayTestResult}}
|
||||||
|
{{#if this.testRunning}}
|
||||||
|
<div class="spinner small"></div>
|
||||||
|
{{i18n "discourse_ai.embeddings.tests.running"}}
|
||||||
|
{{else}}
|
||||||
|
{{#if this.testResult}}
|
||||||
|
<div class="ai-embedding-editor-tests__success">
|
||||||
|
{{icon "check"}}
|
||||||
|
{{i18n "discourse_ai.embeddings.tests.success"}}
|
||||||
|
</div>
|
||||||
|
{{else}}
|
||||||
|
<div class="ai-embedding-editor-tests__failure">
|
||||||
|
{{icon "xmark"}}
|
||||||
|
{{this.testErrorMessage}}
|
||||||
|
</div>
|
||||||
|
{{/if}}
|
||||||
|
{{/if}}
|
||||||
|
{{/if}}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{{/if}}
|
||||||
|
</form>
|
||||||
|
</template>
|
||||||
|
}
|
@ -0,0 +1,114 @@
|
|||||||
|
import Component from "@glimmer/component";
|
||||||
|
import { concat } from "@ember/helper";
|
||||||
|
import { service } from "@ember/service";
|
||||||
|
import DBreadcrumbsItem from "discourse/components/d-breadcrumbs-item";
|
||||||
|
import DButton from "discourse/components/d-button";
|
||||||
|
import DPageSubheader from "discourse/components/d-page-subheader";
|
||||||
|
import { i18n } from "discourse-i18n";
|
||||||
|
import AdminConfigAreaEmptyList from "admin/components/admin-config-area-empty-list";
|
||||||
|
import DTooltip from "float-kit/components/d-tooltip";
|
||||||
|
import AiEmbeddingEditor from "./ai-embedding-editor";
|
||||||
|
|
||||||
|
export default class AiEmbeddingsListEditor extends Component {
|
||||||
|
@service adminPluginNavManager;
|
||||||
|
|
||||||
|
get hasEmbeddingElements() {
|
||||||
|
return this.args.embeddings.length !== 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
<template>
|
||||||
|
<DBreadcrumbsItem
|
||||||
|
@path="/admin/plugins/{{this.adminPluginNavManager.currentPlugin.name}}/ai-embeddings"
|
||||||
|
@label={{i18n "discourse_ai.embeddings.short_title"}}
|
||||||
|
/>
|
||||||
|
<section class="ai-embeddings-list-editor admin-detail">
|
||||||
|
{{#if @currentEmbedding}}
|
||||||
|
<AiEmbeddingEditor
|
||||||
|
@model={{@currentEmbedding}}
|
||||||
|
@embeddings={{@embeddings}}
|
||||||
|
/>
|
||||||
|
{{else}}
|
||||||
|
<DPageSubheader
|
||||||
|
@titleLabel={{i18n "discourse_ai.embeddings.short_title"}}
|
||||||
|
@descriptionLabel={{i18n "discourse_ai.embeddings.description"}}
|
||||||
|
@learnMoreUrl="https://meta.discourse.org/t/discourse-ai-embeddings/259603"
|
||||||
|
>
|
||||||
|
<:actions as |actions|>
|
||||||
|
<actions.Primary
|
||||||
|
@label="discourse_ai.embeddings.new"
|
||||||
|
@route="adminPlugins.show.discourse-ai-embeddings.new"
|
||||||
|
@icon="plus"
|
||||||
|
class="ai-embeddings-list-editor__new-button"
|
||||||
|
/>
|
||||||
|
</:actions>
|
||||||
|
</DPageSubheader>
|
||||||
|
|
||||||
|
{{#if this.hasEmbeddingElements}}
|
||||||
|
<table class="d-admin-table">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th>{{i18n "discourse_ai.embeddings.display_name"}}</th>
|
||||||
|
<th>{{i18n "discourse_ai.embeddings.provider"}}</th>
|
||||||
|
<th></th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{{#each @embeddings as |embedding|}}
|
||||||
|
<tr class="ai-embeddings-list__row d-admin-row__content">
|
||||||
|
<td class="d-admin-row__overview">
|
||||||
|
<div class="ai-embeddings-list__name">
|
||||||
|
<strong>
|
||||||
|
{{embedding.display_name}}
|
||||||
|
</strong>
|
||||||
|
</div>
|
||||||
|
</td>
|
||||||
|
<td class="d-admin-row__detail">
|
||||||
|
<div class="d-admin-row__mobile-label">
|
||||||
|
{{i18n "discourse_ai.embeddings.provider"}}
|
||||||
|
</div>
|
||||||
|
{{i18n
|
||||||
|
(concat
|
||||||
|
"discourse_ai.embeddings.providers." embedding.provider
|
||||||
|
)
|
||||||
|
}}
|
||||||
|
</td>
|
||||||
|
<td class="d-admin-row__controls">
|
||||||
|
{{#if embedding.seeded}}
|
||||||
|
<DTooltip
|
||||||
|
class="ai-embeddings-list__edit-disabled-tooltip"
|
||||||
|
>
|
||||||
|
<:trigger>
|
||||||
|
<DButton
|
||||||
|
class="btn btn-default btn-small disabled"
|
||||||
|
@label="discourse_ai.embeddings.edit"
|
||||||
|
/>
|
||||||
|
</:trigger>
|
||||||
|
<:content>
|
||||||
|
{{i18n "discourse_ai.embeddings.seeded_warning"}}
|
||||||
|
</:content>
|
||||||
|
</DTooltip>
|
||||||
|
{{else}}
|
||||||
|
<DButton
|
||||||
|
class="btn btn-default btn-small ai-embeddings-list__edit-button"
|
||||||
|
@label="discourse_ai.embeddings.edit"
|
||||||
|
@route="adminPlugins.show.discourse-ai-embeddings.edit"
|
||||||
|
@routeModels={{embedding.id}}
|
||||||
|
/>
|
||||||
|
{{/if}}
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
{{/each}}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
{{else}}
|
||||||
|
<AdminConfigAreaEmptyList
|
||||||
|
@ctaLabel="discourse_ai.embeddings.new"
|
||||||
|
@ctaRoute="adminPlugins.show.discourse-ai-embeddings.new"
|
||||||
|
@ctaClass="ai-embeddings-list-editor__empty-new-button"
|
||||||
|
@emptyLabel="discourse_ai.embeddings.empty"
|
||||||
|
/>
|
||||||
|
{{/if}}
|
||||||
|
{{/if}}
|
||||||
|
</section>
|
||||||
|
</template>
|
||||||
|
}
|
@ -12,6 +12,10 @@ export default {
|
|||||||
|
|
||||||
withPluginApi("1.1.0", (api) => {
|
withPluginApi("1.1.0", (api) => {
|
||||||
api.addAdminPluginConfigurationNav("discourse-ai", PLUGIN_NAV_MODE_TOP, [
|
api.addAdminPluginConfigurationNav("discourse-ai", PLUGIN_NAV_MODE_TOP, [
|
||||||
|
{
|
||||||
|
label: "discourse_ai.embeddings.short_title",
|
||||||
|
route: "adminPlugins.show.discourse-ai-embeddings",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
label: "discourse_ai.llms.short_title",
|
label: "discourse_ai.llms.short_title",
|
||||||
route: "adminPlugins.show.discourse-ai-llms",
|
route: "adminPlugins.show.discourse-ai-llms",
|
||||||
|
@ -0,0 +1,26 @@
|
|||||||
|
.ai-embedding-editor {
|
||||||
|
padding-left: 0.5em;
|
||||||
|
|
||||||
|
.ai-embedding-editor-input {
|
||||||
|
width: 350px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.ai-embedding-editor-tests {
|
||||||
|
&__failure {
|
||||||
|
color: var(--danger);
|
||||||
|
}
|
||||||
|
|
||||||
|
&__success {
|
||||||
|
color: var(--success);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
&__api-key {
|
||||||
|
margin-right: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
&__secret-api-key-group {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
}
|
@ -502,6 +502,49 @@ en:
|
|||||||
accuracy: "Accuracy:"
|
accuracy: "Accuracy:"
|
||||||
|
|
||||||
embeddings:
|
embeddings:
|
||||||
|
short_title: "Embeddings"
|
||||||
|
description: "Embeddings are a crucial component of the Discourse AI plugin, enabling features like related topics and semantic search."
|
||||||
|
new: "New embedding"
|
||||||
|
back: "Back"
|
||||||
|
save: "Save"
|
||||||
|
saved: "Embedding configuration saved"
|
||||||
|
delete: "Delete"
|
||||||
|
confirm_delete: Are you sure you want to remove this embedding configuration?
|
||||||
|
empty: "You haven't setup embeddings yet"
|
||||||
|
presets: "Select a preset..."
|
||||||
|
configure_manually: "Configure manually"
|
||||||
|
edit: "Edit"
|
||||||
|
seeded_warning: "This is pre-configured on your site and cannot be edited."
|
||||||
|
tests:
|
||||||
|
title: "Run test"
|
||||||
|
running: "Running test..."
|
||||||
|
success: "Success!"
|
||||||
|
failure: "Attempting to generate an embedding resulted in: %{error}"
|
||||||
|
hints:
|
||||||
|
dimensions_warning: "Once saved, this value can't be changed."
|
||||||
|
|
||||||
|
display_name: "Name"
|
||||||
|
provider: "Provider"
|
||||||
|
url: "Embeddings service URL"
|
||||||
|
api_key: "Embeddings service API Key"
|
||||||
|
tokenizer: "Tokenizer"
|
||||||
|
dimensions: "Embedding dimensions"
|
||||||
|
max_sequence_length: "Sequence length"
|
||||||
|
|
||||||
|
distance_function: "Distance function"
|
||||||
|
distance_functions:
|
||||||
|
<#>: "Negative inner product (<#>)"
|
||||||
|
<=>: "Cosine distance (<=>)"
|
||||||
|
providers:
|
||||||
|
hugging_face: "Hugging Face"
|
||||||
|
open_ai: "OpenAI"
|
||||||
|
google: "Google"
|
||||||
|
cloudflare: "Cloudflare"
|
||||||
|
CDCK: "CDCK"
|
||||||
|
provider_fields:
|
||||||
|
model_name: "Model name"
|
||||||
|
|
||||||
|
|
||||||
semantic_search: "Topics (Semantic)"
|
semantic_search: "Topics (Semantic)"
|
||||||
semantic_search_loading: "Searching for more results using AI"
|
semantic_search_loading: "Searching for more results using AI"
|
||||||
semantic_search_results:
|
semantic_search_results:
|
||||||
|
@ -49,10 +49,7 @@ en:
|
|||||||
ai_nsfw_flag_threshold_sexy: "Threshold for an image classified as sexy to be considered NSFW."
|
ai_nsfw_flag_threshold_sexy: "Threshold for an image classified as sexy to be considered NSFW."
|
||||||
ai_nsfw_models: "Models to use for NSFW inference."
|
ai_nsfw_models: "Models to use for NSFW inference."
|
||||||
|
|
||||||
ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)"
|
ai_openai_api_key: "API key for OpenAI API. ONLY used for Dall-E. For GPT use the LLM config tab"
|
||||||
ai_openai_api_key: "API key for OpenAI API. ONLY used for embeddings and Dall-E. For GPT use the LLM config tab"
|
|
||||||
ai_hugging_face_tei_endpoint: URL where the API is running for the Hugging Face text embeddings inference
|
|
||||||
ai_hugging_face_tei_api_key: API key for Hugging Face text embeddings inference
|
|
||||||
|
|
||||||
ai_helper_enabled: "Enable the AI helper."
|
ai_helper_enabled: "Enable the AI helper."
|
||||||
composer_ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
|
composer_ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
|
||||||
@ -67,15 +64,11 @@ en:
|
|||||||
ai_helper_image_caption_model: "Select the model to use for generating image captions"
|
ai_helper_image_caption_model: "Select the model to use for generating image captions"
|
||||||
ai_auto_image_caption_allowed_groups: "Users on these groups can toggle automatic image captioning."
|
ai_auto_image_caption_allowed_groups: "Users on these groups can toggle automatic image captioning."
|
||||||
|
|
||||||
ai_embeddings_enabled: "Enable the embeddings module."
|
ai_embeddings_selected_model: "Use the selected model for generating embeddings."
|
||||||
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
|
|
||||||
ai_embeddings_discourse_service_api_key: "API key for the embeddings API"
|
|
||||||
ai_embeddings_model: "Use all-mpnet-base-v2 for local and fast inference in english, text-embedding-ada-002 to use OpenAI API (need API key) and multilingual-e5-large for local multilingual embeddings"
|
|
||||||
ai_embeddings_generate_for_pms: "Generate embeddings for personal messages."
|
ai_embeddings_generate_for_pms: "Generate embeddings for personal messages."
|
||||||
ai_embeddings_semantic_related_topics_enabled: "Use Semantic Search for related topics."
|
ai_embeddings_semantic_related_topics_enabled: "Use Semantic Search for related topics."
|
||||||
ai_embeddings_semantic_related_topics: "Maximum number of topics to show in related topic section."
|
ai_embeddings_semantic_related_topics: "Maximum number of topics to show in related topic section."
|
||||||
ai_embeddings_backfill_batch_size: "Number of embeddings to backfill every 15 minutes."
|
ai_embeddings_backfill_batch_size: "Number of embeddings to backfill every 15 minutes."
|
||||||
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
|
|
||||||
ai_embeddings_semantic_search_enabled: "Enable full-page semantic search."
|
ai_embeddings_semantic_search_enabled: "Enable full-page semantic search."
|
||||||
ai_embeddings_semantic_quick_search_enabled: "Enable semantic search option in search menu popup."
|
ai_embeddings_semantic_quick_search_enabled: "Enable semantic search option in search menu popup."
|
||||||
ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results"
|
ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results"
|
||||||
@ -437,13 +430,11 @@ en:
|
|||||||
cannot_edit_builtin: "You can't edit a built-in model."
|
cannot_edit_builtin: "You can't edit a built-in model."
|
||||||
|
|
||||||
embeddings:
|
embeddings:
|
||||||
|
delete_failed: "This model is currently in use. Update the `ai embeddings selected model` first."
|
||||||
|
cannot_edit_builtin: "You can't edit a built-in model."
|
||||||
configuration:
|
configuration:
|
||||||
disable_embeddings: "You have to disable 'ai embeddings enabled' first."
|
disable_embeddings: "You have to disable 'ai embeddings enabled' first."
|
||||||
choose_model: "Set 'ai embeddings model' first."
|
choose_model: "Set 'ai embeddings selected model' first."
|
||||||
model_unreachable: "We failed to generate a test embedding with this model. Check your settings are correct."
|
|
||||||
hint:
|
|
||||||
one: "Make sure the `%{settings}` setting was configured."
|
|
||||||
other: "Make sure the settings of the provider you want were configured. Options are: %{settings}"
|
|
||||||
|
|
||||||
llm_models:
|
llm_models:
|
||||||
missing_provider_param: "%{param} can't be blank"
|
missing_provider_param: "%{param} can't be blank"
|
||||||
|
@ -96,6 +96,13 @@ Discourse::Application.routes.draw do
|
|||||||
controller: "discourse_ai/admin/ai_llm_quotas",
|
controller: "discourse_ai/admin/ai_llm_quotas",
|
||||||
path: "quotas",
|
path: "quotas",
|
||||||
only: %i[index create update destroy]
|
only: %i[index create update destroy]
|
||||||
|
|
||||||
|
resources :ai_embeddings,
|
||||||
|
only: %i[index new create edit update destroy],
|
||||||
|
path: "ai-embeddings",
|
||||||
|
controller: "discourse_ai/admin/ai_embeddings" do
|
||||||
|
collection { get :test }
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -28,11 +28,14 @@ discourse_ai:
|
|||||||
|
|
||||||
|
|
||||||
ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations"
|
ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations"
|
||||||
ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings"
|
ai_openai_embeddings_url:
|
||||||
|
hidden: true
|
||||||
|
default: "https://api.openai.com/v1/embeddings"
|
||||||
ai_openai_organization:
|
ai_openai_organization:
|
||||||
default: ""
|
default: ""
|
||||||
hidden: true
|
hidden: true
|
||||||
ai_openai_api_key:
|
ai_openai_api_key:
|
||||||
|
hidden: true
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
secret: true
|
||||||
ai_stability_api_key:
|
ai_stability_api_key:
|
||||||
@ -50,11 +53,14 @@ discourse_ai:
|
|||||||
- "stable-diffusion-768-v2-1"
|
- "stable-diffusion-768-v2-1"
|
||||||
- "stable-diffusion-v1-5"
|
- "stable-diffusion-v1-5"
|
||||||
ai_hugging_face_tei_endpoint:
|
ai_hugging_face_tei_endpoint:
|
||||||
|
hidden: true
|
||||||
default: ""
|
default: ""
|
||||||
ai_hugging_face_tei_endpoint_srv:
|
ai_hugging_face_tei_endpoint_srv:
|
||||||
default: ""
|
default: ""
|
||||||
hidden: true
|
hidden: true
|
||||||
ai_hugging_face_tei_api_key: ""
|
ai_hugging_face_tei_api_key:
|
||||||
|
default: ""
|
||||||
|
hidden: true
|
||||||
ai_hugging_face_tei_reranker_endpoint:
|
ai_hugging_face_tei_reranker_endpoint:
|
||||||
default: ""
|
default: ""
|
||||||
ai_hugging_face_tei_reranker_endpoint_srv:
|
ai_hugging_face_tei_reranker_endpoint_srv:
|
||||||
@ -69,12 +75,14 @@ discourse_ai:
|
|||||||
ai_cloudflare_workers_account_id:
|
ai_cloudflare_workers_account_id:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
secret: true
|
||||||
|
hidden: true
|
||||||
ai_cloudflare_workers_api_token:
|
ai_cloudflare_workers_api_token:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
secret: true
|
||||||
|
hidden: true
|
||||||
ai_gemini_api_key:
|
ai_gemini_api_key:
|
||||||
default: ""
|
default: ""
|
||||||
hidden: false
|
hidden: true
|
||||||
ai_strict_token_counting:
|
ai_strict_token_counting:
|
||||||
default: false
|
default: false
|
||||||
hidden: true
|
hidden: true
|
||||||
@ -158,27 +166,12 @@ discourse_ai:
|
|||||||
default: false
|
default: false
|
||||||
client: true
|
client: true
|
||||||
validator: "DiscourseAi::Configuration::EmbeddingsModuleValidator"
|
validator: "DiscourseAi::Configuration::EmbeddingsModuleValidator"
|
||||||
ai_embeddings_discourse_service_api_endpoint: ""
|
ai_embeddings_selected_model:
|
||||||
ai_embeddings_discourse_service_api_endpoint_srv:
|
|
||||||
default: ""
|
|
||||||
hidden: true
|
|
||||||
ai_embeddings_discourse_service_api_key:
|
|
||||||
default: ""
|
|
||||||
secret: true
|
|
||||||
ai_embeddings_model:
|
|
||||||
type: enum
|
type: enum
|
||||||
default: "bge-large-en"
|
default: ""
|
||||||
allow_any: false
|
allow_any: false
|
||||||
choices:
|
enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator"
|
||||||
- all-mpnet-base-v2
|
validator: "DiscourseAi::Configuration::EmbeddingDefsValidator"
|
||||||
- text-embedding-ada-002
|
|
||||||
- text-embedding-3-small
|
|
||||||
- text-embedding-3-large
|
|
||||||
- multilingual-e5-large
|
|
||||||
- bge-large-en
|
|
||||||
- gemini
|
|
||||||
- bge-m3
|
|
||||||
validator: "DiscourseAi::Configuration::EmbeddingsModelValidator"
|
|
||||||
ai_embeddings_per_post_enabled:
|
ai_embeddings_per_post_enabled:
|
||||||
default: false
|
default: false
|
||||||
hidden: true
|
hidden: true
|
||||||
@ -191,9 +184,6 @@ discourse_ai:
|
|||||||
ai_embeddings_backfill_batch_size:
|
ai_embeddings_backfill_batch_size:
|
||||||
default: 250
|
default: 250
|
||||||
hidden: true
|
hidden: true
|
||||||
ai_embeddings_pg_connection_string:
|
|
||||||
default: ""
|
|
||||||
hidden: true
|
|
||||||
ai_embeddings_semantic_search_enabled:
|
ai_embeddings_semantic_search_enabled:
|
||||||
default: false
|
default: false
|
||||||
client: true
|
client: true
|
||||||
@ -213,6 +203,35 @@ discourse_ai:
|
|||||||
default: false
|
default: false
|
||||||
client: true
|
client: true
|
||||||
hidden: true
|
hidden: true
|
||||||
|
|
||||||
|
ai_embeddings_discourse_service_api_endpoint:
|
||||||
|
default: ""
|
||||||
|
hidden: true
|
||||||
|
ai_embeddings_discourse_service_api_endpoint_srv:
|
||||||
|
default: ""
|
||||||
|
hidden: true
|
||||||
|
ai_embeddings_discourse_service_api_key:
|
||||||
|
hidden: true
|
||||||
|
default: ""
|
||||||
|
secret: true
|
||||||
|
ai_embeddings_model:
|
||||||
|
hidden: true
|
||||||
|
type: enum
|
||||||
|
default: "bge-large-en"
|
||||||
|
allow_any: false
|
||||||
|
choices:
|
||||||
|
- all-mpnet-base-v2
|
||||||
|
- text-embedding-ada-002
|
||||||
|
- text-embedding-3-small
|
||||||
|
- text-embedding-3-large
|
||||||
|
- multilingual-e5-large
|
||||||
|
- bge-large-en
|
||||||
|
- gemini
|
||||||
|
- bge-m3
|
||||||
|
ai_embeddings_pg_connection_string:
|
||||||
|
default: ""
|
||||||
|
hidden: true
|
||||||
|
|
||||||
ai_summarization_enabled:
|
ai_summarization_enabled:
|
||||||
default: false
|
default: false
|
||||||
client: true
|
client: true
|
||||||
|
19
db/migrate/20241217164540_create_embedding_definitions.rb
Normal file
19
db/migrate/20241217164540_create_embedding_definitions.rb
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
class CreateEmbeddingDefinitions < ActiveRecord::Migration[7.2]
|
||||||
|
def change
|
||||||
|
create_table :embedding_definitions do |t|
|
||||||
|
t.string :display_name, null: false
|
||||||
|
t.integer :dimensions, null: false
|
||||||
|
t.integer :max_sequence_length, null: false
|
||||||
|
t.integer :version, null: false, default: 1
|
||||||
|
t.string :pg_function, null: false
|
||||||
|
t.string :provider, null: false
|
||||||
|
t.string :tokenizer_class, null: false
|
||||||
|
t.string :url, null: false
|
||||||
|
t.string :api_key
|
||||||
|
t.boolean :seeded, null: false, default: false
|
||||||
|
t.jsonb :provider_params
|
||||||
|
t.timestamps
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
204
db/migrate/20250110114305_embedding_config_data_migration.rb
Normal file
204
db/migrate/20250110114305_embedding_config_data_migration.rb
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class EmbeddingConfigDataMigration < ActiveRecord::Migration[7.0]
|
||||||
|
def up
|
||||||
|
current_model = fetch_setting("ai_embeddings_model") || "bge-large-en"
|
||||||
|
provider = provider_for(current_model)
|
||||||
|
|
||||||
|
if provider.present?
|
||||||
|
attrs = creds_for(provider)
|
||||||
|
|
||||||
|
if attrs.present?
|
||||||
|
attrs = attrs.merge(model_attrs(current_model))
|
||||||
|
attrs[:display_name] = current_model
|
||||||
|
attrs[:provider] = provider
|
||||||
|
persist_config(attrs)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def down
|
||||||
|
end
|
||||||
|
|
||||||
|
# Utils
|
||||||
|
|
||||||
|
def fetch_setting(name)
|
||||||
|
DB.query_single(
|
||||||
|
"SELECT value FROM site_settings WHERE name = :setting_name",
|
||||||
|
setting_name: name,
|
||||||
|
).first || ENV["DISCOURSE_#{name&.upcase}"]
|
||||||
|
end
|
||||||
|
|
||||||
|
def provider_for(model)
|
||||||
|
cloudflare_api_token = fetch_setting("ai_cloudflare_workers_api_token")
|
||||||
|
|
||||||
|
return "cloudflare" if model == "bge-large-en" && cloudflare_api_token.present?
|
||||||
|
|
||||||
|
tei_models = %w[bge-large-en bge-m3 multilingual-e5-large]
|
||||||
|
return "hugging_face" if tei_models.include?(model)
|
||||||
|
|
||||||
|
return "google" if model == "gemini"
|
||||||
|
|
||||||
|
if %w[text-embedding-3-large text-embedding-3-small text-embedding-ada-002].include?(model)
|
||||||
|
return "open_ai"
|
||||||
|
end
|
||||||
|
|
||||||
|
nil
|
||||||
|
end
|
||||||
|
|
||||||
|
def creds_for(provider)
|
||||||
|
# CF
|
||||||
|
if provider == "cloudflare"
|
||||||
|
api_key = fetch_setting("ai_cloudflare_workers_api_token")
|
||||||
|
account_id = fetch_setting("ai_cloudflare_workers_account_id")
|
||||||
|
|
||||||
|
return if api_key.blank? || account_id.blank?
|
||||||
|
|
||||||
|
{
|
||||||
|
url:
|
||||||
|
"https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/baai/bge-large-en-v1.5",
|
||||||
|
api_key: api_key,
|
||||||
|
}
|
||||||
|
# TEI
|
||||||
|
elsif provider == "hugging_face"
|
||||||
|
seeded = false
|
||||||
|
endpoint = fetch_setting("ai_hugging_face_tei_endpoint")
|
||||||
|
|
||||||
|
if endpoint.blank?
|
||||||
|
endpoint = fetch_setting("ai_hugging_face_tei_endpoint_srv")
|
||||||
|
if endpoint.present?
|
||||||
|
endpoint = "srv://#{endpoint}"
|
||||||
|
seeded = true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
api_key = fetch_setting("ai_hugging_face_tei_api_key")
|
||||||
|
|
||||||
|
return if endpoint.blank? || api_key.blank?
|
||||||
|
|
||||||
|
{ url: endpoint, api_key: api_key, seeded: seeded }
|
||||||
|
# Gemini
|
||||||
|
elsif provider == "google"
|
||||||
|
api_key = fetch_setting("ai_gemini_api_key")
|
||||||
|
|
||||||
|
return if api_key.blank?
|
||||||
|
|
||||||
|
{
|
||||||
|
url: "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent",
|
||||||
|
api_key: api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Open AI
|
||||||
|
elsif provider == "open_ai"
|
||||||
|
endpoint = fetch_setting("ai_openai_embeddings_url")
|
||||||
|
api_key = fetch_setting("ai_openai_api_key")
|
||||||
|
|
||||||
|
return if endpoint.blank? || api_key.blank?
|
||||||
|
|
||||||
|
{ url: endpoint, api_key: api_key }
|
||||||
|
else
|
||||||
|
nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def model_attrs(model_name)
|
||||||
|
if model_name == "bge-large-en"
|
||||||
|
{
|
||||||
|
dimensions: 1024,
|
||||||
|
max_sequence_length: 512,
|
||||||
|
id: 4,
|
||||||
|
pg_function: "<#>",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer",
|
||||||
|
}
|
||||||
|
elsif model_name == "bge-m3"
|
||||||
|
{
|
||||||
|
dimensions: 1024,
|
||||||
|
max_sequence_length: 8192,
|
||||||
|
id: 8,
|
||||||
|
pg_function: "<#>",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
|
||||||
|
}
|
||||||
|
elsif model_name == "gemini"
|
||||||
|
{
|
||||||
|
dimensions: 768,
|
||||||
|
max_sequence_length: 1536,
|
||||||
|
id: 5,
|
||||||
|
pg_function: "<=>",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||||
|
}
|
||||||
|
elsif model_name == "multilingual-e5-large"
|
||||||
|
{
|
||||||
|
dimensions: 1024,
|
||||||
|
max_sequence_length: 512,
|
||||||
|
id: 3,
|
||||||
|
pg_function: "<=>",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer",
|
||||||
|
}
|
||||||
|
elsif model_name == "text-embedding-3-large"
|
||||||
|
{
|
||||||
|
dimensions: 2000,
|
||||||
|
max_sequence_length: 8191,
|
||||||
|
id: 7,
|
||||||
|
pg_function: "<=>",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||||
|
provider_params: {
|
||||||
|
model_name: "text-embedding-3-large",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
elsif model_name == "text-embedding-3-small"
|
||||||
|
{
|
||||||
|
dimensions: 1536,
|
||||||
|
max_sequence_length: 8191,
|
||||||
|
id: 6,
|
||||||
|
pg_function: "<=>",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||||
|
provider_params: {
|
||||||
|
model_name: "text-embedding-3-small",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
dimensions: 1536,
|
||||||
|
max_sequence_length: 8191,
|
||||||
|
id: 2,
|
||||||
|
pg_function: "<=>",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||||
|
provider_params: {
|
||||||
|
model_name: "text-embedding-ada-002",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def persist_config(attrs)
|
||||||
|
DB.exec(
|
||||||
|
<<~SQL,
|
||||||
|
INSERT INTO embedding_definitions (id, display_name, dimensions, max_sequence_length, version, pg_function, provider, tokenizer_class, url, api_key, provider_params, seeded, created_at, updated_at)
|
||||||
|
VALUES (:id, :display_name, :dimensions, :max_sequence_length, 1, :pg_function, :provider, :tokenizer_class, :url, :api_key, :provider_params, :seeded, :now, :now)
|
||||||
|
SQL
|
||||||
|
id: attrs[:id],
|
||||||
|
display_name: attrs[:display_name],
|
||||||
|
dimensions: attrs[:dimensions],
|
||||||
|
max_sequence_length: attrs[:max_sequence_length],
|
||||||
|
pg_function: attrs[:pg_function],
|
||||||
|
provider: attrs[:provider],
|
||||||
|
tokenizer_class: attrs[:tokenizer_class],
|
||||||
|
url: attrs[:url],
|
||||||
|
api_key: attrs[:api_key],
|
||||||
|
provider_params: attrs[:provider_params],
|
||||||
|
seeded: !!attrs[:seeded],
|
||||||
|
now: Time.zone.now,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We hardcoded the ID to match with already generated embeddings. Let's restart the seq to avoid conflicts.
|
||||||
|
DB.exec(
|
||||||
|
"ALTER SEQUENCE embedding_definitions_id_seq RESTART WITH :new_seq",
|
||||||
|
new_seq: attrs[:id].to_i + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
DB.exec(<<~SQL, new_value: attrs[:id])
|
||||||
|
INSERT INTO site_settings(name, data_type, value, created_at, updated_at)
|
||||||
|
VALUES ('ai_embeddings_selected_model', 3, :new_value, NOW(), NOW())
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
end
|
@ -196,7 +196,7 @@ module DiscourseAi
|
|||||||
)
|
)
|
||||||
|
|
||||||
plugin.on(:site_setting_changed) do |name, old_value, new_value|
|
plugin.on(:site_setting_changed) do |name, old_value, new_value|
|
||||||
if name == :ai_embeddings_model && SiteSetting.ai_embeddings_enabled? &&
|
if name == :ai_embeddings_selected_model && SiteSetting.ai_embeddings_enabled? &&
|
||||||
new_value != old_value
|
new_value != old_value
|
||||||
RagDocumentFragment.delete_all
|
RagDocumentFragment.delete_all
|
||||||
UploadReference
|
UploadReference
|
||||||
|
@ -327,7 +327,7 @@ module DiscourseAi
|
|||||||
rag_conversation_chunks
|
rag_conversation_chunks
|
||||||
end
|
end
|
||||||
|
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector.vdef)
|
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment)
|
||||||
|
|
||||||
candidate_fragment_ids =
|
candidate_fragment_ids =
|
||||||
schema
|
schema
|
||||||
|
@ -93,7 +93,7 @@ module DiscourseAi
|
|||||||
|
|
||||||
def nearest_neighbors(limit: 100)
|
def nearest_neighbors(limit: 100)
|
||||||
vector = DiscourseAi::Embeddings::Vector.instance
|
vector = DiscourseAi::Embeddings::Vector.instance
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
|
schema = DiscourseAi::Embeddings::Schema.for(Topic)
|
||||||
|
|
||||||
raw_vector = vector.vector_from(@text)
|
raw_vector = vector.vector_from(@text)
|
||||||
|
|
||||||
|
20
lib/configuration/embedding_defs_enumerator.rb
Normal file
20
lib/configuration/embedding_defs_enumerator.rb
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require "enum_site_setting"
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Configuration
|
||||||
|
class EmbeddingDefsEnumerator < ::EnumSiteSetting
|
||||||
|
def self.valid_value?(val)
|
||||||
|
true
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.values
|
||||||
|
DB.query_hash(<<~SQL).map(&:symbolize_keys)
|
||||||
|
SELECT display_name AS name, id AS value
|
||||||
|
FROM embedding_definitions
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
19
lib/configuration/embedding_defs_validator.rb
Normal file
19
lib/configuration/embedding_defs_validator.rb
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Configuration
|
||||||
|
class EmbeddingDefsValidator
|
||||||
|
def initialize(opts = {})
|
||||||
|
@opts = opts
|
||||||
|
end
|
||||||
|
|
||||||
|
def valid_value?(val)
|
||||||
|
val.blank? || EmbeddingDefinition.exists?(id: val)
|
||||||
|
end
|
||||||
|
|
||||||
|
def error_message
|
||||||
|
""
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -11,41 +11,11 @@ module DiscourseAi
|
|||||||
return true if val == "f"
|
return true if val == "f"
|
||||||
return true if Rails.env.test?
|
return true if Rails.env.test?
|
||||||
|
|
||||||
chosen_model = SiteSetting.ai_embeddings_model
|
SiteSetting.ai_embeddings_selected_model.present?
|
||||||
|
|
||||||
return false if !chosen_model
|
|
||||||
|
|
||||||
representation =
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(chosen_model)
|
|
||||||
|
|
||||||
return false if representation.nil?
|
|
||||||
|
|
||||||
if !representation.correctly_configured?
|
|
||||||
@representation = representation
|
|
||||||
return false
|
|
||||||
end
|
|
||||||
|
|
||||||
if !can_generate_embeddings?(chosen_model)
|
|
||||||
@unreachable = true
|
|
||||||
return false
|
|
||||||
end
|
|
||||||
|
|
||||||
true
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def error_message
|
def error_message
|
||||||
return(I18n.t("discourse_ai.embeddings.configuration.model_unreachable")) if @unreachable
|
I18n.t("discourse_ai.embeddings.configuration.choose_model")
|
||||||
|
|
||||||
@representation&.configuration_hint
|
|
||||||
end
|
|
||||||
|
|
||||||
def can_generate_embeddings?(val)
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::Base
|
|
||||||
.find_representation(val)
|
|
||||||
.new
|
|
||||||
.inference_client
|
|
||||||
.perform!("this is a test")
|
|
||||||
.present?
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -12,19 +12,80 @@ module DiscourseAi
|
|||||||
POSTS_TABLE = "ai_posts_embeddings"
|
POSTS_TABLE = "ai_posts_embeddings"
|
||||||
RAG_DOCS_TABLE = "ai_document_fragments_embeddings"
|
RAG_DOCS_TABLE = "ai_document_fragments_embeddings"
|
||||||
|
|
||||||
def self.for(
|
EMBEDDING_TARGETS = %w[topics posts document_fragments]
|
||||||
target_klass,
|
EMBEDDING_TABLES = [TOPICS_TABLE, POSTS_TABLE, RAG_DOCS_TABLE]
|
||||||
vector_def: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
|
||||||
)
|
MissingEmbeddingError = Class.new(StandardError)
|
||||||
case target_klass&.name
|
|
||||||
when "Topic"
|
class << self
|
||||||
new(TOPICS_TABLE, "topic_id", vector_def)
|
def for(target_klass)
|
||||||
when "Post"
|
vector_def = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_selected_model)
|
||||||
new(POSTS_TABLE, "post_id", vector_def)
|
raise "Invalid embeddings selected model" if vector_def.nil?
|
||||||
when "RagDocumentFragment"
|
|
||||||
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def)
|
case target_klass&.name
|
||||||
else
|
when "Topic"
|
||||||
raise ArgumentError, "Invalid target type for embeddings"
|
new(TOPICS_TABLE, "topic_id", vector_def)
|
||||||
|
when "Post"
|
||||||
|
new(POSTS_TABLE, "post_id", vector_def)
|
||||||
|
when "RagDocumentFragment"
|
||||||
|
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def)
|
||||||
|
else
|
||||||
|
raise ArgumentError, "Invalid target type for embeddings"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def search_index_name(table, def_id)
|
||||||
|
"ai_#{table}_embeddings_#{def_id}_1_search_bit"
|
||||||
|
end
|
||||||
|
|
||||||
|
def prepare_search_indexes(vector_def)
|
||||||
|
EMBEDDING_TARGETS.each { |target| DB.exec <<~SQL }
|
||||||
|
CREATE INDEX IF NOT EXISTS #{search_index_name(target, vector_def.id)} ON ai_#{target}_embeddings
|
||||||
|
USING hnsw ((binary_quantize(embeddings)::bit(#{vector_def.dimensions})) bit_hamming_ops)
|
||||||
|
WHERE model_id = #{vector_def.id} AND strategy_id = 1;
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
|
||||||
|
def correctly_indexed?(vector_def)
|
||||||
|
index_names = EMBEDDING_TARGETS.map { |t| search_index_name(t, vector_def.id) }
|
||||||
|
indexdefs =
|
||||||
|
DB.query_single(
|
||||||
|
"SELECT indexdef FROM pg_indexes WHERE indexname IN (:names)",
|
||||||
|
names: index_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
return false if indexdefs.length < index_names.length
|
||||||
|
|
||||||
|
indexdefs.all? do |defs|
|
||||||
|
defs.include? "(binary_quantize(embeddings))::bit(#{vector_def.dimensions})"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def remove_orphaned_data
|
||||||
|
removed_defs_ids =
|
||||||
|
DB.query_single(
|
||||||
|
"SELECT DISTINCT(model_id) FROM #{TOPICS_TABLE} te LEFT JOIN embedding_definitions ed ON te.model_id = ed.id WHERE ed.id IS NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
EMBEDDING_TABLES.each do |t|
|
||||||
|
DB.exec(
|
||||||
|
"DELETE FROM #{t} WHERE model_id IN (:removed_defs)",
|
||||||
|
removed_defs: removed_defs_ids,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
drop_index_statement =
|
||||||
|
EMBEDDING_TARGETS
|
||||||
|
.reduce([]) do |memo, et|
|
||||||
|
removed_defs_ids.each do |rdi|
|
||||||
|
memo << "DROP INDEX IF EXISTS #{search_index_name(et, rdi)};"
|
||||||
|
end
|
||||||
|
|
||||||
|
memo
|
||||||
|
end
|
||||||
|
.join("\n")
|
||||||
|
|
||||||
|
DB.exec(drop_index_statement)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -117,7 +178,7 @@ module DiscourseAi
|
|||||||
offset: offset,
|
offset: offset,
|
||||||
)
|
)
|
||||||
rescue PG::Error => e
|
rescue PG::Error => e
|
||||||
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
Rails.logger.error("Error #{e} querying embeddings for model #{vector_def.display_name}")
|
||||||
raise MissingEmbeddingError
|
raise MissingEmbeddingError
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -168,7 +229,7 @@ module DiscourseAi
|
|||||||
|
|
||||||
builder.query(vid: vector_def.id, vsid: vector_def.strategy_id, target_id: record.id)
|
builder.query(vid: vector_def.id, vsid: vector_def.strategy_id, target_id: record.id)
|
||||||
rescue PG::Error => e
|
rescue PG::Error => e
|
||||||
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
Rails.logger.error("Error #{e} querying embeddings for model #{vector_def.display_name}")
|
||||||
raise MissingEmbeddingError
|
raise MissingEmbeddingError
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ module DiscourseAi
|
|||||||
|
|
||||||
over_selection_limit = limit * OVER_SELECTION_FACTOR
|
over_selection_limit = limit * OVER_SELECTION_FACTOR
|
||||||
|
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
|
schema = DiscourseAi::Embeddings::Schema.for(Topic)
|
||||||
|
|
||||||
candidate_topic_ids =
|
candidate_topic_ids =
|
||||||
schema.asymmetric_similarity_search(
|
schema.asymmetric_similarity_search(
|
||||||
@ -132,7 +132,7 @@ module DiscourseAi
|
|||||||
|
|
||||||
candidate_post_ids =
|
candidate_post_ids =
|
||||||
DiscourseAi::Embeddings::Schema
|
DiscourseAi::Embeddings::Schema
|
||||||
.for(Post, vector_def: vector.vdef)
|
.for(Post)
|
||||||
.asymmetric_similarity_search(
|
.asymmetric_similarity_search(
|
||||||
search_term_embedding,
|
search_term_embedding,
|
||||||
limit: max_semantic_results_per_page,
|
limit: max_semantic_results_per_page,
|
||||||
|
@ -4,13 +4,18 @@ module DiscourseAi
|
|||||||
module Embeddings
|
module Embeddings
|
||||||
class Vector
|
class Vector
|
||||||
def self.instance
|
def self.instance
|
||||||
new(DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation)
|
vector_def = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_selected_model)
|
||||||
|
raise "Invalid embeddings selected model" if vector_def.nil?
|
||||||
|
|
||||||
|
new(vector_def)
|
||||||
end
|
end
|
||||||
|
|
||||||
def initialize(vector_definition)
|
def initialize(vector_definition)
|
||||||
@vdef = vector_definition
|
@vdef = vector_definition
|
||||||
end
|
end
|
||||||
|
|
||||||
|
delegate :tokenizer, to: :vdef
|
||||||
|
|
||||||
def gen_bulk_reprensentations(relation)
|
def gen_bulk_reprensentations(relation)
|
||||||
http_pool_size = 100
|
http_pool_size = 100
|
||||||
pool =
|
pool =
|
||||||
@ -20,7 +25,7 @@ module DiscourseAi
|
|||||||
idletime: 30,
|
idletime: 30,
|
||||||
)
|
)
|
||||||
|
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector_def: vdef)
|
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class)
|
||||||
|
|
||||||
embedding_gen = vdef.inference_client
|
embedding_gen = vdef.inference_client
|
||||||
promised_embeddings =
|
promised_embeddings =
|
||||||
@ -53,7 +58,7 @@ module DiscourseAi
|
|||||||
text = vdef.prepare_target_text(target)
|
text = vdef.prepare_target_text(target)
|
||||||
return if text.blank?
|
return if text.blank?
|
||||||
|
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector_def: vdef)
|
schema = DiscourseAi::Embeddings::Schema.for(target.class)
|
||||||
|
|
||||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
||||||
return if schema.find_by_target(target)&.digest == new_digest
|
return if schema.find_by_target(target)&.digest == new_digest
|
||||||
|
@ -1,56 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
module VectorRepresentations
|
|
||||||
class AllMpnetBaseV2 < Base
|
|
||||||
class << self
|
|
||||||
def name
|
|
||||||
"all-mpnet-base-v2"
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[
|
|
||||||
ai_embeddings_discourse_service_api_key
|
|
||||||
ai_embeddings_discourse_service_api_endpoint_srv
|
|
||||||
ai_embeddings_discourse_service_api_endpoint
|
|
||||||
]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
|
||||||
768
|
|
||||||
end
|
|
||||||
|
|
||||||
def max_sequence_length
|
|
||||||
384
|
|
||||||
end
|
|
||||||
|
|
||||||
def id
|
|
||||||
1
|
|
||||||
end
|
|
||||||
|
|
||||||
def version
|
|
||||||
1
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_function
|
|
||||||
"<#>"
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
|
|
||||||
end
|
|
||||||
|
|
||||||
def inference_client
|
|
||||||
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
@ -1,103 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
module VectorRepresentations
|
|
||||||
class Base
|
|
||||||
class << self
|
|
||||||
def find_representation(model_name)
|
|
||||||
# we are explicit here cause the loader may have not
|
|
||||||
# loaded the subclasses yet
|
|
||||||
[
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::BgeM3,
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
|
|
||||||
].find { _1.name == model_name }
|
|
||||||
end
|
|
||||||
|
|
||||||
def current_representation
|
|
||||||
find_representation(SiteSetting.ai_embeddings_model).new
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?
|
|
||||||
raise NotImplementedError
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
raise NotImplementedError
|
|
||||||
end
|
|
||||||
|
|
||||||
def configuration_hint
|
|
||||||
settings = dependant_setting_names
|
|
||||||
I18n.t(
|
|
||||||
"discourse_ai.embeddings.configuration.hint",
|
|
||||||
settings: settings.join(", "),
|
|
||||||
count: settings.length,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def name
|
|
||||||
raise NotImplementedError
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
|
||||||
raise NotImplementedError
|
|
||||||
end
|
|
||||||
|
|
||||||
def max_sequence_length
|
|
||||||
raise NotImplementedError
|
|
||||||
end
|
|
||||||
|
|
||||||
def id
|
|
||||||
raise NotImplementedError
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_function
|
|
||||||
raise NotImplementedError
|
|
||||||
end
|
|
||||||
|
|
||||||
def version
|
|
||||||
raise NotImplementedError
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
raise NotImplementedError
|
|
||||||
end
|
|
||||||
|
|
||||||
def asymmetric_query_prefix
|
|
||||||
""
|
|
||||||
end
|
|
||||||
|
|
||||||
def strategy_id
|
|
||||||
strategy.id
|
|
||||||
end
|
|
||||||
|
|
||||||
def strategy_version
|
|
||||||
strategy.version
|
|
||||||
end
|
|
||||||
|
|
||||||
def prepare_query_text(text, asymetric: false)
|
|
||||||
strategy.prepare_query_text(text, self, asymetric: asymetric)
|
|
||||||
end
|
|
||||||
|
|
||||||
def prepare_target_text(target)
|
|
||||||
strategy.prepare_target_text(target, self)
|
|
||||||
end
|
|
||||||
|
|
||||||
def strategy
|
|
||||||
@strategy ||= DiscourseAi::Embeddings::Strategies::Truncation.new
|
|
||||||
end
|
|
||||||
|
|
||||||
def inference_client
|
|
||||||
raise NotImplementedError
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
@ -1,80 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
module VectorRepresentations
|
|
||||||
class BgeLargeEn < Base
|
|
||||||
class << self
|
|
||||||
def name
|
|
||||||
"bge-large-en"
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?
|
|
||||||
SiteSetting.ai_cloudflare_workers_api_token.present? ||
|
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
|
|
||||||
(
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[
|
|
||||||
ai_cloudflare_workers_api_token
|
|
||||||
ai_hugging_face_tei_endpoint_srv
|
|
||||||
ai_hugging_face_tei_endpoint
|
|
||||||
ai_embeddings_discourse_service_api_key
|
|
||||||
ai_embeddings_discourse_service_api_endpoint_srv
|
|
||||||
ai_embeddings_discourse_service_api_endpoint
|
|
||||||
]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
|
||||||
1024
|
|
||||||
end
|
|
||||||
|
|
||||||
def max_sequence_length
|
|
||||||
512
|
|
||||||
end
|
|
||||||
|
|
||||||
def id
|
|
||||||
4
|
|
||||||
end
|
|
||||||
|
|
||||||
def version
|
|
||||||
1
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_function
|
|
||||||
"<#>"
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::BgeLargeEnTokenizer
|
|
||||||
end
|
|
||||||
|
|
||||||
def asymmetric_query_prefix
|
|
||||||
"Represent this sentence for searching relevant passages:"
|
|
||||||
end
|
|
||||||
|
|
||||||
def inference_client
|
|
||||||
inference_model_name = "baai/bge-large-en-v1.5"
|
|
||||||
|
|
||||||
if SiteSetting.ai_cloudflare_workers_api_token.present?
|
|
||||||
DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name)
|
|
||||||
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
|
|
||||||
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
|
||||||
DiscourseAi::Inference::DiscourseClassifier.instance(
|
|
||||||
inference_model_name.split("/").last,
|
|
||||||
)
|
|
||||||
else
|
|
||||||
raise "No inference endpoint configured"
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
@ -1,51 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
module VectorRepresentations
|
|
||||||
class BgeM3 < Base
|
|
||||||
class << self
|
|
||||||
def name
|
|
||||||
"bge-m3"
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?
|
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[ai_hugging_face_tei_endpoint_srv ai_hugging_face_tei_endpoint]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
|
||||||
1024
|
|
||||||
end
|
|
||||||
|
|
||||||
def max_sequence_length
|
|
||||||
8192
|
|
||||||
end
|
|
||||||
|
|
||||||
def id
|
|
||||||
8
|
|
||||||
end
|
|
||||||
|
|
||||||
def version
|
|
||||||
1
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_function
|
|
||||||
"<#>"
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::BgeM3Tokenizer
|
|
||||||
end
|
|
||||||
|
|
||||||
def inference_client
|
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
@ -1,54 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
module VectorRepresentations
|
|
||||||
class Gemini < Base
|
|
||||||
class << self
|
|
||||||
def name
|
|
||||||
"gemini"
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?
|
|
||||||
SiteSetting.ai_gemini_api_key.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[ai_gemini_api_key]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def id
|
|
||||||
5
|
|
||||||
end
|
|
||||||
|
|
||||||
def version
|
|
||||||
1
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
|
||||||
768
|
|
||||||
end
|
|
||||||
|
|
||||||
def max_sequence_length
|
|
||||||
1536 # Gemini has a max sequence length of 2048, but the API has a limit of 10000 bytes, hence the lower value
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_function
|
|
||||||
"<=>"
|
|
||||||
end
|
|
||||||
|
|
||||||
# There is no public tokenizer for Gemini, and from the ones we already ship in the plugin
|
|
||||||
# OpenAI gets the closest results. Gemini Tokenizer results in ~10% less tokens, so it's safe
|
|
||||||
# to use OpenAI tokenizer since it will overestimate the number of tokens.
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
|
||||||
end
|
|
||||||
|
|
||||||
def inference_client
|
|
||||||
DiscourseAi::Inference::GeminiEmbeddings.instance
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
@ -1,88 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
module VectorRepresentations
|
|
||||||
class MultilingualE5Large < Base
|
|
||||||
class << self
|
|
||||||
def name
|
|
||||||
"multilingual-e5-large"
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?
|
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
|
|
||||||
(
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[
|
|
||||||
ai_hugging_face_tei_endpoint_srv
|
|
||||||
ai_hugging_face_tei_endpoint
|
|
||||||
ai_embeddings_discourse_service_api_key
|
|
||||||
ai_embeddings_discourse_service_api_endpoint_srv
|
|
||||||
ai_embeddings_discourse_service_api_endpoint
|
|
||||||
]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def id
|
|
||||||
3
|
|
||||||
end
|
|
||||||
|
|
||||||
def version
|
|
||||||
1
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
|
||||||
1024
|
|
||||||
end
|
|
||||||
|
|
||||||
def max_sequence_length
|
|
||||||
512
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_function
|
|
||||||
"<=>"
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
|
|
||||||
end
|
|
||||||
|
|
||||||
def inference_client
|
|
||||||
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
|
|
||||||
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
|
||||||
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
|
|
||||||
else
|
|
||||||
raise "No inference endpoint configured"
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def prepare_text(text, asymetric: false)
|
|
||||||
prepared_text = super(text, asymetric: asymetric)
|
|
||||||
|
|
||||||
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
|
|
||||||
return "query: #{prepared_text}"
|
|
||||||
end
|
|
||||||
|
|
||||||
prepared_text
|
|
||||||
end
|
|
||||||
|
|
||||||
def prepare_target_text(target)
|
|
||||||
prepared_text = super(target)
|
|
||||||
|
|
||||||
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
|
|
||||||
return "query: #{prepared_text}"
|
|
||||||
end
|
|
||||||
|
|
||||||
prepared_text
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
@ -1,56 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
module VectorRepresentations
|
|
||||||
class TextEmbedding3Large < Base
|
|
||||||
class << self
|
|
||||||
def name
|
|
||||||
"text-embedding-3-large"
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?
|
|
||||||
SiteSetting.ai_openai_api_key.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[ai_openai_api_key]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def id
|
|
||||||
7
|
|
||||||
end
|
|
||||||
|
|
||||||
def version
|
|
||||||
1
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
|
||||||
# real dimentions are 3072, but we only support up to 2000 in the
|
|
||||||
# indexes, so we downsample to 2000 via API
|
|
||||||
2000
|
|
||||||
end
|
|
||||||
|
|
||||||
def max_sequence_length
|
|
||||||
8191
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_function
|
|
||||||
"<=>"
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
|
||||||
end
|
|
||||||
|
|
||||||
def inference_client
|
|
||||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(
|
|
||||||
model: self.class.name,
|
|
||||||
dimensions: dimensions,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
@ -1,51 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
module VectorRepresentations
|
|
||||||
class TextEmbedding3Small < Base
|
|
||||||
class << self
|
|
||||||
def name
|
|
||||||
"text-embedding-3-small"
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?
|
|
||||||
SiteSetting.ai_openai_api_key.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[ai_openai_api_key]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def id
|
|
||||||
6
|
|
||||||
end
|
|
||||||
|
|
||||||
def version
|
|
||||||
1
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
|
||||||
1536
|
|
||||||
end
|
|
||||||
|
|
||||||
def max_sequence_length
|
|
||||||
8191
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_function
|
|
||||||
"<=>"
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
|
||||||
end
|
|
||||||
|
|
||||||
def inference_client
|
|
||||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
@ -1,51 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
module VectorRepresentations
|
|
||||||
class TextEmbeddingAda002 < Base
|
|
||||||
class << self
|
|
||||||
def name
|
|
||||||
"text-embedding-ada-002"
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?
|
|
||||||
SiteSetting.ai_openai_api_key.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[ai_openai_api_key]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def id
|
|
||||||
2
|
|
||||||
end
|
|
||||||
|
|
||||||
def version
|
|
||||||
1
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
|
||||||
1536
|
|
||||||
end
|
|
||||||
|
|
||||||
def max_sequence_length
|
|
||||||
8191
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_function
|
|
||||||
"<=>"
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
|
||||||
end
|
|
||||||
|
|
||||||
def inference_client
|
|
||||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
@ -3,22 +3,13 @@
|
|||||||
module ::DiscourseAi
|
module ::DiscourseAi
|
||||||
module Inference
|
module Inference
|
||||||
class CloudflareWorkersAi
|
class CloudflareWorkersAi
|
||||||
def initialize(account_id, api_token, model, referer = Discourse.base_url)
|
def initialize(endpoint, api_token, referer = Discourse.base_url)
|
||||||
@account_id = account_id
|
@endpoint = endpoint
|
||||||
@api_token = api_token
|
@api_token = api_token
|
||||||
@model = model
|
|
||||||
@referer = referer
|
@referer = referer
|
||||||
end
|
end
|
||||||
|
|
||||||
def self.instance(model)
|
attr_reader :endpoint, :api_token, :referer
|
||||||
new(
|
|
||||||
SiteSetting.ai_cloudflare_workers_account_id,
|
|
||||||
SiteSetting.ai_cloudflare_workers_api_token,
|
|
||||||
model,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
attr_reader :account_id, :api_token, :model, :referer
|
|
||||||
|
|
||||||
def perform!(content)
|
def perform!(content)
|
||||||
headers = {
|
headers = {
|
||||||
@ -29,8 +20,6 @@ module ::DiscourseAi
|
|||||||
|
|
||||||
payload = { text: [content] }
|
payload = { text: [content] }
|
||||||
|
|
||||||
endpoint = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/#{model}"
|
|
||||||
|
|
||||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||||
response = conn.post(endpoint, payload.to_json, headers)
|
response = conn.post(endpoint, payload.to_json, headers)
|
||||||
|
|
||||||
@ -43,7 +32,7 @@ module ::DiscourseAi
|
|||||||
Rails.logger.warn(
|
Rails.logger.warn(
|
||||||
"Cloudflare Workers AI Embeddings failed with status: #{response.status} body: #{response.body}",
|
"Cloudflare Workers AI Embeddings failed with status: #{response.status} body: #{response.body}",
|
||||||
)
|
)
|
||||||
raise Net::HTTPBadResponse
|
raise Net::HTTPBadResponse.new(response.body.to_s)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -1,47 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module ::DiscourseAi
|
|
||||||
module Inference
|
|
||||||
class DiscourseClassifier
|
|
||||||
def initialize(endpoint, api_key, model, referer = Discourse.base_url)
|
|
||||||
@endpoint = endpoint
|
|
||||||
@api_key = api_key
|
|
||||||
@model = model
|
|
||||||
@referer = referer
|
|
||||||
end
|
|
||||||
|
|
||||||
def self.instance(model)
|
|
||||||
endpoint =
|
|
||||||
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
|
|
||||||
service =
|
|
||||||
DiscourseAi::Utils::DnsSrv.lookup(
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
|
|
||||||
)
|
|
||||||
"https://#{service.target}:#{service.port}"
|
|
||||||
else
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint
|
|
||||||
end
|
|
||||||
|
|
||||||
new(
|
|
||||||
"#{endpoint}/api/v1/classify",
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
|
||||||
model,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
attr_reader :endpoint, :api_key, :model, :referer
|
|
||||||
|
|
||||||
def perform!(content)
|
|
||||||
headers = { "Referer" => referer, "Content-Type" => "application/json" }
|
|
||||||
headers["X-API-KEY"] = api_key if api_key.present?
|
|
||||||
|
|
||||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
|
||||||
response = conn.post(endpoint, { model: model, content: content }.to_json, headers)
|
|
||||||
|
|
||||||
raise Net::HTTPBadResponse if ![200, 415].include?(response.status)
|
|
||||||
|
|
||||||
JSON.parse(response.body, symbolize_names: true)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
@ -3,21 +3,17 @@
|
|||||||
module ::DiscourseAi
|
module ::DiscourseAi
|
||||||
module Inference
|
module Inference
|
||||||
class GeminiEmbeddings
|
class GeminiEmbeddings
|
||||||
def self.instance
|
def initialize(embedding_url, api_key, referer = Discourse.base_url)
|
||||||
new(SiteSetting.ai_gemini_api_key)
|
|
||||||
end
|
|
||||||
|
|
||||||
def initialize(api_key, referer = Discourse.base_url)
|
|
||||||
@api_key = api_key
|
@api_key = api_key
|
||||||
|
@embedding_url = embedding_url
|
||||||
@referer = referer
|
@referer = referer
|
||||||
end
|
end
|
||||||
|
|
||||||
attr_reader :api_key, :referer
|
attr_reader :embedding_url, :api_key, :referer
|
||||||
|
|
||||||
def perform!(content)
|
def perform!(content)
|
||||||
headers = { "Referer" => referer, "Content-Type" => "application/json" }
|
headers = { "Referer" => referer, "Content-Type" => "application/json" }
|
||||||
url =
|
url = "#{embedding_url}\?key\=#{api_key}"
|
||||||
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}"
|
|
||||||
body = { content: { parts: [{ text: content }] } }
|
body = { content: { parts: [{ text: content }] } }
|
||||||
|
|
||||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||||
@ -32,7 +28,7 @@ module ::DiscourseAi
|
|||||||
Rails.logger.warn(
|
Rails.logger.warn(
|
||||||
"Google Gemini Embeddings failed with status: #{response.status} body: #{response.body}",
|
"Google Gemini Embeddings failed with status: #{response.status} body: #{response.body}",
|
||||||
)
|
)
|
||||||
raise Net::HTTPBadResponse
|
raise Net::HTTPBadResponse.new(response.body.to_s)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -12,19 +12,6 @@ module ::DiscourseAi
|
|||||||
attr_reader :endpoint, :key, :referer
|
attr_reader :endpoint, :key, :referer
|
||||||
|
|
||||||
class << self
|
class << self
|
||||||
def instance
|
|
||||||
endpoint =
|
|
||||||
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
|
||||||
service =
|
|
||||||
DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
|
|
||||||
"https://#{service.target}:#{service.port}"
|
|
||||||
else
|
|
||||||
SiteSetting.ai_hugging_face_tei_endpoint
|
|
||||||
end
|
|
||||||
|
|
||||||
new(endpoint, SiteSetting.ai_hugging_face_tei_api_key)
|
|
||||||
end
|
|
||||||
|
|
||||||
def configured?
|
def configured?
|
||||||
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
|
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
|
||||||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||||
@ -100,7 +87,7 @@ module ::DiscourseAi
|
|||||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||||
response = conn.post(endpoint, body, headers)
|
response = conn.post(endpoint, body, headers)
|
||||||
|
|
||||||
raise Net::HTTPBadResponse if ![200].include?(response.status)
|
raise Net::HTTPBadResponse.new(response.body.to_s) if ![200].include?(response.status)
|
||||||
|
|
||||||
JSON.parse(response.body, symbolize_names: true).first
|
JSON.parse(response.body, symbolize_names: true).first
|
||||||
end
|
end
|
||||||
|
@ -12,10 +12,6 @@ module ::DiscourseAi
|
|||||||
|
|
||||||
attr_reader :endpoint, :api_key, :model, :dimensions
|
attr_reader :endpoint, :api_key, :model, :dimensions
|
||||||
|
|
||||||
def self.instance(model:, dimensions: nil)
|
|
||||||
new(SiteSetting.ai_openai_embeddings_url, SiteSetting.ai_openai_api_key, model, dimensions)
|
|
||||||
end
|
|
||||||
|
|
||||||
def perform!(content)
|
def perform!(content)
|
||||||
headers = { "Content-Type" => "application/json" }
|
headers = { "Content-Type" => "application/json" }
|
||||||
|
|
||||||
@ -29,7 +25,7 @@ module ::DiscourseAi
|
|||||||
payload[:dimensions] = dimensions if dimensions.present?
|
payload[:dimensions] = dimensions if dimensions.present?
|
||||||
|
|
||||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||||
response = conn.post(SiteSetting.ai_openai_embeddings_url, payload.to_json, headers)
|
response = conn.post(endpoint, payload.to_json, headers)
|
||||||
|
|
||||||
case response.status
|
case response.status
|
||||||
when 200
|
when 200
|
||||||
@ -40,7 +36,7 @@ module ::DiscourseAi
|
|||||||
Rails.logger.warn(
|
Rails.logger.warn(
|
||||||
"OpenAI Embeddings failed with status: #{response.status} body: #{response.body}",
|
"OpenAI Embeddings failed with status: #{response.status} body: #{response.body}",
|
||||||
)
|
)
|
||||||
raise Net::HTTPBadResponse
|
raise Net::HTTPBadResponse.new(response.body.to_s)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -35,6 +35,7 @@ register_asset "stylesheets/modules/embeddings/common/semantic-search.scss"
|
|||||||
register_asset "stylesheets/modules/sentiment/common/dashboard.scss"
|
register_asset "stylesheets/modules/sentiment/common/dashboard.scss"
|
||||||
|
|
||||||
register_asset "stylesheets/modules/llms/common/ai-llms-editor.scss"
|
register_asset "stylesheets/modules/llms/common/ai-llms-editor.scss"
|
||||||
|
register_asset "stylesheets/modules/embeddings/common/ai-embedding-editor.scss"
|
||||||
|
|
||||||
register_asset "stylesheets/modules/llms/common/usage.scss"
|
register_asset "stylesheets/modules/llms/common/usage.scss"
|
||||||
register_asset "stylesheets/modules/llms/common/spam.scss"
|
register_asset "stylesheets/modules/llms/common/spam.scss"
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
require_relative "../support/embeddings_generation_stubs"
|
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Configuration::EmbeddingsModelValidator do
|
|
||||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
|
||||||
|
|
||||||
describe "#can_generate_embeddings?" do
|
|
||||||
it "works" do
|
|
||||||
discourse_model = "all-mpnet-base-v2"
|
|
||||||
|
|
||||||
EmbeddingsGenerationStubs.discourse_service(discourse_model, "this is a test", [1] * 1024)
|
|
||||||
|
|
||||||
expect(subject.can_generate_embeddings?(discourse_model)).to eq(true)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
40
spec/fabricators/embedding_definition_fabricator.rb
Normal file
40
spec/fabricators/embedding_definition_fabricator.rb
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
Fabricator(:embedding_definition) do
|
||||||
|
display_name "Multilingual E5 Large"
|
||||||
|
provider "hugging_face"
|
||||||
|
tokenizer_class "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer"
|
||||||
|
api_key "123"
|
||||||
|
url "https://test.com/embeddings"
|
||||||
|
provider_params nil
|
||||||
|
pg_function "<=>"
|
||||||
|
max_sequence_length 512
|
||||||
|
dimensions 1024
|
||||||
|
end
|
||||||
|
|
||||||
|
Fabricator(:cloudflare_embedding_def, from: :embedding_definition) do
|
||||||
|
display_name "BGE Large EN"
|
||||||
|
provider "cloudflare"
|
||||||
|
pg_function "<#>"
|
||||||
|
tokenizer_class "DiscourseAi::Tokenizer::BgeLargeEnTokenizer"
|
||||||
|
provider_params nil
|
||||||
|
end
|
||||||
|
|
||||||
|
Fabricator(:open_ai_embedding_def, from: :embedding_definition) do
|
||||||
|
display_name "ADA 002"
|
||||||
|
provider "open_ai"
|
||||||
|
url "https://api.openai.com/v1/embeddings"
|
||||||
|
tokenizer_class "DiscourseAi::Tokenizer::OpenAiTokenizer"
|
||||||
|
provider_params { { model_name: "text-embedding-ada-002" } }
|
||||||
|
max_sequence_length 8191
|
||||||
|
dimensions 1536
|
||||||
|
end
|
||||||
|
|
||||||
|
Fabricator(:gemini_embedding_def, from: :embedding_definition) do
|
||||||
|
display_name "Gemini's embedding-001"
|
||||||
|
provider "google"
|
||||||
|
dimensions 768
|
||||||
|
max_sequence_length 1536
|
||||||
|
tokenizer_class "DiscourseAi::Tokenizer::OpenAiTokenizer"
|
||||||
|
url "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent"
|
||||||
|
end
|
@ -6,9 +6,8 @@ RSpec.describe Jobs::DigestRagUpload do
|
|||||||
|
|
||||||
let(:document_file) { StringIO.new("some text" * 200) }
|
let(:document_file) { StringIO.new("some text" * 200) }
|
||||||
|
|
||||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
fab!(:cloudflare_embedding_def)
|
||||||
|
let(:expected_embedding) { [0.0038493] * cloudflare_embedding_def.dimensions }
|
||||||
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
|
|
||||||
|
|
||||||
let(:document_with_metadata) { plugin_file_from_fixtures("doc_with_metadata.txt", "rag") }
|
let(:document_with_metadata) { plugin_file_from_fixtures("doc_with_metadata.txt", "rag") }
|
||||||
|
|
||||||
@ -21,15 +20,14 @@ RSpec.describe Jobs::DigestRagUpload do
|
|||||||
end
|
end
|
||||||
|
|
||||||
before do
|
before do
|
||||||
|
SiteSetting.ai_embeddings_selected_model = cloudflare_embedding_def.id
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
|
||||||
SiteSetting.ai_embeddings_model = "bge-large-en"
|
|
||||||
SiteSetting.authorized_extensions = "txt"
|
SiteSetting.authorized_extensions = "txt"
|
||||||
|
|
||||||
WebMock.stub_request(
|
WebMock.stub_request(:post, cloudflare_embedding_def.url).to_return(
|
||||||
:post,
|
status: 200,
|
||||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
body: JSON.dump(expected_embedding),
|
||||||
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "#execute" do
|
describe "#execute" do
|
||||||
|
@ -2,23 +2,26 @@
|
|||||||
|
|
||||||
RSpec.describe Jobs::GenerateRagEmbeddings do
|
RSpec.describe Jobs::GenerateRagEmbeddings do
|
||||||
describe "#execute" do
|
describe "#execute" do
|
||||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||||
|
|
||||||
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
|
let(:expected_embedding) { [0.0038493] * vector_def.dimensions }
|
||||||
|
|
||||||
fab!(:ai_persona)
|
fab!(:ai_persona)
|
||||||
|
|
||||||
fab!(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, target: ai_persona) }
|
let(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, target: ai_persona) }
|
||||||
fab!(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, target: ai_persona) }
|
let(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, target: ai_persona) }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
|
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
|
||||||
|
|
||||||
WebMock.stub_request(
|
rag_document_fragment_1
|
||||||
:post,
|
rag_document_fragment_2
|
||||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
|
||||||
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
WebMock.stub_request(:post, vector_def.url).to_return(
|
||||||
|
status: 200,
|
||||||
|
body: JSON.dump(expected_embedding),
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "generates a new vector for each fragment" do
|
it "generates a new vector for each fragment" do
|
||||||
|
56
spec/jobs/regular/manage_embedding_def_search_index_spec.rb
Normal file
56
spec/jobs/regular/manage_embedding_def_search_index_spec.rb
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe Jobs::ManageEmbeddingDefSearchIndex do
|
||||||
|
fab!(:embedding_definition)
|
||||||
|
|
||||||
|
describe "#execute" do
|
||||||
|
context "when there is no embedding def" do
|
||||||
|
it "does nothing" do
|
||||||
|
invalid_id = 999_999_999
|
||||||
|
|
||||||
|
subject.execute(id: invalid_id)
|
||||||
|
|
||||||
|
expect(
|
||||||
|
DiscourseAi::Embeddings::Schema.correctly_indexed?(
|
||||||
|
EmbeddingDefinition.new(id: invalid_id),
|
||||||
|
),
|
||||||
|
).to eq(false)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when the embedding def is fresh" do
|
||||||
|
it "creates the indexes" do
|
||||||
|
subject.execute(id: embedding_definition.id)
|
||||||
|
|
||||||
|
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(true)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "creates them only once" do
|
||||||
|
subject.execute(id: embedding_definition.id)
|
||||||
|
subject.execute(id: embedding_definition.id)
|
||||||
|
|
||||||
|
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(true)
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when one of the idxs is missing" do
|
||||||
|
it "automatically recovers by creating it" do
|
||||||
|
DB.exec <<~SQL
|
||||||
|
CREATE INDEX IF NOT EXISTS ai_topics_embeddings_#{embedding_definition.id}_1_search_bit ON ai_topics_embeddings
|
||||||
|
USING hnsw ((binary_quantize(embeddings)::bit(#{embedding_definition.dimensions})) bit_hamming_ops)
|
||||||
|
WHERE model_id = #{embedding_definition.id} AND strategy_id = 1;
|
||||||
|
SQL
|
||||||
|
|
||||||
|
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
|
||||||
|
subject.execute(id: embedding_definition.id)
|
||||||
|
|
||||||
|
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -19,11 +19,11 @@ RSpec.describe Jobs::EmbeddingsBackfill do
|
|||||||
topic
|
topic
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
|
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
|
||||||
SiteSetting.ai_embeddings_backfill_batch_size = 1
|
SiteSetting.ai_embeddings_backfill_batch_size = 1
|
||||||
SiteSetting.ai_embeddings_per_post_enabled = true
|
SiteSetting.ai_embeddings_per_post_enabled = true
|
||||||
Jobs.run_immediately!
|
Jobs.run_immediately!
|
||||||
@ -32,10 +32,10 @@ RSpec.describe Jobs::EmbeddingsBackfill do
|
|||||||
it "backfills topics based on bumped_at date" do
|
it "backfills topics based on bumped_at date" do
|
||||||
embedding = Array.new(1024) { 1 }
|
embedding = Array.new(1024) { 1 }
|
||||||
|
|
||||||
WebMock.stub_request(
|
WebMock.stub_request(:post, "https://test.com/embeddings").to_return(
|
||||||
:post,
|
status: 200,
|
||||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
body: JSON.dump(embedding),
|
||||||
).to_return(status: 200, body: JSON.dump(embedding))
|
)
|
||||||
|
|
||||||
Jobs::EmbeddingsBackfill.new.execute({})
|
Jobs::EmbeddingsBackfill.new.execute({})
|
||||||
|
|
||||||
|
71
spec/jobs/scheduled/remove_orphaned_embeddings_spec.rb
Normal file
71
spec/jobs/scheduled/remove_orphaned_embeddings_spec.rb
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe Jobs::RemoveOrphanedEmbeddings do
|
||||||
|
describe "#execute" do
|
||||||
|
fab!(:embedding_definition)
|
||||||
|
fab!(:embedding_definition_2) { Fabricate(:embedding_definition) }
|
||||||
|
fab!(:topic)
|
||||||
|
fab!(:post)
|
||||||
|
|
||||||
|
before do
|
||||||
|
DiscourseAi::Embeddings::Schema.prepare_search_indexes(embedding_definition)
|
||||||
|
DiscourseAi::Embeddings::Schema.prepare_search_indexes(embedding_definition_2)
|
||||||
|
|
||||||
|
# Seed embeddings. One of each def x target classes.
|
||||||
|
[embedding_definition, embedding_definition_2].each do |edef|
|
||||||
|
SiteSetting.ai_embeddings_selected_model = edef.id
|
||||||
|
|
||||||
|
[topic, post].each do |target|
|
||||||
|
schema = DiscourseAi::Embeddings::Schema.for(target.class)
|
||||||
|
schema.store(target, [1] * edef.dimensions, "test")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
embedding_definition.destroy!
|
||||||
|
end
|
||||||
|
|
||||||
|
def find_all_embeddings_of(target, table, target_column)
|
||||||
|
DB.query_single("SELECT model_id FROM #{table} WHERE #{target_column} = #{target.id}")
|
||||||
|
end
|
||||||
|
|
||||||
|
it "delete embeddings without an existing embedding definition" do
|
||||||
|
expect(find_all_embeddings_of(post, "ai_posts_embeddings", "post_id")).to contain_exactly(
|
||||||
|
embedding_definition.id,
|
||||||
|
embedding_definition_2.id,
|
||||||
|
)
|
||||||
|
expect(find_all_embeddings_of(topic, "ai_topics_embeddings", "topic_id")).to contain_exactly(
|
||||||
|
embedding_definition.id,
|
||||||
|
embedding_definition_2.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
subject.execute({})
|
||||||
|
|
||||||
|
expect(find_all_embeddings_of(topic, "ai_topics_embeddings", "topic_id")).to contain_exactly(
|
||||||
|
embedding_definition_2.id,
|
||||||
|
)
|
||||||
|
expect(find_all_embeddings_of(post, "ai_posts_embeddings", "post_id")).to contain_exactly(
|
||||||
|
embedding_definition_2.id,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "deletes orphaned indexes" do
|
||||||
|
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(true)
|
||||||
|
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition_2)).to eq(true)
|
||||||
|
|
||||||
|
subject.execute({})
|
||||||
|
|
||||||
|
index_names =
|
||||||
|
DiscourseAi::Embeddings::Schema::EMBEDDING_TARGETS.map do |t|
|
||||||
|
"ai_#{t}_embeddings_#{embedding_definition.id}_1_search_bit"
|
||||||
|
end
|
||||||
|
indexnames =
|
||||||
|
DB.query_single(
|
||||||
|
"SELECT indexname FROM pg_indexes WHERE indexname IN (:names)",
|
||||||
|
names: index_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(indexnames).to be_empty
|
||||||
|
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition_2)).to eq(true)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -1,15 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Configuration::EmbeddingsModuleValidator do
|
|
||||||
let(:validator) { described_class.new }
|
|
||||||
|
|
||||||
describe "#can_generate_embeddings?" do
|
|
||||||
it "returns true if embeddings can be generated" do
|
|
||||||
stub_request(
|
|
||||||
:post,
|
|
||||||
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent?key=",
|
|
||||||
).to_return(status: 200, body: { embedding: { values: [1, 2, 3] } }.to_json)
|
|
||||||
expect(validator.can_generate_embeddings?("gemini")).to eq(true)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
@ -4,7 +4,7 @@ require "rails_helper"
|
|||||||
require "webmock/rspec"
|
require "webmock/rspec"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Inference::CloudflareWorkersAi do
|
RSpec.describe DiscourseAi::Inference::CloudflareWorkersAi do
|
||||||
subject { described_class.new(account_id, api_token, model) }
|
subject { described_class.new(endpoint, api_token) }
|
||||||
|
|
||||||
let(:account_id) { "test_account_id" }
|
let(:account_id) { "test_account_id" }
|
||||||
let(:api_token) { "test_api_token" }
|
let(:api_token) { "test_api_token" }
|
||||||
|
@ -297,9 +297,11 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||||||
end
|
end
|
||||||
|
|
||||||
describe "#craft_prompt" do
|
describe "#craft_prompt" do
|
||||||
|
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
Group.refresh_automatic_groups!
|
Group.refresh_automatic_groups!
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -326,13 +328,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||||||
fab!(:llm_model) { Fabricate(:fake_model) }
|
fab!(:llm_model) { Fabricate(:fake_model) }
|
||||||
|
|
||||||
it "will run the question consolidator" do
|
it "will run the question consolidator" do
|
||||||
vector_def = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
|
||||||
context_embedding = vector_def.dimensions.times.map { rand(-1.0...1.0) }
|
context_embedding = vector_def.dimensions.times.map { rand(-1.0...1.0) }
|
||||||
EmbeddingsGenerationStubs.discourse_service(
|
EmbeddingsGenerationStubs.hugging_face_service(consolidated_question, context_embedding)
|
||||||
SiteSetting.ai_embeddings_model,
|
|
||||||
consolidated_question,
|
|
||||||
context_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
custom_ai_persona =
|
custom_ai_persona =
|
||||||
Fabricate(
|
Fabricate(
|
||||||
@ -373,14 +370,11 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||||||
end
|
end
|
||||||
|
|
||||||
context "when a persona has RAG uploads" do
|
context "when a persona has RAG uploads" do
|
||||||
let(:vector_def) do
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
|
||||||
end
|
|
||||||
let(:embedding_value) { 0.04381 }
|
let(:embedding_value) { 0.04381 }
|
||||||
let(:prompt_cc_embeddings) { [embedding_value] * vector_def.dimensions }
|
let(:prompt_cc_embeddings) { [embedding_value] * vector_def.dimensions }
|
||||||
|
|
||||||
def stub_fragments(fragment_count, persona: ai_persona)
|
def stub_fragments(fragment_count, persona: ai_persona)
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector_def)
|
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment)
|
||||||
|
|
||||||
fragment_count.times do |i|
|
fragment_count.times do |i|
|
||||||
fragment =
|
fragment =
|
||||||
@ -403,8 +397,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||||||
stored_ai_persona = AiPersona.find(ai_persona.id)
|
stored_ai_persona = AiPersona.find(ai_persona.id)
|
||||||
UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id])
|
UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id])
|
||||||
|
|
||||||
EmbeddingsGenerationStubs.discourse_service(
|
EmbeddingsGenerationStubs.hugging_face_service(
|
||||||
SiteSetting.ai_embeddings_model,
|
|
||||||
with_cc.dig(:conversation_context, 0, :content),
|
with_cc.dig(:conversation_context, 0, :content),
|
||||||
prompt_cc_embeddings,
|
prompt_cc_embeddings,
|
||||||
)
|
)
|
||||||
|
@ -108,16 +108,13 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
|||||||
|
|
||||||
it "supports semantic search when enabled" do
|
it "supports semantic search when enabled" do
|
||||||
assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model)
|
assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model)
|
||||||
|
vector_def = Fabricate(:embedding_definition)
|
||||||
|
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||||
SiteSetting.ai_embeddings_semantic_search_enabled = true
|
SiteSetting.ai_embeddings_semantic_search_enabled = true
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
|
||||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
|
||||||
hyde_embedding = [0.049382] * vector_rep.dimensions
|
|
||||||
|
|
||||||
EmbeddingsGenerationStubs.discourse_service(
|
hyde_embedding = [0.049382] * vector_def.dimensions
|
||||||
SiteSetting.ai_embeddings_model,
|
|
||||||
query,
|
EmbeddingsGenerationStubs.hugging_face_service(query, hyde_embedding)
|
||||||
hyde_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
post1 = Fabricate(:post, topic: topic_with_tags)
|
post1 = Fabricate(:post, topic: topic_with_tags)
|
||||||
search =
|
search =
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
|
RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
|
||||||
|
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||||
fab!(:user)
|
fab!(:user)
|
||||||
fab!(:muted_category) { Fabricate(:category) }
|
fab!(:muted_category) { Fabricate(:category) }
|
||||||
fab!(:category_mute) do
|
fab!(:category_mute) do
|
||||||
@ -19,14 +20,13 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
|
|||||||
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
|
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
|
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
|
||||||
SiteSetting.ai_embeddings_model = "bge-large-en"
|
|
||||||
|
|
||||||
WebMock.stub_request(
|
WebMock.stub_request(:post, vector_def.url).to_return(
|
||||||
:post,
|
status: 200,
|
||||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
body: JSON.dump([expected_embedding]),
|
||||||
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
)
|
||||||
|
|
||||||
vector.generate_representation_from(topic)
|
vector.generate_representation_from(topic)
|
||||||
vector.generate_representation_from(muted_topic)
|
vector.generate_representation_from(muted_topic)
|
||||||
|
@ -3,25 +3,26 @@
|
|||||||
RSpec.describe Jobs::GenerateEmbeddings do
|
RSpec.describe Jobs::GenerateEmbeddings do
|
||||||
subject(:job) { described_class.new }
|
subject(:job) { described_class.new }
|
||||||
|
|
||||||
|
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||||
|
|
||||||
describe "#execute" do
|
describe "#execute" do
|
||||||
before do
|
before do
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
end
|
end
|
||||||
|
|
||||||
fab!(:topic)
|
fab!(:topic)
|
||||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||||
|
|
||||||
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic) }
|
||||||
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector_def) }
|
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post) }
|
||||||
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vector_def) }
|
|
||||||
|
|
||||||
it "works for topics" do
|
it "works for topics" do
|
||||||
expected_embedding = [0.0038493] * vector_def.dimensions
|
expected_embedding = [0.0038493] * vector_def.dimensions
|
||||||
|
|
||||||
text = vector_def.prepare_target_text(topic)
|
text = vector_def.prepare_target_text(topic)
|
||||||
|
|
||||||
EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding)
|
EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding)
|
||||||
|
|
||||||
job.execute(target_id: topic.id, target_type: "Topic")
|
job.execute(target_id: topic.id, target_type: "Topic")
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
|||||||
expected_embedding = [0.0038493] * vector_def.dimensions
|
expected_embedding = [0.0038493] * vector_def.dimensions
|
||||||
|
|
||||||
text = vector_def.prepare_target_text(post)
|
text = vector_def.prepare_target_text(post)
|
||||||
EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding)
|
EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding)
|
||||||
|
|
||||||
job.execute(target_id: post.id, target_type: "Post")
|
job.execute(target_id: post.id, target_type: "Post")
|
||||||
|
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Embeddings::Schema do
|
RSpec.describe DiscourseAi::Embeddings::Schema do
|
||||||
subject(:posts_schema) { described_class.for(Post, vector_def: vector_def) }
|
subject(:posts_schema) { described_class.for(Post) }
|
||||||
|
|
||||||
|
fab!(:vector_def) { Fabricate(:cloudflare_embedding_def) }
|
||||||
let(:embeddings) { [0.0038490295] * vector_def.dimensions }
|
let(:embeddings) { [0.0038490295] * vector_def.dimensions }
|
||||||
fab!(:post) { Fabricate(:post, post_number: 1) }
|
fab!(:post) { Fabricate(:post, post_number: 1) }
|
||||||
let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") }
|
let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") }
|
||||||
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new }
|
|
||||||
|
before { SiteSetting.ai_embeddings_selected_model = vector_def.id }
|
||||||
|
|
||||||
before { posts_schema.store(post, embeddings, digest) }
|
before { posts_schema.store(post, embeddings, digest) }
|
||||||
|
|
||||||
|
@ -13,7 +13,13 @@ describe DiscourseAi::Embeddings::SemanticRelated do
|
|||||||
fab!(:secured_category_topic) { Fabricate(:topic, category: secured_category) }
|
fab!(:secured_category_topic) { Fabricate(:topic, category: secured_category) }
|
||||||
fab!(:closed_topic) { Fabricate(:topic, closed: true) }
|
fab!(:closed_topic) { Fabricate(:topic, closed: true) }
|
||||||
|
|
||||||
before { SiteSetting.ai_embeddings_semantic_related_topics_enabled = true }
|
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||||
|
|
||||||
|
before do
|
||||||
|
SiteSetting.ai_embeddings_semantic_related_topics_enabled = true
|
||||||
|
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||||
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
|
end
|
||||||
|
|
||||||
describe "#related_topic_ids_for" do
|
describe "#related_topic_ids_for" do
|
||||||
context "when embeddings do not exist" do
|
context "when embeddings do not exist" do
|
||||||
@ -24,21 +30,15 @@ describe DiscourseAi::Embeddings::SemanticRelated do
|
|||||||
topic
|
topic
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:vector_rep) do
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
|
||||||
end
|
|
||||||
|
|
||||||
it "properly generates embeddings if missing" do
|
it "properly generates embeddings if missing" do
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
|
||||||
Jobs.run_immediately!
|
Jobs.run_immediately!
|
||||||
|
|
||||||
embedding = Array.new(1024) { 1 }
|
embedding = Array.new(1024) { 1 }
|
||||||
|
|
||||||
WebMock.stub_request(
|
WebMock.stub_request(:post, vector_def.url).to_return(
|
||||||
:post,
|
status: 200,
|
||||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
body: JSON.dump([embedding]),
|
||||||
).to_return(status: 200, body: JSON.dump(embedding))
|
)
|
||||||
|
|
||||||
# miss first
|
# miss first
|
||||||
ids = semantic_related.related_topic_ids_for(topic)
|
ids = semantic_related.related_topic_ids_for(topic)
|
||||||
|
@ -7,22 +7,18 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
|||||||
let(:query) { "test_query" }
|
let(:query) { "test_query" }
|
||||||
let(:subject) { described_class.new(Guardian.new(user)) }
|
let(:subject) { described_class.new(Guardian.new(user)) }
|
||||||
|
|
||||||
before { assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model) }
|
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||||
|
|
||||||
|
before do
|
||||||
|
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||||
|
assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model)
|
||||||
|
end
|
||||||
|
|
||||||
describe "#search_for_topics" do
|
describe "#search_for_topics" do
|
||||||
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
|
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
|
||||||
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
|
||||||
let(:hyde_embedding) { [0.049382] * vector_def.dimensions }
|
let(:hyde_embedding) { [0.049382] * vector_def.dimensions }
|
||||||
|
|
||||||
before do
|
before { EmbeddingsGenerationStubs.hugging_face_service(hypothetical_post, hyde_embedding) }
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
|
||||||
|
|
||||||
EmbeddingsGenerationStubs.discourse_service(
|
|
||||||
SiteSetting.ai_embeddings_model,
|
|
||||||
hypothetical_post,
|
|
||||||
hyde_embedding,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
after { described_class.clear_cache_for(query) }
|
after { described_class.clear_cache_for(query) }
|
||||||
|
|
||||||
|
@ -9,6 +9,10 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
|||||||
|
|
||||||
fab!(:target) { Fabricate(:topic) }
|
fab!(:target) { Fabricate(:topic) }
|
||||||
|
|
||||||
|
fab!(:vector_def) { Fabricate(:cloudflare_embedding_def) }
|
||||||
|
|
||||||
|
before { SiteSetting.ai_embeddings_selected_model = vector_def.id }
|
||||||
|
|
||||||
# The Distance gap to target increases for each element of topics.
|
# The Distance gap to target increases for each element of topics.
|
||||||
def seed_embeddings(topics)
|
def seed_embeddings(topics)
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(Topic)
|
schema = DiscourseAi::Embeddings::Schema.for(Topic)
|
||||||
|
@ -19,13 +19,13 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
|||||||
)
|
)
|
||||||
end
|
end
|
||||||
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
|
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
|
||||||
|
fab!(:open_ai_embedding_def)
|
||||||
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new }
|
|
||||||
|
|
||||||
it "truncates a topic" do
|
it "truncates a topic" do
|
||||||
prepared_text = truncation.prepare_target_text(topic, vector_def)
|
prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)
|
||||||
|
|
||||||
expect(vector_def.tokenizer.size(prepared_text)).to be <= vector_def.max_sequence_length
|
expect(open_ai_embedding_def.tokenizer.size(prepared_text)).to be <=
|
||||||
|
open_ai_embedding_def.max_sequence_length
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -7,8 +7,10 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
|||||||
let(:expected_embedding_1) { [0.0038493] * vdef.dimensions }
|
let(:expected_embedding_1) { [0.0038493] * vdef.dimensions }
|
||||||
let(:expected_embedding_2) { [0.0037684] * vdef.dimensions }
|
let(:expected_embedding_2) { [0.0037684] * vdef.dimensions }
|
||||||
|
|
||||||
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vdef) }
|
before { SiteSetting.ai_embeddings_selected_model = vdef.id }
|
||||||
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vdef) }
|
|
||||||
|
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic) }
|
||||||
|
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post) }
|
||||||
|
|
||||||
fab!(:topic)
|
fab!(:topic)
|
||||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||||
@ -84,63 +86,16 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
context "with text-embedding-ada-002" do
|
context "with open_ai as the provider" do
|
||||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new }
|
fab!(:vdef) { Fabricate(:open_ai_embedding_def) }
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
|
||||||
EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding)
|
|
||||||
end
|
|
||||||
|
|
||||||
it_behaves_like "generates and store embeddings using a vector definition"
|
|
||||||
end
|
|
||||||
|
|
||||||
context "with all all-mpnet-base-v2" do
|
|
||||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new }
|
|
||||||
|
|
||||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
|
||||||
EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding)
|
|
||||||
end
|
|
||||||
|
|
||||||
it_behaves_like "generates and store embeddings using a vector definition"
|
|
||||||
end
|
|
||||||
|
|
||||||
context "with gemini" do
|
|
||||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::Gemini.new }
|
|
||||||
let(:api_key) { "test-123" }
|
|
||||||
|
|
||||||
before { SiteSetting.ai_gemini_api_key = api_key }
|
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
|
||||||
EmbeddingsGenerationStubs.gemini_service(api_key, text, expected_embedding)
|
|
||||||
end
|
|
||||||
|
|
||||||
it_behaves_like "generates and store embeddings using a vector definition"
|
|
||||||
end
|
|
||||||
|
|
||||||
context "with multilingual-e5-large" do
|
|
||||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large.new }
|
|
||||||
|
|
||||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
|
||||||
EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding)
|
|
||||||
end
|
|
||||||
|
|
||||||
it_behaves_like "generates and store embeddings using a vector definition"
|
|
||||||
end
|
|
||||||
|
|
||||||
context "with text-embedding-3-large" do
|
|
||||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large.new }
|
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
def stub_vector_mapping(text, expected_embedding)
|
||||||
EmbeddingsGenerationStubs.openai_service(
|
EmbeddingsGenerationStubs.openai_service(
|
||||||
vdef.class.name,
|
vdef.lookup_custom_param("model_name"),
|
||||||
text,
|
text,
|
||||||
expected_embedding,
|
expected_embedding,
|
||||||
extra_args: {
|
extra_args: {
|
||||||
dimensions: 2000,
|
dimensions: vdef.dimensions,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
@ -148,11 +103,31 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
|||||||
it_behaves_like "generates and store embeddings using a vector definition"
|
it_behaves_like "generates and store embeddings using a vector definition"
|
||||||
end
|
end
|
||||||
|
|
||||||
context "with text-embedding-3-small" do
|
context "with hugging_face as the provider" do
|
||||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small.new }
|
fab!(:vdef) { Fabricate(:embedding_definition) }
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
def stub_vector_mapping(text, expected_embedding)
|
||||||
EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding)
|
EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding)
|
||||||
|
end
|
||||||
|
|
||||||
|
it_behaves_like "generates and store embeddings using a vector definition"
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with google as the provider" do
|
||||||
|
fab!(:vdef) { Fabricate(:gemini_embedding_def) }
|
||||||
|
|
||||||
|
def stub_vector_mapping(text, expected_embedding)
|
||||||
|
EmbeddingsGenerationStubs.gemini_service(vdef.api_key, text, expected_embedding)
|
||||||
|
end
|
||||||
|
|
||||||
|
it_behaves_like "generates and store embeddings using a vector definition"
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with cloudflare as the provider" do
|
||||||
|
fab!(:vdef) { Fabricate(:cloudflare_embedding_def) }
|
||||||
|
|
||||||
|
def stub_vector_mapping(text, expected_embedding)
|
||||||
|
EmbeddingsGenerationStubs.cloudflare_service(text, expected_embedding)
|
||||||
end
|
end
|
||||||
|
|
||||||
it_behaves_like "generates and store embeddings using a vector definition"
|
it_behaves_like "generates and store embeddings using a vector definition"
|
||||||
|
@ -204,12 +204,12 @@ RSpec.describe AiTool do
|
|||||||
end
|
end
|
||||||
|
|
||||||
context "when defining RAG fragments" do
|
context "when defining RAG fragments" do
|
||||||
|
fab!(:cloudflare_embedding_def)
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.authorized_extensions = "txt"
|
SiteSetting.authorized_extensions = "txt"
|
||||||
|
SiteSetting.ai_embeddings_selected_model = cloudflare_embedding_def.id
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
|
||||||
SiteSetting.ai_embeddings_model = "bge-large-en"
|
|
||||||
|
|
||||||
Jobs.run_immediately!
|
Jobs.run_immediately!
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -228,9 +228,9 @@ RSpec.describe AiTool do
|
|||||||
# this is a trick, we get ever increasing embeddings, this gives us in turn
|
# this is a trick, we get ever increasing embeddings, this gives us in turn
|
||||||
# 100% consistent search results
|
# 100% consistent search results
|
||||||
@counter = 0
|
@counter = 0
|
||||||
stub_request(:post, "http://test.com/api/v1/classify").to_return(
|
stub_request(:post, cloudflare_embedding_def.url).to_return(
|
||||||
status: 200,
|
status: 200,
|
||||||
body: lambda { |req| ([@counter += 1] * 1024).to_json },
|
body: lambda { |req| { result: { data: [([@counter += 1] * 1024)] } }.to_json },
|
||||||
headers: {
|
headers: {
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -4,11 +4,9 @@ RSpec.describe RagDocumentFragment do
|
|||||||
fab!(:persona) { Fabricate(:ai_persona) }
|
fab!(:persona) { Fabricate(:ai_persona) }
|
||||||
fab!(:upload_1) { Fabricate(:upload) }
|
fab!(:upload_1) { Fabricate(:upload) }
|
||||||
fab!(:upload_2) { Fabricate(:upload) }
|
fab!(:upload_2) { Fabricate(:upload) }
|
||||||
|
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||||
|
|
||||||
before do
|
before { SiteSetting.ai_embeddings_enabled = true }
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
|
||||||
end
|
|
||||||
|
|
||||||
describe ".link_uploads_and_persona" do
|
describe ".link_uploads_and_persona" do
|
||||||
it "does nothing if there is no persona" do
|
it "does nothing if there is no persona" do
|
||||||
@ -76,25 +74,25 @@ RSpec.describe RagDocumentFragment do
|
|||||||
describe ".indexing_status" do
|
describe ".indexing_status" do
|
||||||
let(:vector) { DiscourseAi::Embeddings::Vector.instance }
|
let(:vector) { DiscourseAi::Embeddings::Vector.instance }
|
||||||
|
|
||||||
fab!(:rag_document_fragment_1) do
|
let(:rag_document_fragment_1) do
|
||||||
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
||||||
end
|
end
|
||||||
|
|
||||||
fab!(:rag_document_fragment_2) do
|
let(:rag_document_fragment_2) do
|
||||||
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
|
let(:expected_embedding) { [0.0038493] * vector_def.dimensions }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
rag_document_fragment_1
|
||||||
SiteSetting.ai_embeddings_model = "bge-large-en"
|
rag_document_fragment_2
|
||||||
|
|
||||||
WebMock.stub_request(
|
WebMock.stub_request(:post, "https://test.com/embeddings").to_return(
|
||||||
:post,
|
status: 200,
|
||||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
body: JSON.dump(expected_embedding),
|
||||||
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
)
|
||||||
|
|
||||||
vector.generate_representation_from(rag_document_fragment_1)
|
vector.generate_representation_from(rag_document_fragment_1)
|
||||||
end
|
end
|
||||||
@ -106,7 +104,7 @@ RSpec.describe RagDocumentFragment do
|
|||||||
UploadReference.create!(upload_id: upload_2.id, target: persona)
|
UploadReference.create!(upload_id: upload_2.id, target: persona)
|
||||||
|
|
||||||
Sidekiq::Testing.fake! do
|
Sidekiq::Testing.fake! do
|
||||||
SiteSetting.ai_embeddings_model = "all-mpnet-base-v2"
|
SiteSetting.ai_embeddings_selected_model = Fabricate(:open_ai_embedding_def).id
|
||||||
expect(RagDocumentFragment.exists?(old_id)).to eq(false)
|
expect(RagDocumentFragment.exists?(old_id)).to eq(false)
|
||||||
expect(Jobs::DigestRagUpload.jobs.size).to eq(2)
|
expect(Jobs::DigestRagUpload.jobs.size).to eq(2)
|
||||||
end
|
end
|
||||||
|
184
spec/requests/admin/ai_embeddings_controller_spec.rb
Normal file
184
spec/requests/admin/ai_embeddings_controller_spec.rb
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do
|
||||||
|
fab!(:admin)
|
||||||
|
|
||||||
|
before { sign_in(admin) }
|
||||||
|
|
||||||
|
let(:valid_attrs) do
|
||||||
|
{
|
||||||
|
display_name: "Embedding config test",
|
||||||
|
dimensions: 1001,
|
||||||
|
max_sequence_length: 234,
|
||||||
|
pg_function: "<#>",
|
||||||
|
provider: "hugging_face",
|
||||||
|
url: "https://test.com/api/v1/embeddings",
|
||||||
|
api_key: "test",
|
||||||
|
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "POST #create" do
|
||||||
|
context "with valid attrs" do
|
||||||
|
it "creates a new embedding definition" do
|
||||||
|
post "/admin/plugins/discourse-ai/ai-embeddings.json", params: { ai_embedding: valid_attrs }
|
||||||
|
|
||||||
|
created_def = EmbeddingDefinition.last
|
||||||
|
|
||||||
|
expect(response.status).to eq(201)
|
||||||
|
expect(created_def.display_name).to eq(valid_attrs[:display_name])
|
||||||
|
end
|
||||||
|
|
||||||
|
it "stores provider-specific config params" do
|
||||||
|
post "/admin/plugins/discourse-ai/ai-embeddings.json",
|
||||||
|
params: {
|
||||||
|
ai_embedding:
|
||||||
|
valid_attrs.merge(
|
||||||
|
provider: "open_ai",
|
||||||
|
provider_params: {
|
||||||
|
model_name: "embeddings-v1",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
created_def = EmbeddingDefinition.last
|
||||||
|
|
||||||
|
expect(response.status).to eq(201)
|
||||||
|
expect(created_def.provider_params["model_name"]).to eq("embeddings-v1")
|
||||||
|
end
|
||||||
|
|
||||||
|
it "ignores parameters not associated with that provider" do
|
||||||
|
post "/admin/plugins/discourse-ai/ai-embeddings.json",
|
||||||
|
params: {
|
||||||
|
ai_embedding: valid_attrs.merge(provider_params: { custom: "custom" }),
|
||||||
|
}
|
||||||
|
|
||||||
|
created_def = EmbeddingDefinition.last
|
||||||
|
|
||||||
|
expect(response.status).to eq(201)
|
||||||
|
expect(created_def.lookup_custom_param("custom")).to be_nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with invalid attrs" do
|
||||||
|
it "doesn't create a new embedding defitinion" do
|
||||||
|
post "/admin/plugins/discourse-ai/ai-embeddings.json",
|
||||||
|
params: {
|
||||||
|
ai_embedding: valid_attrs.except(:provider),
|
||||||
|
}
|
||||||
|
|
||||||
|
created_def = EmbeddingDefinition.last
|
||||||
|
|
||||||
|
expect(created_def).to be_nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "PUT #update" do
|
||||||
|
fab!(:embedding_definition)
|
||||||
|
|
||||||
|
context "with valid update params" do
|
||||||
|
let(:update_attrs) { { provider: "open_ai" } }
|
||||||
|
|
||||||
|
it "updates the model" do
|
||||||
|
put "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json",
|
||||||
|
params: {
|
||||||
|
ai_embedding: update_attrs,
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(response.status).to eq(200)
|
||||||
|
expect(embedding_definition.reload.provider).to eq(update_attrs[:provider])
|
||||||
|
end
|
||||||
|
|
||||||
|
it "returns a 404 if there is no model with the given Id" do
|
||||||
|
put "/admin/plugins/discourse-ai/ai-embeddings/9999999.json"
|
||||||
|
|
||||||
|
expect(response.status).to eq(404)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "doesn't allow dimenstions to be updated" do
|
||||||
|
new_dimensions = 200
|
||||||
|
|
||||||
|
put "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json",
|
||||||
|
params: {
|
||||||
|
ai_embedding: {
|
||||||
|
dimensions: new_dimensions,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(response.status).to eq(200)
|
||||||
|
expect(embedding_definition.reload.dimensions).not_to eq(new_dimensions)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with invalid update params" do
|
||||||
|
it "doesn't update the model" do
|
||||||
|
put "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json",
|
||||||
|
params: {
|
||||||
|
ai_embedding: {
|
||||||
|
url: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(response.status).to eq(422)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "DELETE #destroy" do
|
||||||
|
fab!(:embedding_definition)
|
||||||
|
|
||||||
|
it "destroys the embedding defitinion" do
|
||||||
|
expect {
|
||||||
|
delete "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json"
|
||||||
|
|
||||||
|
expect(response).to have_http_status(:no_content)
|
||||||
|
}.to change(EmbeddingDefinition, :count).by(-1)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "validates the model is not in use" do
|
||||||
|
SiteSetting.ai_embeddings_selected_model = embedding_definition.id
|
||||||
|
|
||||||
|
delete "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json"
|
||||||
|
|
||||||
|
expect(response.status).to eq(409)
|
||||||
|
expect(embedding_definition.reload).to eq(embedding_definition)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "GET #test" do
|
||||||
|
context "when we can generate an embedding" do
|
||||||
|
it "returns a success true flag" do
|
||||||
|
WebMock.stub_request(:post, valid_attrs[:url]).to_return(status: 200, body: [[1]].to_json)
|
||||||
|
|
||||||
|
get "/admin/plugins/discourse-ai/ai-embeddings/test.json",
|
||||||
|
params: {
|
||||||
|
ai_embedding: valid_attrs,
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(response).to be_successful
|
||||||
|
expect(response.parsed_body["success"]).to eq(true)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when we cannot generate an embedding" do
|
||||||
|
it "returns a success false flag and the error message" do
|
||||||
|
error_message = { error: "Embedding generation failed." }
|
||||||
|
|
||||||
|
WebMock.stub_request(:post, valid_attrs[:url]).to_return(
|
||||||
|
status: 422,
|
||||||
|
body: error_message.to_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
get "/admin/plugins/discourse-ai/ai-embeddings/test.json",
|
||||||
|
params: {
|
||||||
|
ai_embedding: valid_attrs,
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(response).to be_successful
|
||||||
|
expect(response.parsed_body["success"]).to eq(false)
|
||||||
|
expect(response.parsed_body["error"]).to eq(error_message.to_json)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -8,7 +8,6 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||||||
sign_in(admin)
|
sign_in(admin)
|
||||||
|
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "GET #index" do
|
describe "GET #index" do
|
||||||
|
@ -4,11 +4,12 @@ RSpec.describe DiscourseAi::Admin::RagDocumentFragmentsController do
|
|||||||
fab!(:admin)
|
fab!(:admin)
|
||||||
fab!(:ai_persona)
|
fab!(:ai_persona)
|
||||||
|
|
||||||
|
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
sign_in(admin)
|
sign_in(admin)
|
||||||
|
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "GET #indexing_status_check" do
|
describe "GET #indexing_status_check" do
|
||||||
|
@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
describe DiscourseAi::Embeddings::EmbeddingsController do
|
describe DiscourseAi::Embeddings::EmbeddingsController do
|
||||||
context "when performing a topic search" do
|
context "when performing a topic search" do
|
||||||
|
fab!(:vector_def) { Fabricate(:open_ai_embedding_def) }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.min_search_term_length = 3
|
SiteSetting.min_search_term_length = 3
|
||||||
SiteSetting.ai_embeddings_model = "text-embedding-3-small"
|
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||||
DiscourseAi::Embeddings::SemanticSearch.clear_cache_for("test")
|
DiscourseAi::Embeddings::SemanticSearch.clear_cache_for("test")
|
||||||
SearchIndexer.enable
|
SearchIndexer.enable
|
||||||
end
|
end
|
||||||
@ -31,7 +33,15 @@ describe DiscourseAi::Embeddings::EmbeddingsController do
|
|||||||
|
|
||||||
def stub_embedding(query)
|
def stub_embedding(query)
|
||||||
embedding = [0.049382] * 1536
|
embedding = [0.049382] * 1536
|
||||||
EmbeddingsGenerationStubs.openai_service(SiteSetting.ai_embeddings_model, query, embedding)
|
|
||||||
|
EmbeddingsGenerationStubs.openai_service(
|
||||||
|
vector_def.lookup_custom_param("model_name"),
|
||||||
|
query,
|
||||||
|
embedding,
|
||||||
|
extra_args: {
|
||||||
|
dimensions: vector_def.dimensions,
|
||||||
|
},
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
def create_api_key(user)
|
def create_api_key(user)
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
describe DiscourseAi::Inference::OpenAiEmbeddings do
|
describe DiscourseAi::Inference::OpenAiEmbeddings do
|
||||||
|
let(:api_key) { "123456" }
|
||||||
|
let(:dimensions) { 1000 }
|
||||||
|
let(:model) { "text-embedding-ada-002" }
|
||||||
|
|
||||||
it "supports azure embeddings" do
|
it "supports azure embeddings" do
|
||||||
SiteSetting.ai_openai_embeddings_url =
|
azure_url =
|
||||||
"https://my-company.openai.azure.com/openai/deployments/embeddings-deployment/embeddings?api-version=2023-05-15"
|
"https://my-company.openai.azure.com/openai/deployments/embeddings-deployment/embeddings?api-version=2023-05-15"
|
||||||
SiteSetting.ai_openai_api_key = "123456"
|
|
||||||
|
|
||||||
body_json = {
|
body_json = {
|
||||||
usage: {
|
usage: {
|
||||||
@ -14,28 +17,22 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
|
|||||||
data: [{ object: "embedding", embedding: [0.0, 0.1] }],
|
data: [{ object: "embedding", embedding: [0.0, 0.1] }],
|
||||||
}.to_json
|
}.to_json
|
||||||
|
|
||||||
stub_request(
|
stub_request(:post, azure_url).with(
|
||||||
:post,
|
|
||||||
"https://my-company.openai.azure.com/openai/deployments/embeddings-deployment/embeddings?api-version=2023-05-15",
|
|
||||||
).with(
|
|
||||||
body: "{\"model\":\"text-embedding-ada-002\",\"input\":\"hello\"}",
|
body: "{\"model\":\"text-embedding-ada-002\",\"input\":\"hello\"}",
|
||||||
headers: {
|
headers: {
|
||||||
"Api-Key" => "123456",
|
"Api-Key" => api_key,
|
||||||
"Content-Type" => "application/json",
|
"Content-Type" => "application/json",
|
||||||
},
|
},
|
||||||
).to_return(status: 200, body: body_json, headers: {})
|
).to_return(status: 200, body: body_json, headers: {})
|
||||||
|
|
||||||
result =
|
result =
|
||||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: "text-embedding-ada-002").perform!(
|
DiscourseAi::Inference::OpenAiEmbeddings.new(azure_url, api_key, model, nil).perform!("hello")
|
||||||
"hello",
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(result).to eq([0.0, 0.1])
|
expect(result).to eq([0.0, 0.1])
|
||||||
end
|
end
|
||||||
|
|
||||||
it "supports openai embeddings" do
|
it "supports openai embeddings" do
|
||||||
SiteSetting.ai_openai_api_key = "123456"
|
url = "https://api.openai.com/v1/embeddings"
|
||||||
|
|
||||||
body_json = {
|
body_json = {
|
||||||
usage: {
|
usage: {
|
||||||
prompt_tokens: 1,
|
prompt_tokens: 1,
|
||||||
@ -44,21 +41,20 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
|
|||||||
data: [{ object: "embedding", embedding: [0.0, 0.1] }],
|
data: [{ object: "embedding", embedding: [0.0, 0.1] }],
|
||||||
}.to_json
|
}.to_json
|
||||||
|
|
||||||
body = { model: "text-embedding-ada-002", input: "hello", dimensions: 1000 }.to_json
|
body = { model: model, input: "hello", dimensions: dimensions }.to_json
|
||||||
|
|
||||||
stub_request(:post, "https://api.openai.com/v1/embeddings").with(
|
stub_request(:post, url).with(
|
||||||
body: body,
|
body: body,
|
||||||
headers: {
|
headers: {
|
||||||
"Authorization" => "Bearer 123456",
|
"Authorization" => "Bearer #{api_key}",
|
||||||
"Content-Type" => "application/json",
|
"Content-Type" => "application/json",
|
||||||
},
|
},
|
||||||
).to_return(status: 200, body: body_json, headers: {})
|
).to_return(status: 200, body: body_json, headers: {})
|
||||||
|
|
||||||
result =
|
result =
|
||||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(
|
DiscourseAi::Inference::OpenAiEmbeddings.new(url, api_key, model, dimensions).perform!(
|
||||||
model: "text-embedding-ada-002",
|
"hello",
|
||||||
dimensions: 1000,
|
)
|
||||||
).perform!("hello")
|
|
||||||
|
|
||||||
expect(result).to eq([0.0, 0.1])
|
expect(result).to eq([0.0, 0.1])
|
||||||
end
|
end
|
||||||
|
@ -2,15 +2,11 @@
|
|||||||
|
|
||||||
class EmbeddingsGenerationStubs
|
class EmbeddingsGenerationStubs
|
||||||
class << self
|
class << self
|
||||||
def discourse_service(model, string, embedding)
|
def hugging_face_service(string, embedding)
|
||||||
model = "bge-large-en-v1.5" if model == "bge-large-en"
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(
|
.stub_request(:post, "https://test.com/embeddings")
|
||||||
:post,
|
.with(body: JSON.dump({ inputs: string, truncate: true }))
|
||||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
.to_return(status: 200, body: JSON.dump([embedding]))
|
||||||
)
|
|
||||||
.with(body: JSON.dump({ model: model, content: string }))
|
|
||||||
.to_return(status: 200, body: JSON.dump(embedding))
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def openai_service(model, string, embedding, extra_args: {})
|
def openai_service(model, string, embedding, extra_args: {})
|
||||||
@ -29,5 +25,12 @@ class EmbeddingsGenerationStubs
|
|||||||
.with(body: JSON.dump({ content: { parts: [{ text: string }] } }))
|
.with(body: JSON.dump({ content: { parts: [{ text: string }] } }))
|
||||||
.to_return(status: 200, body: JSON.dump({ embedding: { values: embedding } }))
|
.to_return(status: 200, body: JSON.dump({ embedding: { values: embedding } }))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def cloudflare_service(string, embedding)
|
||||||
|
WebMock
|
||||||
|
.stub_request(:post, "https://test.com/embeddings")
|
||||||
|
.with(body: JSON.dump({ text: [string] }))
|
||||||
|
.to_return(status: 200, body: JSON.dump({ result: { data: [embedding] } }))
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
87
spec/system/embeddings/ai_embedding_definition_spec.rb
Normal file
87
spec/system/embeddings/ai_embedding_definition_spec.rb
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe "Managing Embeddings configurations", type: :system, js: true do
|
||||||
|
fab!(:admin)
|
||||||
|
let(:page_header) { PageObjects::Components::DPageHeader.new }
|
||||||
|
|
||||||
|
before { sign_in(admin) }
|
||||||
|
|
||||||
|
it "correctly sets defaults" do
|
||||||
|
preset = "text-embedding-3-small"
|
||||||
|
api_key = "abcd"
|
||||||
|
|
||||||
|
visit "/admin/plugins/discourse-ai/ai-embeddings"
|
||||||
|
|
||||||
|
find(".ai-embeddings-list-editor__new-button").click()
|
||||||
|
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__presets")
|
||||||
|
select_kit.expand
|
||||||
|
select_kit.select_row_by_value(preset)
|
||||||
|
find(".ai-embedding-editor__next").click
|
||||||
|
find("input.ai-embedding-editor__api-key").fill_in(with: api_key)
|
||||||
|
find(".ai-embedding-editor__save").click()
|
||||||
|
|
||||||
|
expect(page).to have_current_path("/admin/plugins/discourse-ai/ai-embeddings")
|
||||||
|
|
||||||
|
embedding_def = EmbeddingDefinition.order(:id).last
|
||||||
|
expect(embedding_def.api_key).to eq(api_key)
|
||||||
|
|
||||||
|
preset = EmbeddingDefinition.presets.find { |p| p[:preset_id] == preset }
|
||||||
|
|
||||||
|
expect(embedding_def.display_name).to eq(preset[:display_name])
|
||||||
|
expect(embedding_def.url).to eq(preset[:url])
|
||||||
|
expect(embedding_def.tokenizer_class).to eq(preset[:tokenizer_class])
|
||||||
|
expect(embedding_def.dimensions).to eq(preset[:dimensions])
|
||||||
|
expect(embedding_def.max_sequence_length).to eq(preset[:max_sequence_length])
|
||||||
|
expect(embedding_def.pg_function).to eq(preset[:pg_function])
|
||||||
|
expect(embedding_def.provider).to eq(preset[:provider])
|
||||||
|
expect(embedding_def.provider_params.symbolize_keys).to eq(preset[:provider_params])
|
||||||
|
end
|
||||||
|
|
||||||
|
it "supports manual config" do
|
||||||
|
api_key = "abcd"
|
||||||
|
|
||||||
|
visit "/admin/plugins/discourse-ai/ai-embeddings"
|
||||||
|
|
||||||
|
find(".ai-embeddings-list-editor__new-button").click()
|
||||||
|
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__presets")
|
||||||
|
select_kit.expand
|
||||||
|
select_kit.select_row_by_value("manual")
|
||||||
|
find(".ai-embedding-editor__next").click
|
||||||
|
|
||||||
|
find("input.ai-embedding-editor__display-name").fill_in(with: "OpenAI's text-embedding-3-small")
|
||||||
|
|
||||||
|
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__provider")
|
||||||
|
select_kit.expand
|
||||||
|
select_kit.select_row_by_value(EmbeddingDefinition::OPEN_AI)
|
||||||
|
|
||||||
|
find("input.ai-embedding-editor__url").fill_in(with: "https://api.openai.com/v1/embeddings")
|
||||||
|
find("input.ai-embedding-editor__api-key").fill_in(with: api_key)
|
||||||
|
|
||||||
|
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__tokenizer")
|
||||||
|
select_kit.expand
|
||||||
|
select_kit.select_row_by_value("DiscourseAi::Tokenizer::OpenAiTokenizer")
|
||||||
|
|
||||||
|
find("input.ai-embedding-editor__dimensions").fill_in(with: 1536)
|
||||||
|
find("input.ai-embedding-editor__max_sequence_length").fill_in(with: 8191)
|
||||||
|
|
||||||
|
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__distance_functions")
|
||||||
|
select_kit.expand
|
||||||
|
select_kit.select_row_by_value("<=>")
|
||||||
|
find(".ai-embedding-editor__save").click()
|
||||||
|
|
||||||
|
expect(page).to have_current_path("/admin/plugins/discourse-ai/ai-embeddings")
|
||||||
|
|
||||||
|
embedding_def = EmbeddingDefinition.order(:id).last
|
||||||
|
expect(embedding_def.api_key).to eq(api_key)
|
||||||
|
|
||||||
|
preset = EmbeddingDefinition.presets.find { |p| p[:preset_id] == "text-embedding-3-small" }
|
||||||
|
|
||||||
|
expect(embedding_def.display_name).to eq(preset[:display_name])
|
||||||
|
expect(embedding_def.url).to eq(preset[:url])
|
||||||
|
expect(embedding_def.tokenizer_class).to eq(preset[:tokenizer_class])
|
||||||
|
expect(embedding_def.dimensions).to eq(preset[:dimensions])
|
||||||
|
expect(embedding_def.max_sequence_length).to eq(preset[:max_sequence_length])
|
||||||
|
expect(embedding_def.pg_function).to eq(preset[:pg_function])
|
||||||
|
expect(embedding_def.provider).to eq(preset[:provider])
|
||||||
|
end
|
||||||
|
end
|
@ -10,7 +10,6 @@ RSpec.describe "AI Composer helper", type: :system, js: true do
|
|||||||
fab!(:post) { Fabricate(:post, topic: topic, raw: "Apple pie is a delicious dessert to eat") }
|
fab!(:post) { Fabricate(:post, topic: topic, raw: "Apple pie is a delicious dessert to eat") }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
|
||||||
prompt = DiscourseAi::Embeddings::HydeGenerators::OpenAi.new.prompt(query)
|
prompt = DiscourseAi::Embeddings::HydeGenerators::OpenAi.new.prompt(query)
|
||||||
OpenAiCompletionsInferenceStubs.stub_response(
|
OpenAiCompletionsInferenceStubs.stub_response(
|
||||||
prompt,
|
prompt,
|
||||||
@ -21,11 +20,7 @@ RSpec.describe "AI Composer helper", type: :system, js: true do
|
|||||||
)
|
)
|
||||||
|
|
||||||
hyde_embedding = [0.049382, 0.9999]
|
hyde_embedding = [0.049382, 0.9999]
|
||||||
EmbeddingsGenerationStubs.discourse_service(
|
EmbeddingsGenerationStubs.hugging_face_service(hypothetical_post, hyde_embedding)
|
||||||
SiteSetting.ai_embeddings_model,
|
|
||||||
hypothetical_post,
|
|
||||||
hyde_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
SearchIndexer.enable
|
SearchIndexer.enable
|
||||||
SearchIndexer.index(topic, force: true)
|
SearchIndexer.index(topic, force: true)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user