mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-02-07 12:08:13 +00:00
FEATURE: configurable embeddings (#1049)
* Use AR model for embeddings features * endpoints * Embeddings CRUD UI * Add presets. Hide a couple more settings * system specs * Seed embedding definition from old settings * Generate search bit index on the fly. cleanup orphaned data * support for seeded models * Fix run test for new embedding * fix selected model not set correctly
This commit is contained in:
parent
fad4b65d4f
commit
f5cf1019fb
@ -0,0 +1,21 @@
|
||||
import DiscourseRoute from "discourse/routes/discourse";
|
||||
|
||||
export default class AdminPluginsShowDiscourseAiEmbeddingsEdit extends DiscourseRoute {
|
||||
async model(params) {
|
||||
const allEmbeddings = this.modelFor(
|
||||
"adminPlugins.show.discourse-ai-embeddings"
|
||||
);
|
||||
const id = parseInt(params.id, 10);
|
||||
const record = allEmbeddings.findBy("id", id);
|
||||
record.provider_params = record.provider_params || {};
|
||||
return record;
|
||||
}
|
||||
|
||||
setupController(controller, model) {
|
||||
super.setupController(controller, model);
|
||||
controller.set(
|
||||
"allEmbeddings",
|
||||
this.modelFor("adminPlugins.show.discourse-ai-embeddings")
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
import DiscourseRoute from "discourse/routes/discourse";
|
||||
|
||||
export default class AdminPluginsShowDiscourseAiEmbeddingsNew extends DiscourseRoute {
|
||||
async model() {
|
||||
const record = this.store.createRecord("ai-embedding");
|
||||
record.provider_params = {};
|
||||
return record;
|
||||
}
|
||||
|
||||
setupController(controller, model) {
|
||||
super.setupController(controller, model);
|
||||
controller.set(
|
||||
"allEmbeddings",
|
||||
this.modelFor("adminPlugins.show.discourse-ai-embeddings")
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
import DiscourseRoute from "discourse/routes/discourse";
|
||||
|
||||
export default class DiscourseAiAiEmbeddingsRoute extends DiscourseRoute {
|
||||
model() {
|
||||
return this.store.findAll("ai-embedding");
|
||||
}
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
<AiEmbeddingsListEditor
|
||||
@embeddings={{this.allEmbeddings}}
|
||||
@currentEmbedding={{this.model}}
|
||||
/>
|
@ -0,0 +1 @@
|
||||
<AiEmbeddingsListEditor @embeddings={{this.model}} />
|
@ -0,0 +1,4 @@
|
||||
<AiEmbeddingsListEditor
|
||||
@embeddings={{this.allEmbeddings}}
|
||||
@currentEmbedding={{this.model}}
|
||||
/>
|
130
app/controllers/discourse_ai/admin/ai_embeddings_controller.rb
Normal file
130
app/controllers/discourse_ai/admin/ai_embeddings_controller.rb
Normal file
@ -0,0 +1,130 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Admin
|
||||
class AiEmbeddingsController < ::Admin::AdminController
|
||||
requires_plugin ::DiscourseAi::PLUGIN_NAME
|
||||
|
||||
def index
|
||||
embedding_defs = EmbeddingDefinition.all.order(:display_name)
|
||||
|
||||
render json: {
|
||||
ai_embeddings:
|
||||
ActiveModel::ArraySerializer.new(
|
||||
embedding_defs,
|
||||
each_serializer: AiEmbeddingDefinitionSerializer,
|
||||
root: false,
|
||||
).as_json,
|
||||
meta: {
|
||||
provider_params: EmbeddingDefinition.provider_params,
|
||||
providers: EmbeddingDefinition.provider_names,
|
||||
distance_functions: EmbeddingDefinition.distance_functions,
|
||||
tokenizers:
|
||||
EmbeddingDefinition.tokenizer_names.map { |tn|
|
||||
{ id: tn, name: tn.split("::").last }
|
||||
},
|
||||
presets: EmbeddingDefinition.presets,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
def new
|
||||
end
|
||||
|
||||
def edit
|
||||
embedding_def = EmbeddingDefinition.find(params[:id])
|
||||
render json: AiEmbeddingDefinitionSerializer.new(embedding_def)
|
||||
end
|
||||
|
||||
def create
|
||||
embedding_def = EmbeddingDefinition.new(ai_embeddings_params)
|
||||
|
||||
if embedding_def.save
|
||||
render json: AiEmbeddingDefinitionSerializer.new(embedding_def), status: :created
|
||||
else
|
||||
render_json_error embedding_def
|
||||
end
|
||||
end
|
||||
|
||||
def update
|
||||
embedding_def = EmbeddingDefinition.find(params[:id])
|
||||
|
||||
if embedding_def.seeded?
|
||||
return(
|
||||
render_json_error(I18n.t("discourse_ai.embeddings.cannot_edit_builtin"), status: 403)
|
||||
)
|
||||
end
|
||||
|
||||
if embedding_def.update(ai_embeddings_params.except(:dimensions))
|
||||
render json: AiEmbeddingDefinitionSerializer.new(embedding_def)
|
||||
else
|
||||
render_json_error embedding_def
|
||||
end
|
||||
end
|
||||
|
||||
def destroy
|
||||
embedding_def = EmbeddingDefinition.find(params[:id])
|
||||
|
||||
if embedding_def.seeded?
|
||||
return(
|
||||
render_json_error(I18n.t("discourse_ai.embeddings.cannot_edit_builtin"), status: 403)
|
||||
)
|
||||
end
|
||||
|
||||
if embedding_def.id == SiteSetting.ai_embeddings_selected_model.to_i
|
||||
return render_json_error(I18n.t("discourse_ai.embeddings.delete_failed"), status: 409)
|
||||
end
|
||||
|
||||
if embedding_def.destroy
|
||||
head :no_content
|
||||
else
|
||||
render_json_error embedding_def
|
||||
end
|
||||
end
|
||||
|
||||
def test
|
||||
RateLimiter.new(
|
||||
current_user,
|
||||
"ai_embeddings_test_#{current_user.id}",
|
||||
3,
|
||||
1.minute,
|
||||
).performed!
|
||||
|
||||
embedding_def = EmbeddingDefinition.new(ai_embeddings_params)
|
||||
DiscourseAi::Embeddings::Vector.new(embedding_def).vector_from("this is a test")
|
||||
|
||||
render json: { success: true }
|
||||
rescue Net::HTTPBadResponse => e
|
||||
render json: { success: false, error: e.message }
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def ai_embeddings_params
|
||||
permitted =
|
||||
params.require(:ai_embedding).permit(
|
||||
:display_name,
|
||||
:dimensions,
|
||||
:max_sequence_length,
|
||||
:pg_function,
|
||||
:provider,
|
||||
:url,
|
||||
:api_key,
|
||||
:tokenizer_class,
|
||||
)
|
||||
|
||||
extra_field_names = EmbeddingDefinition.provider_params.dig(permitted[:provider]&.to_sym)
|
||||
if extra_field_names.present?
|
||||
received_prov_params =
|
||||
params.dig(:ai_embedding, :provider_params)&.slice(*extra_field_names.keys)
|
||||
|
||||
if received_prov_params.present?
|
||||
permitted[:provider_params] = received_prov_params.permit!
|
||||
end
|
||||
end
|
||||
|
||||
permitted
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -18,7 +18,7 @@ module ::Jobs
|
||||
target = target_type.constantize.find_by(id: target_id)
|
||||
return if !target
|
||||
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
vector_rep = DiscourseAi::Embeddings::Vector.instance
|
||||
|
||||
tokenizer = vector_rep.tokenizer
|
||||
chunk_tokens = target.rag_chunk_tokens
|
||||
|
13
app/jobs/regular/manage_embedding_def_search_index.rb
Normal file
13
app/jobs/regular/manage_embedding_def_search_index.rb
Normal file
@ -0,0 +1,13 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module ::Jobs
|
||||
class ManageEmbeddingDefSearchIndex < ::Jobs::Base
|
||||
def execute(args)
|
||||
embedding_def = EmbeddingDefinition.find_by(id: args[:id])
|
||||
return if embedding_def.nil?
|
||||
return if DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_def)
|
||||
|
||||
DiscourseAi::Embeddings::Schema.prepare_search_indexes(embedding_def)
|
||||
end
|
||||
end
|
||||
end
|
11
app/jobs/scheduled/remove_orphaned_embeddings.rb
Normal file
11
app/jobs/scheduled/remove_orphaned_embeddings.rb
Normal file
@ -0,0 +1,11 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module Jobs
|
||||
class RemoveOrphanedEmbeddings < ::Jobs::Scheduled
|
||||
every 1.week
|
||||
|
||||
def execute(_args)
|
||||
DiscourseAi::Embeddings::Schema.remove_orphaned_data
|
||||
end
|
||||
end
|
||||
end
|
231
app/models/embedding_definition.rb
Normal file
231
app/models/embedding_definition.rb
Normal file
@ -0,0 +1,231 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
class EmbeddingDefinition < ActiveRecord::Base
|
||||
CLOUDFLARE = "cloudflare"
|
||||
HUGGING_FACE = "hugging_face"
|
||||
OPEN_AI = "open_ai"
|
||||
GOOGLE = "google"
|
||||
|
||||
class << self
|
||||
def provider_names
|
||||
[CLOUDFLARE, HUGGING_FACE, OPEN_AI, GOOGLE]
|
||||
end
|
||||
|
||||
def distance_functions
|
||||
%w[<#> <=>]
|
||||
end
|
||||
|
||||
def tokenizer_names
|
||||
[
|
||||
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer,
|
||||
DiscourseAi::Tokenizer::BgeLargeEnTokenizer,
|
||||
DiscourseAi::Tokenizer::BgeM3Tokenizer,
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer,
|
||||
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer,
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer,
|
||||
].map(&:name)
|
||||
end
|
||||
|
||||
def provider_params
|
||||
{ open_ai: { model_name: :text } }
|
||||
end
|
||||
|
||||
def presets
|
||||
@presets ||=
|
||||
begin
|
||||
[
|
||||
{
|
||||
preset_id: "bge-large-en",
|
||||
display_name: "bge-large-en",
|
||||
dimensions: 1024,
|
||||
max_sequence_length: 512,
|
||||
pg_function: "<#>",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer",
|
||||
provider: HUGGING_FACE,
|
||||
},
|
||||
{
|
||||
preset_id: "bge-m3",
|
||||
display_name: "bge-m3",
|
||||
dimensions: 1024,
|
||||
max_sequence_length: 8192,
|
||||
pg_function: "<#>",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
|
||||
provider: HUGGING_FACE,
|
||||
},
|
||||
{
|
||||
preset_id: "gemini-embedding-001",
|
||||
display_name: "Gemini's embedding-001",
|
||||
dimensions: 768,
|
||||
max_sequence_length: 1536,
|
||||
pg_function: "<=>",
|
||||
url:
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||
provider: GOOGLE,
|
||||
},
|
||||
{
|
||||
preset_id: "multilingual-e5-large",
|
||||
display_name: "multilingual-e5-large",
|
||||
dimensions: 1024,
|
||||
max_sequence_length: 512,
|
||||
pg_function: "<=>",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer",
|
||||
provider: HUGGING_FACE,
|
||||
},
|
||||
{
|
||||
preset_id: "text-embedding-3-large",
|
||||
display_name: "OpenAI's text-embedding-3-large",
|
||||
dimensions: 2000,
|
||||
max_sequence_length: 8191,
|
||||
pg_function: "<=>",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||
url: "https://api.openai.com/v1/embeddings",
|
||||
provider: OPEN_AI,
|
||||
provider_params: {
|
||||
model_name: "text-embedding-3-large",
|
||||
},
|
||||
},
|
||||
{
|
||||
preset_id: "text-embedding-3-small",
|
||||
display_name: "OpenAI's text-embedding-3-small",
|
||||
dimensions: 1536,
|
||||
max_sequence_length: 8191,
|
||||
pg_function: "<=>",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||
url: "https://api.openai.com/v1/embeddings",
|
||||
provider: OPEN_AI,
|
||||
provider_params: {
|
||||
model_name: "text-embedding-3-small",
|
||||
},
|
||||
},
|
||||
{
|
||||
preset_id: "text-embedding-ada-002",
|
||||
display_name: "OpenAI's text-embedding-ada-002",
|
||||
dimensions: 1536,
|
||||
max_sequence_length: 8191,
|
||||
pg_function: "<=>",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||
url: "https://api.openai.com/v1/embeddings",
|
||||
provider: OPEN_AI,
|
||||
provider_params: {
|
||||
model_name: "text-embedding-ada-002",
|
||||
},
|
||||
},
|
||||
]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
validates :provider, presence: true, inclusion: provider_names
|
||||
validates :display_name, presence: true, length: { maximum: 100 }
|
||||
validates :tokenizer_class, presence: true, inclusion: tokenizer_names
|
||||
validates_presence_of :url, :api_key, :dimensions, :max_sequence_length, :pg_function
|
||||
|
||||
after_create :create_indexes
|
||||
|
||||
def create_indexes
|
||||
Jobs.enqueue(:manage_embedding_def_search_index, id: self.id)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
tokenizer_class.constantize
|
||||
end
|
||||
|
||||
def inference_client
|
||||
case provider
|
||||
when CLOUDFLARE
|
||||
cloudflare_client
|
||||
when HUGGING_FACE
|
||||
hugging_face_client
|
||||
when OPEN_AI
|
||||
open_ai_client
|
||||
when GOOGLE
|
||||
gemini_client
|
||||
else
|
||||
raise "Uknown embeddings provider"
|
||||
end
|
||||
end
|
||||
|
||||
def lookup_custom_param(key)
|
||||
provider_params&.dig(key)
|
||||
end
|
||||
|
||||
def endpoint_url
|
||||
return url if !url.starts_with?("srv://")
|
||||
|
||||
service = DiscourseAi::Utils::DnsSrv.lookup(url.sub("srv://", ""))
|
||||
"https://#{service.target}:#{service.port}"
|
||||
end
|
||||
|
||||
def prepare_query_text(text, asymetric: false)
|
||||
strategy.prepare_query_text(text, self, asymetric: asymetric)
|
||||
end
|
||||
|
||||
def prepare_target_text(target)
|
||||
strategy.prepare_target_text(target, self)
|
||||
end
|
||||
|
||||
def strategy_id
|
||||
strategy.id
|
||||
end
|
||||
|
||||
def strategy_version
|
||||
strategy.version
|
||||
end
|
||||
|
||||
def api_key
|
||||
if seeded?
|
||||
env_key = "DISCOURSE_AI_SEEDED_EMBEDDING_API_KEY"
|
||||
ENV[env_key] || self[:api_key]
|
||||
else
|
||||
self[:api_key]
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def strategy
|
||||
@strategy ||= DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
end
|
||||
|
||||
def cloudflare_client
|
||||
DiscourseAi::Inference::CloudflareWorkersAi.new(endpoint_url, api_key)
|
||||
end
|
||||
|
||||
def hugging_face_client
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.new(endpoint_url, api_key)
|
||||
end
|
||||
|
||||
def open_ai_client
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.new(
|
||||
endpoint_url,
|
||||
api_key,
|
||||
lookup_custom_param("model_name"),
|
||||
dimensions,
|
||||
)
|
||||
end
|
||||
|
||||
def gemini_client
|
||||
DiscourseAi::Inference::GeminiEmbeddings.new(endpoint_url, api_key)
|
||||
end
|
||||
end
|
||||
|
||||
# == Schema Information
|
||||
#
|
||||
# Table name: embedding_definitions
|
||||
#
|
||||
# id :bigint not null, primary key
|
||||
# display_name :string not null
|
||||
# dimensions :integer not null
|
||||
# max_sequence_length :integer not null
|
||||
# version :integer default(1), not null
|
||||
# pg_function :string not null
|
||||
# provider :string not null
|
||||
# tokenizer_class :string not null
|
||||
# url :string not null
|
||||
# api_key :string
|
||||
# seeded :boolean default(FALSE), not null
|
||||
# provider_params :jsonb
|
||||
# created_at :datetime not null
|
||||
# updated_at :datetime not null
|
||||
#
|
29
app/serializers/ai_embedding_definition_serializer.rb
Normal file
29
app/serializers/ai_embedding_definition_serializer.rb
Normal file
@ -0,0 +1,29 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
class AiEmbeddingDefinitionSerializer < ApplicationSerializer
|
||||
root "ai_embedding"
|
||||
|
||||
attributes :id,
|
||||
:display_name,
|
||||
:dimensions,
|
||||
:max_sequence_length,
|
||||
:pg_function,
|
||||
:provider,
|
||||
:url,
|
||||
:api_key,
|
||||
:seeded,
|
||||
:tokenizer_class,
|
||||
:provider_params
|
||||
|
||||
def api_key
|
||||
object.seeded? ? "********" : object.api_key
|
||||
end
|
||||
|
||||
def url
|
||||
object.seeded? ? "********" : object.url
|
||||
end
|
||||
|
||||
def provider
|
||||
object.seeded? ? "CDCK" : object.provider
|
||||
end
|
||||
end
|
@ -20,5 +20,14 @@ export default {
|
||||
});
|
||||
this.route("discourse-ai-spam", { path: "ai-spam" });
|
||||
this.route("discourse-ai-usage", { path: "ai-usage" });
|
||||
|
||||
this.route(
|
||||
"discourse-ai-embeddings",
|
||||
{ path: "ai-embeddings" },
|
||||
function () {
|
||||
this.route("new");
|
||||
this.route("edit", { path: "/:id/edit" });
|
||||
}
|
||||
);
|
||||
},
|
||||
};
|
||||
|
21
assets/javascripts/discourse/admin/adapters/ai-embedding.js
Normal file
21
assets/javascripts/discourse/admin/adapters/ai-embedding.js
Normal file
@ -0,0 +1,21 @@
|
||||
import RestAdapter from "discourse/adapters/rest";
|
||||
|
||||
export default class Adapter extends RestAdapter {
|
||||
jsonMode = true;
|
||||
|
||||
basePath() {
|
||||
return "/admin/plugins/discourse-ai/";
|
||||
}
|
||||
|
||||
pathFor(store, type, findArgs) {
|
||||
// removes underscores which are implemented in base
|
||||
let path =
|
||||
this.basePath(store, type, findArgs) +
|
||||
store.pluralize(this.apiNameFor(type));
|
||||
return this.appendQueryParams(path, findArgs);
|
||||
}
|
||||
|
||||
apiNameFor() {
|
||||
return "ai-embedding";
|
||||
}
|
||||
}
|
38
assets/javascripts/discourse/admin/models/ai-embedding.js
Normal file
38
assets/javascripts/discourse/admin/models/ai-embedding.js
Normal file
@ -0,0 +1,38 @@
|
||||
import { ajax } from "discourse/lib/ajax";
|
||||
import RestModel from "discourse/models/rest";
|
||||
|
||||
export default class AiEmbedding extends RestModel {
|
||||
createProperties() {
|
||||
return this.getProperties(
|
||||
"id",
|
||||
"display_name",
|
||||
"dimensions",
|
||||
"provider",
|
||||
"tokenizer_class",
|
||||
"dimensions",
|
||||
"url",
|
||||
"api_key",
|
||||
"max_sequence_length",
|
||||
"provider_params",
|
||||
"pg_function"
|
||||
);
|
||||
}
|
||||
|
||||
updateProperties() {
|
||||
const attrs = this.createProperties();
|
||||
attrs.id = this.id;
|
||||
|
||||
return attrs;
|
||||
}
|
||||
|
||||
async testConfig() {
|
||||
return await ajax(`/admin/plugins/discourse-ai/ai-embeddings/test.json`, {
|
||||
data: { ai_embedding: this.createProperties() },
|
||||
});
|
||||
}
|
||||
|
||||
workingCopy() {
|
||||
const attrs = this.createProperties();
|
||||
return this.store.createRecord("ai-embedding", attrs);
|
||||
}
|
||||
}
|
376
assets/javascripts/discourse/components/ai-embedding-editor.gjs
Normal file
376
assets/javascripts/discourse/components/ai-embedding-editor.gjs
Normal file
@ -0,0 +1,376 @@
|
||||
import Component from "@glimmer/component";
|
||||
import { tracked } from "@glimmer/tracking";
|
||||
import { Input } from "@ember/component";
|
||||
import { concat, get } from "@ember/helper";
|
||||
import { on } from "@ember/modifier";
|
||||
import { action, computed } from "@ember/object";
|
||||
import didInsert from "@ember/render-modifiers/modifiers/did-insert";
|
||||
import didUpdate from "@ember/render-modifiers/modifiers/did-update";
|
||||
import { later } from "@ember/runloop";
|
||||
import { service } from "@ember/service";
|
||||
import BackButton from "discourse/components/back-button";
|
||||
import DButton from "discourse/components/d-button";
|
||||
import icon from "discourse/helpers/d-icon";
|
||||
import { popupAjaxError } from "discourse/lib/ajax-error";
|
||||
import { i18n } from "discourse-i18n";
|
||||
import ComboBox from "select-kit/components/combo-box";
|
||||
import DTooltip from "float-kit/components/d-tooltip";
|
||||
import not from "truth-helpers/helpers/not";
|
||||
|
||||
export default class AiEmbeddingEditor extends Component {
|
||||
@service toasts;
|
||||
@service router;
|
||||
@service dialog;
|
||||
@service store;
|
||||
|
||||
@tracked isSaving = false;
|
||||
@tracked selectedPreset = null;
|
||||
|
||||
@tracked testRunning = false;
|
||||
@tracked testResult = null;
|
||||
@tracked testError = null;
|
||||
@tracked apiKeySecret = true;
|
||||
@tracked editingModel = null;
|
||||
|
||||
get selectedProviders() {
|
||||
const t = (provName) => {
|
||||
return i18n(`discourse_ai.embeddings.providers.${provName}`);
|
||||
};
|
||||
|
||||
return this.args.embeddings.resultSetMeta.providers.map((prov) => {
|
||||
return { id: prov, name: t(prov) };
|
||||
});
|
||||
}
|
||||
|
||||
get distanceFunctions() {
|
||||
const t = (df) => {
|
||||
return i18n(`discourse_ai.embeddings.distance_functions.${df}`);
|
||||
};
|
||||
|
||||
return this.args.embeddings.resultSetMeta.distance_functions.map((df) => {
|
||||
return { id: df, name: t(df) };
|
||||
});
|
||||
}
|
||||
|
||||
get presets() {
|
||||
const presets = this.args.embeddings.resultSetMeta.presets.map((preset) => {
|
||||
return {
|
||||
name: preset.display_name,
|
||||
id: preset.preset_id,
|
||||
};
|
||||
});
|
||||
|
||||
presets.pushObject({
|
||||
name: i18n("discourse_ai.embeddings.configure_manually"),
|
||||
id: "manual",
|
||||
});
|
||||
|
||||
return presets;
|
||||
}
|
||||
|
||||
get showPresets() {
|
||||
return !this.selectedPreset && this.args.model.isNew;
|
||||
}
|
||||
|
||||
@computed("editingModel.provider")
|
||||
get metaProviderParams() {
|
||||
return (
|
||||
this.args.embeddings.resultSetMeta.provider_params[
|
||||
this.editingModel?.provider
|
||||
] || {}
|
||||
);
|
||||
}
|
||||
|
||||
get testErrorMessage() {
|
||||
return i18n("discourse_ai.llms.tests.failure", { error: this.testError });
|
||||
}
|
||||
|
||||
get displayTestResult() {
|
||||
return this.testRunning || this.testResult !== null;
|
||||
}
|
||||
|
||||
@action
|
||||
configurePreset() {
|
||||
this.selectedPreset =
|
||||
this.args.embeddings.resultSetMeta.presets.findBy(
|
||||
"preset_id",
|
||||
this.presetId
|
||||
) || {};
|
||||
|
||||
this.editingModel = this.store
|
||||
.createRecord("ai-embedding", this.selectedPreset)
|
||||
.workingCopy();
|
||||
}
|
||||
|
||||
@action
|
||||
updateModel() {
|
||||
this.editingModel = this.args.model.workingCopy();
|
||||
}
|
||||
|
||||
@action
|
||||
makeApiKeySecret() {
|
||||
this.apiKeySecret = true;
|
||||
}
|
||||
|
||||
@action
|
||||
toggleApiKeySecret() {
|
||||
this.apiKeySecret = !this.apiKeySecret;
|
||||
}
|
||||
|
||||
@action
|
||||
async save() {
|
||||
this.isSaving = true;
|
||||
const isNew = this.args.model.isNew;
|
||||
|
||||
try {
|
||||
await this.editingModel.save();
|
||||
|
||||
if (isNew) {
|
||||
this.args.embeddings.addObject(this.editingModel);
|
||||
this.router.transitionTo(
|
||||
"adminPlugins.show.discourse-ai-embeddings.index"
|
||||
);
|
||||
} else {
|
||||
this.toasts.success({
|
||||
data: { message: i18n("discourse_ai.embeddings.saved") },
|
||||
duration: 2000,
|
||||
});
|
||||
}
|
||||
} catch (e) {
|
||||
popupAjaxError(e);
|
||||
} finally {
|
||||
later(() => {
|
||||
this.isSaving = false;
|
||||
}, 1000);
|
||||
}
|
||||
}
|
||||
|
||||
@action
|
||||
async test() {
|
||||
this.testRunning = true;
|
||||
|
||||
try {
|
||||
const configTestResult = await this.editingModel.testConfig();
|
||||
this.testResult = configTestResult.success;
|
||||
|
||||
if (this.testResult) {
|
||||
this.testError = null;
|
||||
} else {
|
||||
this.testError = configTestResult.error;
|
||||
}
|
||||
} catch (e) {
|
||||
popupAjaxError(e);
|
||||
} finally {
|
||||
later(() => {
|
||||
this.testRunning = false;
|
||||
}, 1000);
|
||||
}
|
||||
}
|
||||
|
||||
@action
|
||||
delete() {
|
||||
return this.dialog.confirm({
|
||||
message: i18n("discourse_ai.embeddings.confirm_delete"),
|
||||
didConfirm: () => {
|
||||
return this.args.model
|
||||
.destroyRecord()
|
||||
.then(() => {
|
||||
this.args.llms.removeObject(this.args.model);
|
||||
this.router.transitionTo(
|
||||
"adminPlugins.show.discourse-ai-embeddings.index"
|
||||
);
|
||||
})
|
||||
.catch(popupAjaxError);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
<template>
|
||||
<BackButton
|
||||
@route="adminPlugins.show.discourse-ai-embeddings"
|
||||
@label="discourse_ai.embeddings.back"
|
||||
/>
|
||||
|
||||
<form
|
||||
{{didInsert this.updateModel @model.id}}
|
||||
{{didUpdate this.updateModel @model.id}}
|
||||
class="form-horizontal ai-embedding-editor"
|
||||
>
|
||||
{{#if this.showPresets}}
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.embeddings.presets"}}</label>
|
||||
<ComboBox
|
||||
@value={{this.presetId}}
|
||||
@content={{this.presets}}
|
||||
class="ai-embedding-editor__presets"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="control-group ai-llm-editor__action_panel">
|
||||
<DButton
|
||||
@action={{this.configurePreset}}
|
||||
@label="discourse_ai.tools.next.title"
|
||||
class="ai-embedding-editor__next"
|
||||
/>
|
||||
</div>
|
||||
{{else}}
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.embeddings.display_name"}}</label>
|
||||
<Input
|
||||
class="ai-embedding-editor-input ai-embedding-editor__display-name"
|
||||
@type="text"
|
||||
@value={{this.editingModel.display_name}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.embeddings.provider"}}</label>
|
||||
<ComboBox
|
||||
@value={{this.editingModel.provider}}
|
||||
@content={{this.selectedProviders}}
|
||||
@class="ai-embedding-editor__provider"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.embeddings.url"}}</label>
|
||||
<Input
|
||||
class="ai-embedding-editor-input ai-embedding-editor__url"
|
||||
@type="text"
|
||||
@value={{this.editingModel.url}}
|
||||
required="true"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.embeddings.api_key"}}</label>
|
||||
<div class="ai-embedding-editor__secret-api-key-group">
|
||||
<Input
|
||||
@value={{this.editingModel.api_key}}
|
||||
class="ai-embedding-editor-input ai-embedding-editor__api-key"
|
||||
@type={{if this.apiKeySecret "password" "text"}}
|
||||
required="true"
|
||||
{{on "focusout" this.makeApiKeySecret}}
|
||||
/>
|
||||
<DButton
|
||||
@action={{this.toggleApiKeySecret}}
|
||||
@icon="far-eye-slash"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.embeddings.tokenizer"}}</label>
|
||||
<ComboBox
|
||||
@value={{this.editingModel.tokenizer_class}}
|
||||
@content={{@embeddings.resultSetMeta.tokenizers}}
|
||||
@class="ai-embedding-editor__tokenizer"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.embeddings.dimensions"}}</label>
|
||||
<Input
|
||||
@type="number"
|
||||
class="ai-embedding-editor-input ai-embedding-editor__dimensions"
|
||||
step="any"
|
||||
min="0"
|
||||
lang="en"
|
||||
@value={{this.editingModel.dimensions}}
|
||||
required="true"
|
||||
disabled={{not this.editingModel.isNew}}
|
||||
/>
|
||||
{{#if this.editingModel.isNew}}
|
||||
<DTooltip
|
||||
@icon="circle-exclamation"
|
||||
@content={{i18n
|
||||
"discourse_ai.embeddings.hints.dimensions_warning"
|
||||
}}
|
||||
/>
|
||||
{{/if}}
|
||||
</div>
|
||||
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.embeddings.max_sequence_length"}}</label>
|
||||
<Input
|
||||
@type="number"
|
||||
class="ai-embedding-editor-input ai-embedding-editor__max_sequence_length"
|
||||
step="any"
|
||||
min="0"
|
||||
lang="en"
|
||||
@value={{this.editingModel.max_sequence_length}}
|
||||
required="true"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.embeddings.distance_function"}}</label>
|
||||
<ComboBox
|
||||
@value={{this.editingModel.pg_function}}
|
||||
@content={{this.distanceFunctions}}
|
||||
@class="ai-embedding-editor__distance_functions"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{{#each-in this.metaProviderParams as |field type|}}
|
||||
<div
|
||||
class="control-group ai-embedding-editor-provider-param__{{type}}"
|
||||
>
|
||||
<label>
|
||||
{{i18n (concat "discourse_ai.embeddings.provider_fields." field)}}
|
||||
</label>
|
||||
<Input
|
||||
@type="text"
|
||||
class="ai-embedding-editor-input ai-embedding-editor__{{field}}"
|
||||
@value={{mut (get this.editingModel.provider_params field)}}
|
||||
/>
|
||||
</div>
|
||||
{{/each-in}}
|
||||
|
||||
<div class="control-group ai-embedding-editor__action_panel">
|
||||
<DButton
|
||||
class="ai-embedding-editor__test"
|
||||
@action={{this.test}}
|
||||
@disabled={{this.testRunning}}
|
||||
@label="discourse_ai.embeddings.tests.title"
|
||||
/>
|
||||
|
||||
<DButton
|
||||
class="btn-primary ai-embedding-editor__save"
|
||||
@action={{this.save}}
|
||||
@disabled={{this.isSaving}}
|
||||
@label="discourse_ai.embeddings.save"
|
||||
/>
|
||||
{{#unless this.editingModel.isNew}}
|
||||
<DButton
|
||||
@action={{this.delete}}
|
||||
class="btn-danger ai-embedding-editor__delete"
|
||||
@label="discourse_ai.embeddings.delete"
|
||||
/>
|
||||
{{/unless}}
|
||||
|
||||
<div class="control-group ai-embedding-editor-tests">
|
||||
{{#if this.displayTestResult}}
|
||||
{{#if this.testRunning}}
|
||||
<div class="spinner small"></div>
|
||||
{{i18n "discourse_ai.embeddings.tests.running"}}
|
||||
{{else}}
|
||||
{{#if this.testResult}}
|
||||
<div class="ai-embedding-editor-tests__success">
|
||||
{{icon "check"}}
|
||||
{{i18n "discourse_ai.embeddings.tests.success"}}
|
||||
</div>
|
||||
{{else}}
|
||||
<div class="ai-embedding-editor-tests__failure">
|
||||
{{icon "xmark"}}
|
||||
{{this.testErrorMessage}}
|
||||
</div>
|
||||
{{/if}}
|
||||
{{/if}}
|
||||
{{/if}}
|
||||
</div>
|
||||
</div>
|
||||
{{/if}}
|
||||
</form>
|
||||
</template>
|
||||
}
|
@ -0,0 +1,114 @@
|
||||
import Component from "@glimmer/component";
|
||||
import { concat } from "@ember/helper";
|
||||
import { service } from "@ember/service";
|
||||
import DBreadcrumbsItem from "discourse/components/d-breadcrumbs-item";
|
||||
import DButton from "discourse/components/d-button";
|
||||
import DPageSubheader from "discourse/components/d-page-subheader";
|
||||
import { i18n } from "discourse-i18n";
|
||||
import AdminConfigAreaEmptyList from "admin/components/admin-config-area-empty-list";
|
||||
import DTooltip from "float-kit/components/d-tooltip";
|
||||
import AiEmbeddingEditor from "./ai-embedding-editor";
|
||||
|
||||
export default class AiEmbeddingsListEditor extends Component {
|
||||
@service adminPluginNavManager;
|
||||
|
||||
get hasEmbeddingElements() {
|
||||
return this.args.embeddings.length !== 0;
|
||||
}
|
||||
|
||||
<template>
|
||||
<DBreadcrumbsItem
|
||||
@path="/admin/plugins/{{this.adminPluginNavManager.currentPlugin.name}}/ai-embeddings"
|
||||
@label={{i18n "discourse_ai.embeddings.short_title"}}
|
||||
/>
|
||||
<section class="ai-embeddings-list-editor admin-detail">
|
||||
{{#if @currentEmbedding}}
|
||||
<AiEmbeddingEditor
|
||||
@model={{@currentEmbedding}}
|
||||
@embeddings={{@embeddings}}
|
||||
/>
|
||||
{{else}}
|
||||
<DPageSubheader
|
||||
@titleLabel={{i18n "discourse_ai.embeddings.short_title"}}
|
||||
@descriptionLabel={{i18n "discourse_ai.embeddings.description"}}
|
||||
@learnMoreUrl="https://meta.discourse.org/t/discourse-ai-embeddings/259603"
|
||||
>
|
||||
<:actions as |actions|>
|
||||
<actions.Primary
|
||||
@label="discourse_ai.embeddings.new"
|
||||
@route="adminPlugins.show.discourse-ai-embeddings.new"
|
||||
@icon="plus"
|
||||
class="ai-embeddings-list-editor__new-button"
|
||||
/>
|
||||
</:actions>
|
||||
</DPageSubheader>
|
||||
|
||||
{{#if this.hasEmbeddingElements}}
|
||||
<table class="d-admin-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>{{i18n "discourse_ai.embeddings.display_name"}}</th>
|
||||
<th>{{i18n "discourse_ai.embeddings.provider"}}</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{#each @embeddings as |embedding|}}
|
||||
<tr class="ai-embeddings-list__row d-admin-row__content">
|
||||
<td class="d-admin-row__overview">
|
||||
<div class="ai-embeddings-list__name">
|
||||
<strong>
|
||||
{{embedding.display_name}}
|
||||
</strong>
|
||||
</div>
|
||||
</td>
|
||||
<td class="d-admin-row__detail">
|
||||
<div class="d-admin-row__mobile-label">
|
||||
{{i18n "discourse_ai.embeddings.provider"}}
|
||||
</div>
|
||||
{{i18n
|
||||
(concat
|
||||
"discourse_ai.embeddings.providers." embedding.provider
|
||||
)
|
||||
}}
|
||||
</td>
|
||||
<td class="d-admin-row__controls">
|
||||
{{#if embedding.seeded}}
|
||||
<DTooltip
|
||||
class="ai-embeddings-list__edit-disabled-tooltip"
|
||||
>
|
||||
<:trigger>
|
||||
<DButton
|
||||
class="btn btn-default btn-small disabled"
|
||||
@label="discourse_ai.embeddings.edit"
|
||||
/>
|
||||
</:trigger>
|
||||
<:content>
|
||||
{{i18n "discourse_ai.embeddings.seeded_warning"}}
|
||||
</:content>
|
||||
</DTooltip>
|
||||
{{else}}
|
||||
<DButton
|
||||
class="btn btn-default btn-small ai-embeddings-list__edit-button"
|
||||
@label="discourse_ai.embeddings.edit"
|
||||
@route="adminPlugins.show.discourse-ai-embeddings.edit"
|
||||
@routeModels={{embedding.id}}
|
||||
/>
|
||||
{{/if}}
|
||||
</td>
|
||||
</tr>
|
||||
{{/each}}
|
||||
</tbody>
|
||||
</table>
|
||||
{{else}}
|
||||
<AdminConfigAreaEmptyList
|
||||
@ctaLabel="discourse_ai.embeddings.new"
|
||||
@ctaRoute="adminPlugins.show.discourse-ai-embeddings.new"
|
||||
@ctaClass="ai-embeddings-list-editor__empty-new-button"
|
||||
@emptyLabel="discourse_ai.embeddings.empty"
|
||||
/>
|
||||
{{/if}}
|
||||
{{/if}}
|
||||
</section>
|
||||
</template>
|
||||
}
|
@ -12,6 +12,10 @@ export default {
|
||||
|
||||
withPluginApi("1.1.0", (api) => {
|
||||
api.addAdminPluginConfigurationNav("discourse-ai", PLUGIN_NAV_MODE_TOP, [
|
||||
{
|
||||
label: "discourse_ai.embeddings.short_title",
|
||||
route: "adminPlugins.show.discourse-ai-embeddings",
|
||||
},
|
||||
{
|
||||
label: "discourse_ai.llms.short_title",
|
||||
route: "adminPlugins.show.discourse-ai-llms",
|
||||
|
@ -0,0 +1,26 @@
|
||||
.ai-embedding-editor {
|
||||
padding-left: 0.5em;
|
||||
|
||||
.ai-embedding-editor-input {
|
||||
width: 350px;
|
||||
}
|
||||
|
||||
.ai-embedding-editor-tests {
|
||||
&__failure {
|
||||
color: var(--danger);
|
||||
}
|
||||
|
||||
&__success {
|
||||
color: var(--success);
|
||||
}
|
||||
}
|
||||
|
||||
&__api-key {
|
||||
margin-right: 0.5em;
|
||||
}
|
||||
|
||||
&__secret-api-key-group {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
}
|
@ -502,6 +502,49 @@ en:
|
||||
accuracy: "Accuracy:"
|
||||
|
||||
embeddings:
|
||||
short_title: "Embeddings"
|
||||
description: "Embeddings are a crucial component of the Discourse AI plugin, enabling features like related topics and semantic search."
|
||||
new: "New embedding"
|
||||
back: "Back"
|
||||
save: "Save"
|
||||
saved: "Embedding configuration saved"
|
||||
delete: "Delete"
|
||||
confirm_delete: Are you sure you want to remove this embedding configuration?
|
||||
empty: "You haven't setup embeddings yet"
|
||||
presets: "Select a preset..."
|
||||
configure_manually: "Configure manually"
|
||||
edit: "Edit"
|
||||
seeded_warning: "This is pre-configured on your site and cannot be edited."
|
||||
tests:
|
||||
title: "Run test"
|
||||
running: "Running test..."
|
||||
success: "Success!"
|
||||
failure: "Attempting to generate an embedding resulted in: %{error}"
|
||||
hints:
|
||||
dimensions_warning: "Once saved, this value can't be changed."
|
||||
|
||||
display_name: "Name"
|
||||
provider: "Provider"
|
||||
url: "Embeddings service URL"
|
||||
api_key: "Embeddings service API Key"
|
||||
tokenizer: "Tokenizer"
|
||||
dimensions: "Embedding dimensions"
|
||||
max_sequence_length: "Sequence length"
|
||||
|
||||
distance_function: "Distance function"
|
||||
distance_functions:
|
||||
<#>: "Negative inner product (<#>)"
|
||||
<=>: "Cosine distance (<=>)"
|
||||
providers:
|
||||
hugging_face: "Hugging Face"
|
||||
open_ai: "OpenAI"
|
||||
google: "Google"
|
||||
cloudflare: "Cloudflare"
|
||||
CDCK: "CDCK"
|
||||
provider_fields:
|
||||
model_name: "Model name"
|
||||
|
||||
|
||||
semantic_search: "Topics (Semantic)"
|
||||
semantic_search_loading: "Searching for more results using AI"
|
||||
semantic_search_results:
|
||||
|
@ -49,10 +49,7 @@ en:
|
||||
ai_nsfw_flag_threshold_sexy: "Threshold for an image classified as sexy to be considered NSFW."
|
||||
ai_nsfw_models: "Models to use for NSFW inference."
|
||||
|
||||
ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)"
|
||||
ai_openai_api_key: "API key for OpenAI API. ONLY used for embeddings and Dall-E. For GPT use the LLM config tab"
|
||||
ai_hugging_face_tei_endpoint: URL where the API is running for the Hugging Face text embeddings inference
|
||||
ai_hugging_face_tei_api_key: API key for Hugging Face text embeddings inference
|
||||
ai_openai_api_key: "API key for OpenAI API. ONLY used for Dall-E. For GPT use the LLM config tab"
|
||||
|
||||
ai_helper_enabled: "Enable the AI helper."
|
||||
composer_ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
|
||||
@ -67,15 +64,11 @@ en:
|
||||
ai_helper_image_caption_model: "Select the model to use for generating image captions"
|
||||
ai_auto_image_caption_allowed_groups: "Users on these groups can toggle automatic image captioning."
|
||||
|
||||
ai_embeddings_enabled: "Enable the embeddings module."
|
||||
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
|
||||
ai_embeddings_discourse_service_api_key: "API key for the embeddings API"
|
||||
ai_embeddings_model: "Use all-mpnet-base-v2 for local and fast inference in english, text-embedding-ada-002 to use OpenAI API (need API key) and multilingual-e5-large for local multilingual embeddings"
|
||||
ai_embeddings_selected_model: "Use the selected model for generating embeddings."
|
||||
ai_embeddings_generate_for_pms: "Generate embeddings for personal messages."
|
||||
ai_embeddings_semantic_related_topics_enabled: "Use Semantic Search for related topics."
|
||||
ai_embeddings_semantic_related_topics: "Maximum number of topics to show in related topic section."
|
||||
ai_embeddings_backfill_batch_size: "Number of embeddings to backfill every 15 minutes."
|
||||
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
|
||||
ai_embeddings_semantic_search_enabled: "Enable full-page semantic search."
|
||||
ai_embeddings_semantic_quick_search_enabled: "Enable semantic search option in search menu popup."
|
||||
ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results"
|
||||
@ -437,13 +430,11 @@ en:
|
||||
cannot_edit_builtin: "You can't edit a built-in model."
|
||||
|
||||
embeddings:
|
||||
delete_failed: "This model is currently in use. Update the `ai embeddings selected model` first."
|
||||
cannot_edit_builtin: "You can't edit a built-in model."
|
||||
configuration:
|
||||
disable_embeddings: "You have to disable 'ai embeddings enabled' first."
|
||||
choose_model: "Set 'ai embeddings model' first."
|
||||
model_unreachable: "We failed to generate a test embedding with this model. Check your settings are correct."
|
||||
hint:
|
||||
one: "Make sure the `%{settings}` setting was configured."
|
||||
other: "Make sure the settings of the provider you want were configured. Options are: %{settings}"
|
||||
choose_model: "Set 'ai embeddings selected model' first."
|
||||
|
||||
llm_models:
|
||||
missing_provider_param: "%{param} can't be blank"
|
||||
|
@ -96,6 +96,13 @@ Discourse::Application.routes.draw do
|
||||
controller: "discourse_ai/admin/ai_llm_quotas",
|
||||
path: "quotas",
|
||||
only: %i[index create update destroy]
|
||||
|
||||
resources :ai_embeddings,
|
||||
only: %i[index new create edit update destroy],
|
||||
path: "ai-embeddings",
|
||||
controller: "discourse_ai/admin/ai_embeddings" do
|
||||
collection { get :test }
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -28,11 +28,14 @@ discourse_ai:
|
||||
|
||||
|
||||
ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations"
|
||||
ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings"
|
||||
ai_openai_embeddings_url:
|
||||
hidden: true
|
||||
default: "https://api.openai.com/v1/embeddings"
|
||||
ai_openai_organization:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_openai_api_key:
|
||||
hidden: true
|
||||
default: ""
|
||||
secret: true
|
||||
ai_stability_api_key:
|
||||
@ -50,11 +53,14 @@ discourse_ai:
|
||||
- "stable-diffusion-768-v2-1"
|
||||
- "stable-diffusion-v1-5"
|
||||
ai_hugging_face_tei_endpoint:
|
||||
hidden: true
|
||||
default: ""
|
||||
ai_hugging_face_tei_endpoint_srv:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_hugging_face_tei_api_key: ""
|
||||
ai_hugging_face_tei_api_key:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_hugging_face_tei_reranker_endpoint:
|
||||
default: ""
|
||||
ai_hugging_face_tei_reranker_endpoint_srv:
|
||||
@ -69,12 +75,14 @@ discourse_ai:
|
||||
ai_cloudflare_workers_account_id:
|
||||
default: ""
|
||||
secret: true
|
||||
hidden: true
|
||||
ai_cloudflare_workers_api_token:
|
||||
default: ""
|
||||
secret: true
|
||||
hidden: true
|
||||
ai_gemini_api_key:
|
||||
default: ""
|
||||
hidden: false
|
||||
hidden: true
|
||||
ai_strict_token_counting:
|
||||
default: false
|
||||
hidden: true
|
||||
@ -158,27 +166,12 @@ discourse_ai:
|
||||
default: false
|
||||
client: true
|
||||
validator: "DiscourseAi::Configuration::EmbeddingsModuleValidator"
|
||||
ai_embeddings_discourse_service_api_endpoint: ""
|
||||
ai_embeddings_discourse_service_api_endpoint_srv:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_embeddings_discourse_service_api_key:
|
||||
default: ""
|
||||
secret: true
|
||||
ai_embeddings_model:
|
||||
ai_embeddings_selected_model:
|
||||
type: enum
|
||||
default: "bge-large-en"
|
||||
default: ""
|
||||
allow_any: false
|
||||
choices:
|
||||
- all-mpnet-base-v2
|
||||
- text-embedding-ada-002
|
||||
- text-embedding-3-small
|
||||
- text-embedding-3-large
|
||||
- multilingual-e5-large
|
||||
- bge-large-en
|
||||
- gemini
|
||||
- bge-m3
|
||||
validator: "DiscourseAi::Configuration::EmbeddingsModelValidator"
|
||||
enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator"
|
||||
validator: "DiscourseAi::Configuration::EmbeddingDefsValidator"
|
||||
ai_embeddings_per_post_enabled:
|
||||
default: false
|
||||
hidden: true
|
||||
@ -191,9 +184,6 @@ discourse_ai:
|
||||
ai_embeddings_backfill_batch_size:
|
||||
default: 250
|
||||
hidden: true
|
||||
ai_embeddings_pg_connection_string:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_embeddings_semantic_search_enabled:
|
||||
default: false
|
||||
client: true
|
||||
@ -213,6 +203,35 @@ discourse_ai:
|
||||
default: false
|
||||
client: true
|
||||
hidden: true
|
||||
|
||||
ai_embeddings_discourse_service_api_endpoint:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_embeddings_discourse_service_api_endpoint_srv:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_embeddings_discourse_service_api_key:
|
||||
hidden: true
|
||||
default: ""
|
||||
secret: true
|
||||
ai_embeddings_model:
|
||||
hidden: true
|
||||
type: enum
|
||||
default: "bge-large-en"
|
||||
allow_any: false
|
||||
choices:
|
||||
- all-mpnet-base-v2
|
||||
- text-embedding-ada-002
|
||||
- text-embedding-3-small
|
||||
- text-embedding-3-large
|
||||
- multilingual-e5-large
|
||||
- bge-large-en
|
||||
- gemini
|
||||
- bge-m3
|
||||
ai_embeddings_pg_connection_string:
|
||||
default: ""
|
||||
hidden: true
|
||||
|
||||
ai_summarization_enabled:
|
||||
default: false
|
||||
client: true
|
||||
|
19
db/migrate/20241217164540_create_embedding_definitions.rb
Normal file
19
db/migrate/20241217164540_create_embedding_definitions.rb
Normal file
@ -0,0 +1,19 @@
|
||||
# frozen_string_literal: true
|
||||
class CreateEmbeddingDefinitions < ActiveRecord::Migration[7.2]
|
||||
def change
|
||||
create_table :embedding_definitions do |t|
|
||||
t.string :display_name, null: false
|
||||
t.integer :dimensions, null: false
|
||||
t.integer :max_sequence_length, null: false
|
||||
t.integer :version, null: false, default: 1
|
||||
t.string :pg_function, null: false
|
||||
t.string :provider, null: false
|
||||
t.string :tokenizer_class, null: false
|
||||
t.string :url, null: false
|
||||
t.string :api_key
|
||||
t.boolean :seeded, null: false, default: false
|
||||
t.jsonb :provider_params
|
||||
t.timestamps
|
||||
end
|
||||
end
|
||||
end
|
204
db/migrate/20250110114305_embedding_config_data_migration.rb
Normal file
204
db/migrate/20250110114305_embedding_config_data_migration.rb
Normal file
@ -0,0 +1,204 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
class EmbeddingConfigDataMigration < ActiveRecord::Migration[7.0]
|
||||
def up
|
||||
current_model = fetch_setting("ai_embeddings_model") || "bge-large-en"
|
||||
provider = provider_for(current_model)
|
||||
|
||||
if provider.present?
|
||||
attrs = creds_for(provider)
|
||||
|
||||
if attrs.present?
|
||||
attrs = attrs.merge(model_attrs(current_model))
|
||||
attrs[:display_name] = current_model
|
||||
attrs[:provider] = provider
|
||||
persist_config(attrs)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def down
|
||||
end
|
||||
|
||||
# Utils
|
||||
|
||||
def fetch_setting(name)
|
||||
DB.query_single(
|
||||
"SELECT value FROM site_settings WHERE name = :setting_name",
|
||||
setting_name: name,
|
||||
).first || ENV["DISCOURSE_#{name&.upcase}"]
|
||||
end
|
||||
|
||||
def provider_for(model)
|
||||
cloudflare_api_token = fetch_setting("ai_cloudflare_workers_api_token")
|
||||
|
||||
return "cloudflare" if model == "bge-large-en" && cloudflare_api_token.present?
|
||||
|
||||
tei_models = %w[bge-large-en bge-m3 multilingual-e5-large]
|
||||
return "hugging_face" if tei_models.include?(model)
|
||||
|
||||
return "google" if model == "gemini"
|
||||
|
||||
if %w[text-embedding-3-large text-embedding-3-small text-embedding-ada-002].include?(model)
|
||||
return "open_ai"
|
||||
end
|
||||
|
||||
nil
|
||||
end
|
||||
|
||||
def creds_for(provider)
|
||||
# CF
|
||||
if provider == "cloudflare"
|
||||
api_key = fetch_setting("ai_cloudflare_workers_api_token")
|
||||
account_id = fetch_setting("ai_cloudflare_workers_account_id")
|
||||
|
||||
return if api_key.blank? || account_id.blank?
|
||||
|
||||
{
|
||||
url:
|
||||
"https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/baai/bge-large-en-v1.5",
|
||||
api_key: api_key,
|
||||
}
|
||||
# TEI
|
||||
elsif provider == "hugging_face"
|
||||
seeded = false
|
||||
endpoint = fetch_setting("ai_hugging_face_tei_endpoint")
|
||||
|
||||
if endpoint.blank?
|
||||
endpoint = fetch_setting("ai_hugging_face_tei_endpoint_srv")
|
||||
if endpoint.present?
|
||||
endpoint = "srv://#{endpoint}"
|
||||
seeded = true
|
||||
end
|
||||
end
|
||||
|
||||
api_key = fetch_setting("ai_hugging_face_tei_api_key")
|
||||
|
||||
return if endpoint.blank? || api_key.blank?
|
||||
|
||||
{ url: endpoint, api_key: api_key, seeded: seeded }
|
||||
# Gemini
|
||||
elsif provider == "google"
|
||||
api_key = fetch_setting("ai_gemini_api_key")
|
||||
|
||||
return if api_key.blank?
|
||||
|
||||
{
|
||||
url: "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent",
|
||||
api_key: api_key,
|
||||
}
|
||||
|
||||
# Open AI
|
||||
elsif provider == "open_ai"
|
||||
endpoint = fetch_setting("ai_openai_embeddings_url")
|
||||
api_key = fetch_setting("ai_openai_api_key")
|
||||
|
||||
return if endpoint.blank? || api_key.blank?
|
||||
|
||||
{ url: endpoint, api_key: api_key }
|
||||
else
|
||||
nil
|
||||
end
|
||||
end
|
||||
|
||||
def model_attrs(model_name)
|
||||
if model_name == "bge-large-en"
|
||||
{
|
||||
dimensions: 1024,
|
||||
max_sequence_length: 512,
|
||||
id: 4,
|
||||
pg_function: "<#>",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer",
|
||||
}
|
||||
elsif model_name == "bge-m3"
|
||||
{
|
||||
dimensions: 1024,
|
||||
max_sequence_length: 8192,
|
||||
id: 8,
|
||||
pg_function: "<#>",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
|
||||
}
|
||||
elsif model_name == "gemini"
|
||||
{
|
||||
dimensions: 768,
|
||||
max_sequence_length: 1536,
|
||||
id: 5,
|
||||
pg_function: "<=>",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||
}
|
||||
elsif model_name == "multilingual-e5-large"
|
||||
{
|
||||
dimensions: 1024,
|
||||
max_sequence_length: 512,
|
||||
id: 3,
|
||||
pg_function: "<=>",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer",
|
||||
}
|
||||
elsif model_name == "text-embedding-3-large"
|
||||
{
|
||||
dimensions: 2000,
|
||||
max_sequence_length: 8191,
|
||||
id: 7,
|
||||
pg_function: "<=>",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||
provider_params: {
|
||||
model_name: "text-embedding-3-large",
|
||||
},
|
||||
}
|
||||
elsif model_name == "text-embedding-3-small"
|
||||
{
|
||||
dimensions: 1536,
|
||||
max_sequence_length: 8191,
|
||||
id: 6,
|
||||
pg_function: "<=>",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||
provider_params: {
|
||||
model_name: "text-embedding-3-small",
|
||||
},
|
||||
}
|
||||
else
|
||||
{
|
||||
dimensions: 1536,
|
||||
max_sequence_length: 8191,
|
||||
id: 2,
|
||||
pg_function: "<=>",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||
provider_params: {
|
||||
model_name: "text-embedding-ada-002",
|
||||
},
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
def persist_config(attrs)
|
||||
DB.exec(
|
||||
<<~SQL,
|
||||
INSERT INTO embedding_definitions (id, display_name, dimensions, max_sequence_length, version, pg_function, provider, tokenizer_class, url, api_key, provider_params, seeded, created_at, updated_at)
|
||||
VALUES (:id, :display_name, :dimensions, :max_sequence_length, 1, :pg_function, :provider, :tokenizer_class, :url, :api_key, :provider_params, :seeded, :now, :now)
|
||||
SQL
|
||||
id: attrs[:id],
|
||||
display_name: attrs[:display_name],
|
||||
dimensions: attrs[:dimensions],
|
||||
max_sequence_length: attrs[:max_sequence_length],
|
||||
pg_function: attrs[:pg_function],
|
||||
provider: attrs[:provider],
|
||||
tokenizer_class: attrs[:tokenizer_class],
|
||||
url: attrs[:url],
|
||||
api_key: attrs[:api_key],
|
||||
provider_params: attrs[:provider_params],
|
||||
seeded: !!attrs[:seeded],
|
||||
now: Time.zone.now,
|
||||
)
|
||||
|
||||
# We hardcoded the ID to match with already generated embeddings. Let's restart the seq to avoid conflicts.
|
||||
DB.exec(
|
||||
"ALTER SEQUENCE embedding_definitions_id_seq RESTART WITH :new_seq",
|
||||
new_seq: attrs[:id].to_i + 1,
|
||||
)
|
||||
|
||||
DB.exec(<<~SQL, new_value: attrs[:id])
|
||||
INSERT INTO site_settings(name, data_type, value, created_at, updated_at)
|
||||
VALUES ('ai_embeddings_selected_model', 3, :new_value, NOW(), NOW())
|
||||
SQL
|
||||
end
|
||||
end
|
@ -196,7 +196,7 @@ module DiscourseAi
|
||||
)
|
||||
|
||||
plugin.on(:site_setting_changed) do |name, old_value, new_value|
|
||||
if name == :ai_embeddings_model && SiteSetting.ai_embeddings_enabled? &&
|
||||
if name == :ai_embeddings_selected_model && SiteSetting.ai_embeddings_enabled? &&
|
||||
new_value != old_value
|
||||
RagDocumentFragment.delete_all
|
||||
UploadReference
|
||||
|
@ -327,7 +327,7 @@ module DiscourseAi
|
||||
rag_conversation_chunks
|
||||
end
|
||||
|
||||
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector.vdef)
|
||||
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment)
|
||||
|
||||
candidate_fragment_ids =
|
||||
schema
|
||||
|
@ -93,7 +93,7 @@ module DiscourseAi
|
||||
|
||||
def nearest_neighbors(limit: 100)
|
||||
vector = DiscourseAi::Embeddings::Vector.instance
|
||||
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
|
||||
schema = DiscourseAi::Embeddings::Schema.for(Topic)
|
||||
|
||||
raw_vector = vector.vector_from(@text)
|
||||
|
||||
|
20
lib/configuration/embedding_defs_enumerator.rb
Normal file
20
lib/configuration/embedding_defs_enumerator.rb
Normal file
@ -0,0 +1,20 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
require "enum_site_setting"
|
||||
|
||||
module DiscourseAi
|
||||
module Configuration
|
||||
class EmbeddingDefsEnumerator < ::EnumSiteSetting
|
||||
def self.valid_value?(val)
|
||||
true
|
||||
end
|
||||
|
||||
def self.values
|
||||
DB.query_hash(<<~SQL).map(&:symbolize_keys)
|
||||
SELECT display_name AS name, id AS value
|
||||
FROM embedding_definitions
|
||||
SQL
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
19
lib/configuration/embedding_defs_validator.rb
Normal file
19
lib/configuration/embedding_defs_validator.rb
Normal file
@ -0,0 +1,19 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Configuration
|
||||
class EmbeddingDefsValidator
|
||||
def initialize(opts = {})
|
||||
@opts = opts
|
||||
end
|
||||
|
||||
def valid_value?(val)
|
||||
val.blank? || EmbeddingDefinition.exists?(id: val)
|
||||
end
|
||||
|
||||
def error_message
|
||||
""
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -11,41 +11,11 @@ module DiscourseAi
|
||||
return true if val == "f"
|
||||
return true if Rails.env.test?
|
||||
|
||||
chosen_model = SiteSetting.ai_embeddings_model
|
||||
|
||||
return false if !chosen_model
|
||||
|
||||
representation =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(chosen_model)
|
||||
|
||||
return false if representation.nil?
|
||||
|
||||
if !representation.correctly_configured?
|
||||
@representation = representation
|
||||
return false
|
||||
end
|
||||
|
||||
if !can_generate_embeddings?(chosen_model)
|
||||
@unreachable = true
|
||||
return false
|
||||
end
|
||||
|
||||
true
|
||||
SiteSetting.ai_embeddings_selected_model.present?
|
||||
end
|
||||
|
||||
def error_message
|
||||
return(I18n.t("discourse_ai.embeddings.configuration.model_unreachable")) if @unreachable
|
||||
|
||||
@representation&.configuration_hint
|
||||
end
|
||||
|
||||
def can_generate_embeddings?(val)
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base
|
||||
.find_representation(val)
|
||||
.new
|
||||
.inference_client
|
||||
.perform!("this is a test")
|
||||
.present?
|
||||
I18n.t("discourse_ai.embeddings.configuration.choose_model")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -12,19 +12,80 @@ module DiscourseAi
|
||||
POSTS_TABLE = "ai_posts_embeddings"
|
||||
RAG_DOCS_TABLE = "ai_document_fragments_embeddings"
|
||||
|
||||
def self.for(
|
||||
target_klass,
|
||||
vector_def: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
)
|
||||
case target_klass&.name
|
||||
when "Topic"
|
||||
new(TOPICS_TABLE, "topic_id", vector_def)
|
||||
when "Post"
|
||||
new(POSTS_TABLE, "post_id", vector_def)
|
||||
when "RagDocumentFragment"
|
||||
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def)
|
||||
else
|
||||
raise ArgumentError, "Invalid target type for embeddings"
|
||||
EMBEDDING_TARGETS = %w[topics posts document_fragments]
|
||||
EMBEDDING_TABLES = [TOPICS_TABLE, POSTS_TABLE, RAG_DOCS_TABLE]
|
||||
|
||||
MissingEmbeddingError = Class.new(StandardError)
|
||||
|
||||
class << self
|
||||
def for(target_klass)
|
||||
vector_def = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_selected_model)
|
||||
raise "Invalid embeddings selected model" if vector_def.nil?
|
||||
|
||||
case target_klass&.name
|
||||
when "Topic"
|
||||
new(TOPICS_TABLE, "topic_id", vector_def)
|
||||
when "Post"
|
||||
new(POSTS_TABLE, "post_id", vector_def)
|
||||
when "RagDocumentFragment"
|
||||
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def)
|
||||
else
|
||||
raise ArgumentError, "Invalid target type for embeddings"
|
||||
end
|
||||
end
|
||||
|
||||
def search_index_name(table, def_id)
|
||||
"ai_#{table}_embeddings_#{def_id}_1_search_bit"
|
||||
end
|
||||
|
||||
def prepare_search_indexes(vector_def)
|
||||
EMBEDDING_TARGETS.each { |target| DB.exec <<~SQL }
|
||||
CREATE INDEX IF NOT EXISTS #{search_index_name(target, vector_def.id)} ON ai_#{target}_embeddings
|
||||
USING hnsw ((binary_quantize(embeddings)::bit(#{vector_def.dimensions})) bit_hamming_ops)
|
||||
WHERE model_id = #{vector_def.id} AND strategy_id = 1;
|
||||
SQL
|
||||
end
|
||||
|
||||
def correctly_indexed?(vector_def)
|
||||
index_names = EMBEDDING_TARGETS.map { |t| search_index_name(t, vector_def.id) }
|
||||
indexdefs =
|
||||
DB.query_single(
|
||||
"SELECT indexdef FROM pg_indexes WHERE indexname IN (:names)",
|
||||
names: index_names,
|
||||
)
|
||||
|
||||
return false if indexdefs.length < index_names.length
|
||||
|
||||
indexdefs.all? do |defs|
|
||||
defs.include? "(binary_quantize(embeddings))::bit(#{vector_def.dimensions})"
|
||||
end
|
||||
end
|
||||
|
||||
def remove_orphaned_data
|
||||
removed_defs_ids =
|
||||
DB.query_single(
|
||||
"SELECT DISTINCT(model_id) FROM #{TOPICS_TABLE} te LEFT JOIN embedding_definitions ed ON te.model_id = ed.id WHERE ed.id IS NULL",
|
||||
)
|
||||
|
||||
EMBEDDING_TABLES.each do |t|
|
||||
DB.exec(
|
||||
"DELETE FROM #{t} WHERE model_id IN (:removed_defs)",
|
||||
removed_defs: removed_defs_ids,
|
||||
)
|
||||
end
|
||||
|
||||
drop_index_statement =
|
||||
EMBEDDING_TARGETS
|
||||
.reduce([]) do |memo, et|
|
||||
removed_defs_ids.each do |rdi|
|
||||
memo << "DROP INDEX IF EXISTS #{search_index_name(et, rdi)};"
|
||||
end
|
||||
|
||||
memo
|
||||
end
|
||||
.join("\n")
|
||||
|
||||
DB.exec(drop_index_statement)
|
||||
end
|
||||
end
|
||||
|
||||
@ -117,7 +178,7 @@ module DiscourseAi
|
||||
offset: offset,
|
||||
)
|
||||
rescue PG::Error => e
|
||||
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
||||
Rails.logger.error("Error #{e} querying embeddings for model #{vector_def.display_name}")
|
||||
raise MissingEmbeddingError
|
||||
end
|
||||
|
||||
@ -168,7 +229,7 @@ module DiscourseAi
|
||||
|
||||
builder.query(vid: vector_def.id, vsid: vector_def.strategy_id, target_id: record.id)
|
||||
rescue PG::Error => e
|
||||
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
||||
Rails.logger.error("Error #{e} querying embeddings for model #{vector_def.display_name}")
|
||||
raise MissingEmbeddingError
|
||||
end
|
||||
|
||||
|
@ -82,7 +82,7 @@ module DiscourseAi
|
||||
|
||||
over_selection_limit = limit * OVER_SELECTION_FACTOR
|
||||
|
||||
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
|
||||
schema = DiscourseAi::Embeddings::Schema.for(Topic)
|
||||
|
||||
candidate_topic_ids =
|
||||
schema.asymmetric_similarity_search(
|
||||
@ -132,7 +132,7 @@ module DiscourseAi
|
||||
|
||||
candidate_post_ids =
|
||||
DiscourseAi::Embeddings::Schema
|
||||
.for(Post, vector_def: vector.vdef)
|
||||
.for(Post)
|
||||
.asymmetric_similarity_search(
|
||||
search_term_embedding,
|
||||
limit: max_semantic_results_per_page,
|
||||
|
@ -4,13 +4,18 @@ module DiscourseAi
|
||||
module Embeddings
|
||||
class Vector
|
||||
def self.instance
|
||||
new(DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation)
|
||||
vector_def = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_selected_model)
|
||||
raise "Invalid embeddings selected model" if vector_def.nil?
|
||||
|
||||
new(vector_def)
|
||||
end
|
||||
|
||||
def initialize(vector_definition)
|
||||
@vdef = vector_definition
|
||||
end
|
||||
|
||||
delegate :tokenizer, to: :vdef
|
||||
|
||||
def gen_bulk_reprensentations(relation)
|
||||
http_pool_size = 100
|
||||
pool =
|
||||
@ -20,7 +25,7 @@ module DiscourseAi
|
||||
idletime: 30,
|
||||
)
|
||||
|
||||
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector_def: vdef)
|
||||
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class)
|
||||
|
||||
embedding_gen = vdef.inference_client
|
||||
promised_embeddings =
|
||||
@ -53,7 +58,7 @@ module DiscourseAi
|
||||
text = vdef.prepare_target_text(target)
|
||||
return if text.blank?
|
||||
|
||||
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector_def: vdef)
|
||||
schema = DiscourseAi::Embeddings::Schema.for(target.class)
|
||||
|
||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
||||
return if schema.find_by_target(target)&.digest == new_digest
|
||||
|
@ -1,56 +0,0 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class AllMpnetBaseV2 < Base
|
||||
class << self
|
||||
def name
|
||||
"all-mpnet-base-v2"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[
|
||||
ai_embeddings_discourse_service_api_key
|
||||
ai_embeddings_discourse_service_api_endpoint_srv
|
||||
ai_embeddings_discourse_service_api_endpoint
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def dimensions
|
||||
768
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
384
|
||||
end
|
||||
|
||||
def id
|
||||
1
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<#>"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
|
||||
end
|
||||
|
||||
def inference_client
|
||||
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -1,103 +0,0 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class Base
|
||||
class << self
|
||||
def find_representation(model_name)
|
||||
# we are explicit here cause the loader may have not
|
||||
# loaded the subclasses yet
|
||||
[
|
||||
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::BgeM3,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
|
||||
].find { _1.name == model_name }
|
||||
end
|
||||
|
||||
def current_representation
|
||||
find_representation(SiteSetting.ai_embeddings_model).new
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def configuration_hint
|
||||
settings = dependant_setting_names
|
||||
I18n.t(
|
||||
"discourse_ai.embeddings.configuration.hint",
|
||||
settings: settings.join(", "),
|
||||
count: settings.length,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def name
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def dimensions
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def id
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def pg_function
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def version
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def asymmetric_query_prefix
|
||||
""
|
||||
end
|
||||
|
||||
def strategy_id
|
||||
strategy.id
|
||||
end
|
||||
|
||||
def strategy_version
|
||||
strategy.version
|
||||
end
|
||||
|
||||
def prepare_query_text(text, asymetric: false)
|
||||
strategy.prepare_query_text(text, self, asymetric: asymetric)
|
||||
end
|
||||
|
||||
def prepare_target_text(target)
|
||||
strategy.prepare_target_text(target, self)
|
||||
end
|
||||
|
||||
def strategy
|
||||
@strategy ||= DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
end
|
||||
|
||||
def inference_client
|
||||
raise NotImplementedError
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -1,80 +0,0 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class BgeLargeEn < Base
|
||||
class << self
|
||||
def name
|
||||
"bge-large-en"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_cloudflare_workers_api_token.present? ||
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
|
||||
(
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||
)
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[
|
||||
ai_cloudflare_workers_api_token
|
||||
ai_hugging_face_tei_endpoint_srv
|
||||
ai_hugging_face_tei_endpoint
|
||||
ai_embeddings_discourse_service_api_key
|
||||
ai_embeddings_discourse_service_api_endpoint_srv
|
||||
ai_embeddings_discourse_service_api_endpoint
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1024
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
512
|
||||
end
|
||||
|
||||
def id
|
||||
4
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<#>"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::BgeLargeEnTokenizer
|
||||
end
|
||||
|
||||
def asymmetric_query_prefix
|
||||
"Represent this sentence for searching relevant passages:"
|
||||
end
|
||||
|
||||
def inference_client
|
||||
inference_model_name = "baai/bge-large-en-v1.5"
|
||||
|
||||
if SiteSetting.ai_cloudflare_workers_api_token.present?
|
||||
DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name)
|
||||
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
|
||||
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||
DiscourseAi::Inference::DiscourseClassifier.instance(
|
||||
inference_model_name.split("/").last,
|
||||
)
|
||||
else
|
||||
raise "No inference endpoint configured"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -1,51 +0,0 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class BgeM3 < Base
|
||||
class << self
|
||||
def name
|
||||
"bge-m3"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_hugging_face_tei_endpoint_srv ai_hugging_face_tei_endpoint]
|
||||
end
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1024
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
8192
|
||||
end
|
||||
|
||||
def id
|
||||
8
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<#>"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::BgeM3Tokenizer
|
||||
end
|
||||
|
||||
def inference_client
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -1,54 +0,0 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class Gemini < Base
|
||||
class << self
|
||||
def name
|
||||
"gemini"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_gemini_api_key.present?
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_gemini_api_key]
|
||||
end
|
||||
end
|
||||
|
||||
def id
|
||||
5
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def dimensions
|
||||
768
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
1536 # Gemini has a max sequence length of 2048, but the API has a limit of 10000 bytes, hence the lower value
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<=>"
|
||||
end
|
||||
|
||||
# There is no public tokenizer for Gemini, and from the ones we already ship in the plugin
|
||||
# OpenAI gets the closest results. Gemini Tokenizer results in ~10% less tokens, so it's safe
|
||||
# to use OpenAI tokenizer since it will overestimate the number of tokens.
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
def inference_client
|
||||
DiscourseAi::Inference::GeminiEmbeddings.instance
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -1,88 +0,0 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class MultilingualE5Large < Base
|
||||
class << self
|
||||
def name
|
||||
"multilingual-e5-large"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
|
||||
(
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||
)
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[
|
||||
ai_hugging_face_tei_endpoint_srv
|
||||
ai_hugging_face_tei_endpoint
|
||||
ai_embeddings_discourse_service_api_key
|
||||
ai_embeddings_discourse_service_api_endpoint_srv
|
||||
ai_embeddings_discourse_service_api_endpoint
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def id
|
||||
3
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1024
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
512
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<=>"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
|
||||
end
|
||||
|
||||
def inference_client
|
||||
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
|
||||
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
|
||||
else
|
||||
raise "No inference endpoint configured"
|
||||
end
|
||||
end
|
||||
|
||||
def prepare_text(text, asymetric: false)
|
||||
prepared_text = super(text, asymetric: asymetric)
|
||||
|
||||
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
|
||||
return "query: #{prepared_text}"
|
||||
end
|
||||
|
||||
prepared_text
|
||||
end
|
||||
|
||||
def prepare_target_text(target)
|
||||
prepared_text = super(target)
|
||||
|
||||
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
|
||||
return "query: #{prepared_text}"
|
||||
end
|
||||
|
||||
prepared_text
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -1,56 +0,0 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class TextEmbedding3Large < Base
|
||||
class << self
|
||||
def name
|
||||
"text-embedding-3-large"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_openai_api_key.present?
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_openai_api_key]
|
||||
end
|
||||
end
|
||||
|
||||
def id
|
||||
7
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def dimensions
|
||||
# real dimentions are 3072, but we only support up to 2000 in the
|
||||
# indexes, so we downsample to 2000 via API
|
||||
2000
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
8191
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<=>"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
def inference_client
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(
|
||||
model: self.class.name,
|
||||
dimensions: dimensions,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -1,51 +0,0 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class TextEmbedding3Small < Base
|
||||
class << self
|
||||
def name
|
||||
"text-embedding-3-small"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_openai_api_key.present?
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_openai_api_key]
|
||||
end
|
||||
end
|
||||
|
||||
def id
|
||||
6
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1536
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
8191
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<=>"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
def inference_client
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -1,51 +0,0 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class TextEmbeddingAda002 < Base
|
||||
class << self
|
||||
def name
|
||||
"text-embedding-ada-002"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_openai_api_key.present?
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_openai_api_key]
|
||||
end
|
||||
end
|
||||
|
||||
def id
|
||||
2
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1536
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
8191
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<=>"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
def inference_client
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -3,22 +3,13 @@
|
||||
module ::DiscourseAi
|
||||
module Inference
|
||||
class CloudflareWorkersAi
|
||||
def initialize(account_id, api_token, model, referer = Discourse.base_url)
|
||||
@account_id = account_id
|
||||
def initialize(endpoint, api_token, referer = Discourse.base_url)
|
||||
@endpoint = endpoint
|
||||
@api_token = api_token
|
||||
@model = model
|
||||
@referer = referer
|
||||
end
|
||||
|
||||
def self.instance(model)
|
||||
new(
|
||||
SiteSetting.ai_cloudflare_workers_account_id,
|
||||
SiteSetting.ai_cloudflare_workers_api_token,
|
||||
model,
|
||||
)
|
||||
end
|
||||
|
||||
attr_reader :account_id, :api_token, :model, :referer
|
||||
attr_reader :endpoint, :api_token, :referer
|
||||
|
||||
def perform!(content)
|
||||
headers = {
|
||||
@ -29,8 +20,6 @@ module ::DiscourseAi
|
||||
|
||||
payload = { text: [content] }
|
||||
|
||||
endpoint = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/#{model}"
|
||||
|
||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||
response = conn.post(endpoint, payload.to_json, headers)
|
||||
|
||||
@ -43,7 +32,7 @@ module ::DiscourseAi
|
||||
Rails.logger.warn(
|
||||
"Cloudflare Workers AI Embeddings failed with status: #{response.status} body: #{response.body}",
|
||||
)
|
||||
raise Net::HTTPBadResponse
|
||||
raise Net::HTTPBadResponse.new(response.body.to_s)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -1,47 +0,0 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAi
|
||||
module Inference
|
||||
class DiscourseClassifier
|
||||
def initialize(endpoint, api_key, model, referer = Discourse.base_url)
|
||||
@endpoint = endpoint
|
||||
@api_key = api_key
|
||||
@model = model
|
||||
@referer = referer
|
||||
end
|
||||
|
||||
def self.instance(model)
|
||||
endpoint =
|
||||
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
|
||||
service =
|
||||
DiscourseAi::Utils::DnsSrv.lookup(
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
|
||||
)
|
||||
"https://#{service.target}:#{service.port}"
|
||||
else
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint
|
||||
end
|
||||
|
||||
new(
|
||||
"#{endpoint}/api/v1/classify",
|
||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||
model,
|
||||
)
|
||||
end
|
||||
|
||||
attr_reader :endpoint, :api_key, :model, :referer
|
||||
|
||||
def perform!(content)
|
||||
headers = { "Referer" => referer, "Content-Type" => "application/json" }
|
||||
headers["X-API-KEY"] = api_key if api_key.present?
|
||||
|
||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||
response = conn.post(endpoint, { model: model, content: content }.to_json, headers)
|
||||
|
||||
raise Net::HTTPBadResponse if ![200, 415].include?(response.status)
|
||||
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -3,21 +3,17 @@
|
||||
module ::DiscourseAi
|
||||
module Inference
|
||||
class GeminiEmbeddings
|
||||
def self.instance
|
||||
new(SiteSetting.ai_gemini_api_key)
|
||||
end
|
||||
|
||||
def initialize(api_key, referer = Discourse.base_url)
|
||||
def initialize(embedding_url, api_key, referer = Discourse.base_url)
|
||||
@api_key = api_key
|
||||
@embedding_url = embedding_url
|
||||
@referer = referer
|
||||
end
|
||||
|
||||
attr_reader :api_key, :referer
|
||||
attr_reader :embedding_url, :api_key, :referer
|
||||
|
||||
def perform!(content)
|
||||
headers = { "Referer" => referer, "Content-Type" => "application/json" }
|
||||
url =
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}"
|
||||
url = "#{embedding_url}\?key\=#{api_key}"
|
||||
body = { content: { parts: [{ text: content }] } }
|
||||
|
||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||
@ -32,7 +28,7 @@ module ::DiscourseAi
|
||||
Rails.logger.warn(
|
||||
"Google Gemini Embeddings failed with status: #{response.status} body: #{response.body}",
|
||||
)
|
||||
raise Net::HTTPBadResponse
|
||||
raise Net::HTTPBadResponse.new(response.body.to_s)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -12,19 +12,6 @@ module ::DiscourseAi
|
||||
attr_reader :endpoint, :key, :referer
|
||||
|
||||
class << self
|
||||
def instance
|
||||
endpoint =
|
||||
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||
service =
|
||||
DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
|
||||
"https://#{service.target}:#{service.port}"
|
||||
else
|
||||
SiteSetting.ai_hugging_face_tei_endpoint
|
||||
end
|
||||
|
||||
new(endpoint, SiteSetting.ai_hugging_face_tei_api_key)
|
||||
end
|
||||
|
||||
def configured?
|
||||
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
|
||||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||
@ -100,7 +87,7 @@ module ::DiscourseAi
|
||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||
response = conn.post(endpoint, body, headers)
|
||||
|
||||
raise Net::HTTPBadResponse if ![200].include?(response.status)
|
||||
raise Net::HTTPBadResponse.new(response.body.to_s) if ![200].include?(response.status)
|
||||
|
||||
JSON.parse(response.body, symbolize_names: true).first
|
||||
end
|
||||
|
@ -12,10 +12,6 @@ module ::DiscourseAi
|
||||
|
||||
attr_reader :endpoint, :api_key, :model, :dimensions
|
||||
|
||||
def self.instance(model:, dimensions: nil)
|
||||
new(SiteSetting.ai_openai_embeddings_url, SiteSetting.ai_openai_api_key, model, dimensions)
|
||||
end
|
||||
|
||||
def perform!(content)
|
||||
headers = { "Content-Type" => "application/json" }
|
||||
|
||||
@ -29,7 +25,7 @@ module ::DiscourseAi
|
||||
payload[:dimensions] = dimensions if dimensions.present?
|
||||
|
||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||
response = conn.post(SiteSetting.ai_openai_embeddings_url, payload.to_json, headers)
|
||||
response = conn.post(endpoint, payload.to_json, headers)
|
||||
|
||||
case response.status
|
||||
when 200
|
||||
@ -40,7 +36,7 @@ module ::DiscourseAi
|
||||
Rails.logger.warn(
|
||||
"OpenAI Embeddings failed with status: #{response.status} body: #{response.body}",
|
||||
)
|
||||
raise Net::HTTPBadResponse
|
||||
raise Net::HTTPBadResponse.new(response.body.to_s)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -35,6 +35,7 @@ register_asset "stylesheets/modules/embeddings/common/semantic-search.scss"
|
||||
register_asset "stylesheets/modules/sentiment/common/dashboard.scss"
|
||||
|
||||
register_asset "stylesheets/modules/llms/common/ai-llms-editor.scss"
|
||||
register_asset "stylesheets/modules/embeddings/common/ai-embedding-editor.scss"
|
||||
|
||||
register_asset "stylesheets/modules/llms/common/usage.scss"
|
||||
register_asset "stylesheets/modules/llms/common/spam.scss"
|
||||
|
@ -1,17 +0,0 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
require_relative "../support/embeddings_generation_stubs"
|
||||
|
||||
RSpec.describe DiscourseAi::Configuration::EmbeddingsModelValidator do
|
||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||
|
||||
describe "#can_generate_embeddings?" do
|
||||
it "works" do
|
||||
discourse_model = "all-mpnet-base-v2"
|
||||
|
||||
EmbeddingsGenerationStubs.discourse_service(discourse_model, "this is a test", [1] * 1024)
|
||||
|
||||
expect(subject.can_generate_embeddings?(discourse_model)).to eq(true)
|
||||
end
|
||||
end
|
||||
end
|
40
spec/fabricators/embedding_definition_fabricator.rb
Normal file
40
spec/fabricators/embedding_definition_fabricator.rb
Normal file
@ -0,0 +1,40 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
Fabricator(:embedding_definition) do
|
||||
display_name "Multilingual E5 Large"
|
||||
provider "hugging_face"
|
||||
tokenizer_class "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer"
|
||||
api_key "123"
|
||||
url "https://test.com/embeddings"
|
||||
provider_params nil
|
||||
pg_function "<=>"
|
||||
max_sequence_length 512
|
||||
dimensions 1024
|
||||
end
|
||||
|
||||
Fabricator(:cloudflare_embedding_def, from: :embedding_definition) do
|
||||
display_name "BGE Large EN"
|
||||
provider "cloudflare"
|
||||
pg_function "<#>"
|
||||
tokenizer_class "DiscourseAi::Tokenizer::BgeLargeEnTokenizer"
|
||||
provider_params nil
|
||||
end
|
||||
|
||||
Fabricator(:open_ai_embedding_def, from: :embedding_definition) do
|
||||
display_name "ADA 002"
|
||||
provider "open_ai"
|
||||
url "https://api.openai.com/v1/embeddings"
|
||||
tokenizer_class "DiscourseAi::Tokenizer::OpenAiTokenizer"
|
||||
provider_params { { model_name: "text-embedding-ada-002" } }
|
||||
max_sequence_length 8191
|
||||
dimensions 1536
|
||||
end
|
||||
|
||||
Fabricator(:gemini_embedding_def, from: :embedding_definition) do
|
||||
display_name "Gemini's embedding-001"
|
||||
provider "google"
|
||||
dimensions 768
|
||||
max_sequence_length 1536
|
||||
tokenizer_class "DiscourseAi::Tokenizer::OpenAiTokenizer"
|
||||
url "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent"
|
||||
end
|
@ -6,9 +6,8 @@ RSpec.describe Jobs::DigestRagUpload do
|
||||
|
||||
let(:document_file) { StringIO.new("some text" * 200) }
|
||||
|
||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
|
||||
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
|
||||
fab!(:cloudflare_embedding_def)
|
||||
let(:expected_embedding) { [0.0038493] * cloudflare_embedding_def.dimensions }
|
||||
|
||||
let(:document_with_metadata) { plugin_file_from_fixtures("doc_with_metadata.txt", "rag") }
|
||||
|
||||
@ -21,15 +20,14 @@ RSpec.describe Jobs::DigestRagUpload do
|
||||
end
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_selected_model = cloudflare_embedding_def.id
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
SiteSetting.ai_embeddings_model = "bge-large-en"
|
||||
SiteSetting.authorized_extensions = "txt"
|
||||
|
||||
WebMock.stub_request(
|
||||
:post,
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
||||
WebMock.stub_request(:post, cloudflare_embedding_def.url).to_return(
|
||||
status: 200,
|
||||
body: JSON.dump(expected_embedding),
|
||||
)
|
||||
end
|
||||
|
||||
describe "#execute" do
|
||||
|
@ -2,23 +2,26 @@
|
||||
|
||||
RSpec.describe Jobs::GenerateRagEmbeddings do
|
||||
describe "#execute" do
|
||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||
|
||||
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
|
||||
let(:expected_embedding) { [0.0038493] * vector_def.dimensions }
|
||||
|
||||
fab!(:ai_persona)
|
||||
|
||||
fab!(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, target: ai_persona) }
|
||||
fab!(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, target: ai_persona) }
|
||||
let(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, target: ai_persona) }
|
||||
let(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, target: ai_persona) }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
|
||||
WebMock.stub_request(
|
||||
:post,
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
||||
rag_document_fragment_1
|
||||
rag_document_fragment_2
|
||||
|
||||
WebMock.stub_request(:post, vector_def.url).to_return(
|
||||
status: 200,
|
||||
body: JSON.dump(expected_embedding),
|
||||
)
|
||||
end
|
||||
|
||||
it "generates a new vector for each fragment" do
|
||||
|
56
spec/jobs/regular/manage_embedding_def_search_index_spec.rb
Normal file
56
spec/jobs/regular/manage_embedding_def_search_index_spec.rb
Normal file
@ -0,0 +1,56 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe Jobs::ManageEmbeddingDefSearchIndex do
|
||||
fab!(:embedding_definition)
|
||||
|
||||
describe "#execute" do
|
||||
context "when there is no embedding def" do
|
||||
it "does nothing" do
|
||||
invalid_id = 999_999_999
|
||||
|
||||
subject.execute(id: invalid_id)
|
||||
|
||||
expect(
|
||||
DiscourseAi::Embeddings::Schema.correctly_indexed?(
|
||||
EmbeddingDefinition.new(id: invalid_id),
|
||||
),
|
||||
).to eq(false)
|
||||
end
|
||||
end
|
||||
|
||||
context "when the embedding def is fresh" do
|
||||
it "creates the indexes" do
|
||||
subject.execute(id: embedding_definition.id)
|
||||
|
||||
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(true)
|
||||
end
|
||||
|
||||
it "creates them only once" do
|
||||
subject.execute(id: embedding_definition.id)
|
||||
subject.execute(id: embedding_definition.id)
|
||||
|
||||
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(true)
|
||||
end
|
||||
|
||||
context "when one of the idxs is missing" do
|
||||
it "automatically recovers by creating it" do
|
||||
DB.exec <<~SQL
|
||||
CREATE INDEX IF NOT EXISTS ai_topics_embeddings_#{embedding_definition.id}_1_search_bit ON ai_topics_embeddings
|
||||
USING hnsw ((binary_quantize(embeddings)::bit(#{embedding_definition.dimensions})) bit_hamming_ops)
|
||||
WHERE model_id = #{embedding_definition.id} AND strategy_id = 1;
|
||||
SQL
|
||||
|
||||
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(
|
||||
false,
|
||||
)
|
||||
|
||||
subject.execute(id: embedding_definition.id)
|
||||
|
||||
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(
|
||||
true,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -19,11 +19,11 @@ RSpec.describe Jobs::EmbeddingsBackfill do
|
||||
topic
|
||||
end
|
||||
|
||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
SiteSetting.ai_embeddings_backfill_batch_size = 1
|
||||
SiteSetting.ai_embeddings_per_post_enabled = true
|
||||
Jobs.run_immediately!
|
||||
@ -32,10 +32,10 @@ RSpec.describe Jobs::EmbeddingsBackfill do
|
||||
it "backfills topics based on bumped_at date" do
|
||||
embedding = Array.new(1024) { 1 }
|
||||
|
||||
WebMock.stub_request(
|
||||
:post,
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
).to_return(status: 200, body: JSON.dump(embedding))
|
||||
WebMock.stub_request(:post, "https://test.com/embeddings").to_return(
|
||||
status: 200,
|
||||
body: JSON.dump(embedding),
|
||||
)
|
||||
|
||||
Jobs::EmbeddingsBackfill.new.execute({})
|
||||
|
||||
|
71
spec/jobs/scheduled/remove_orphaned_embeddings_spec.rb
Normal file
71
spec/jobs/scheduled/remove_orphaned_embeddings_spec.rb
Normal file
@ -0,0 +1,71 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe Jobs::RemoveOrphanedEmbeddings do
|
||||
describe "#execute" do
|
||||
fab!(:embedding_definition)
|
||||
fab!(:embedding_definition_2) { Fabricate(:embedding_definition) }
|
||||
fab!(:topic)
|
||||
fab!(:post)
|
||||
|
||||
before do
|
||||
DiscourseAi::Embeddings::Schema.prepare_search_indexes(embedding_definition)
|
||||
DiscourseAi::Embeddings::Schema.prepare_search_indexes(embedding_definition_2)
|
||||
|
||||
# Seed embeddings. One of each def x target classes.
|
||||
[embedding_definition, embedding_definition_2].each do |edef|
|
||||
SiteSetting.ai_embeddings_selected_model = edef.id
|
||||
|
||||
[topic, post].each do |target|
|
||||
schema = DiscourseAi::Embeddings::Schema.for(target.class)
|
||||
schema.store(target, [1] * edef.dimensions, "test")
|
||||
end
|
||||
end
|
||||
|
||||
embedding_definition.destroy!
|
||||
end
|
||||
|
||||
def find_all_embeddings_of(target, table, target_column)
|
||||
DB.query_single("SELECT model_id FROM #{table} WHERE #{target_column} = #{target.id}")
|
||||
end
|
||||
|
||||
it "delete embeddings without an existing embedding definition" do
|
||||
expect(find_all_embeddings_of(post, "ai_posts_embeddings", "post_id")).to contain_exactly(
|
||||
embedding_definition.id,
|
||||
embedding_definition_2.id,
|
||||
)
|
||||
expect(find_all_embeddings_of(topic, "ai_topics_embeddings", "topic_id")).to contain_exactly(
|
||||
embedding_definition.id,
|
||||
embedding_definition_2.id,
|
||||
)
|
||||
|
||||
subject.execute({})
|
||||
|
||||
expect(find_all_embeddings_of(topic, "ai_topics_embeddings", "topic_id")).to contain_exactly(
|
||||
embedding_definition_2.id,
|
||||
)
|
||||
expect(find_all_embeddings_of(post, "ai_posts_embeddings", "post_id")).to contain_exactly(
|
||||
embedding_definition_2.id,
|
||||
)
|
||||
end
|
||||
|
||||
it "deletes orphaned indexes" do
|
||||
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(true)
|
||||
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition_2)).to eq(true)
|
||||
|
||||
subject.execute({})
|
||||
|
||||
index_names =
|
||||
DiscourseAi::Embeddings::Schema::EMBEDDING_TARGETS.map do |t|
|
||||
"ai_#{t}_embeddings_#{embedding_definition.id}_1_search_bit"
|
||||
end
|
||||
indexnames =
|
||||
DB.query_single(
|
||||
"SELECT indexname FROM pg_indexes WHERE indexname IN (:names)",
|
||||
names: index_names,
|
||||
)
|
||||
|
||||
expect(indexnames).to be_empty
|
||||
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition_2)).to eq(true)
|
||||
end
|
||||
end
|
||||
end
|
@ -1,15 +0,0 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Configuration::EmbeddingsModuleValidator do
|
||||
let(:validator) { described_class.new }
|
||||
|
||||
describe "#can_generate_embeddings?" do
|
||||
it "returns true if embeddings can be generated" do
|
||||
stub_request(
|
||||
:post,
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent?key=",
|
||||
).to_return(status: 200, body: { embedding: { values: [1, 2, 3] } }.to_json)
|
||||
expect(validator.can_generate_embeddings?("gemini")).to eq(true)
|
||||
end
|
||||
end
|
||||
end
|
@ -4,7 +4,7 @@ require "rails_helper"
|
||||
require "webmock/rspec"
|
||||
|
||||
RSpec.describe DiscourseAi::Inference::CloudflareWorkersAi do
|
||||
subject { described_class.new(account_id, api_token, model) }
|
||||
subject { described_class.new(endpoint, api_token) }
|
||||
|
||||
let(:account_id) { "test_account_id" }
|
||||
let(:api_token) { "test_api_token" }
|
||||
|
@ -297,9 +297,11 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||
end
|
||||
|
||||
describe "#craft_prompt" do
|
||||
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||
|
||||
before do
|
||||
Group.refresh_automatic_groups!
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
end
|
||||
|
||||
@ -326,13 +328,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||
fab!(:llm_model) { Fabricate(:fake_model) }
|
||||
|
||||
it "will run the question consolidator" do
|
||||
vector_def = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
context_embedding = vector_def.dimensions.times.map { rand(-1.0...1.0) }
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
SiteSetting.ai_embeddings_model,
|
||||
consolidated_question,
|
||||
context_embedding,
|
||||
)
|
||||
EmbeddingsGenerationStubs.hugging_face_service(consolidated_question, context_embedding)
|
||||
|
||||
custom_ai_persona =
|
||||
Fabricate(
|
||||
@ -373,14 +370,11 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||
end
|
||||
|
||||
context "when a persona has RAG uploads" do
|
||||
let(:vector_def) do
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
end
|
||||
let(:embedding_value) { 0.04381 }
|
||||
let(:prompt_cc_embeddings) { [embedding_value] * vector_def.dimensions }
|
||||
|
||||
def stub_fragments(fragment_count, persona: ai_persona)
|
||||
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector_def)
|
||||
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment)
|
||||
|
||||
fragment_count.times do |i|
|
||||
fragment =
|
||||
@ -403,8 +397,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||
stored_ai_persona = AiPersona.find(ai_persona.id)
|
||||
UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id])
|
||||
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
SiteSetting.ai_embeddings_model,
|
||||
EmbeddingsGenerationStubs.hugging_face_service(
|
||||
with_cc.dig(:conversation_context, 0, :content),
|
||||
prompt_cc_embeddings,
|
||||
)
|
||||
|
@ -108,16 +108,13 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
||||
|
||||
it "supports semantic search when enabled" do
|
||||
assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model)
|
||||
vector_def = Fabricate(:embedding_definition)
|
||||
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||
SiteSetting.ai_embeddings_semantic_search_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
hyde_embedding = [0.049382] * vector_rep.dimensions
|
||||
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
SiteSetting.ai_embeddings_model,
|
||||
query,
|
||||
hyde_embedding,
|
||||
)
|
||||
hyde_embedding = [0.049382] * vector_def.dimensions
|
||||
|
||||
EmbeddingsGenerationStubs.hugging_face_service(query, hyde_embedding)
|
||||
|
||||
post1 = Fabricate(:post, topic: topic_with_tags)
|
||||
search =
|
||||
|
@ -1,6 +1,7 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
|
||||
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||
fab!(:user)
|
||||
fab!(:muted_category) { Fabricate(:category) }
|
||||
fab!(:category_mute) do
|
||||
@ -19,14 +20,13 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
|
||||
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
SiteSetting.ai_embeddings_model = "bge-large-en"
|
||||
|
||||
WebMock.stub_request(
|
||||
:post,
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
||||
WebMock.stub_request(:post, vector_def.url).to_return(
|
||||
status: 200,
|
||||
body: JSON.dump([expected_embedding]),
|
||||
)
|
||||
|
||||
vector.generate_representation_from(topic)
|
||||
vector.generate_representation_from(muted_topic)
|
||||
|
@ -3,25 +3,26 @@
|
||||
RSpec.describe Jobs::GenerateEmbeddings do
|
||||
subject(:job) { described_class.new }
|
||||
|
||||
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||
|
||||
describe "#execute" do
|
||||
before do
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
end
|
||||
|
||||
fab!(:topic)
|
||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||
|
||||
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector_def) }
|
||||
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vector_def) }
|
||||
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic) }
|
||||
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post) }
|
||||
|
||||
it "works for topics" do
|
||||
expected_embedding = [0.0038493] * vector_def.dimensions
|
||||
|
||||
text = vector_def.prepare_target_text(topic)
|
||||
|
||||
EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding)
|
||||
|
||||
job.execute(target_id: topic.id, target_type: "Topic")
|
||||
|
||||
@ -32,7 +33,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
||||
expected_embedding = [0.0038493] * vector_def.dimensions
|
||||
|
||||
text = vector_def.prepare_target_text(post)
|
||||
EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding)
|
||||
|
||||
job.execute(target_id: post.id, target_type: "Post")
|
||||
|
||||
|
@ -1,12 +1,14 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::Schema do
|
||||
subject(:posts_schema) { described_class.for(Post, vector_def: vector_def) }
|
||||
subject(:posts_schema) { described_class.for(Post) }
|
||||
|
||||
fab!(:vector_def) { Fabricate(:cloudflare_embedding_def) }
|
||||
let(:embeddings) { [0.0038490295] * vector_def.dimensions }
|
||||
fab!(:post) { Fabricate(:post, post_number: 1) }
|
||||
let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") }
|
||||
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new }
|
||||
|
||||
before { SiteSetting.ai_embeddings_selected_model = vector_def.id }
|
||||
|
||||
before { posts_schema.store(post, embeddings, digest) }
|
||||
|
||||
|
@ -13,7 +13,13 @@ describe DiscourseAi::Embeddings::SemanticRelated do
|
||||
fab!(:secured_category_topic) { Fabricate(:topic, category: secured_category) }
|
||||
fab!(:closed_topic) { Fabricate(:topic, closed: true) }
|
||||
|
||||
before { SiteSetting.ai_embeddings_semantic_related_topics_enabled = true }
|
||||
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_semantic_related_topics_enabled = true
|
||||
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
end
|
||||
|
||||
describe "#related_topic_ids_for" do
|
||||
context "when embeddings do not exist" do
|
||||
@ -24,21 +30,15 @@ describe DiscourseAi::Embeddings::SemanticRelated do
|
||||
topic
|
||||
end
|
||||
|
||||
let(:vector_rep) do
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
end
|
||||
|
||||
it "properly generates embeddings if missing" do
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
Jobs.run_immediately!
|
||||
|
||||
embedding = Array.new(1024) { 1 }
|
||||
|
||||
WebMock.stub_request(
|
||||
:post,
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
).to_return(status: 200, body: JSON.dump(embedding))
|
||||
WebMock.stub_request(:post, vector_def.url).to_return(
|
||||
status: 200,
|
||||
body: JSON.dump([embedding]),
|
||||
)
|
||||
|
||||
# miss first
|
||||
ids = semantic_related.related_topic_ids_for(topic)
|
||||
|
@ -7,22 +7,18 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
||||
let(:query) { "test_query" }
|
||||
let(:subject) { described_class.new(Guardian.new(user)) }
|
||||
|
||||
before { assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model) }
|
||||
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||
assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model)
|
||||
end
|
||||
|
||||
describe "#search_for_topics" do
|
||||
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
|
||||
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
let(:hyde_embedding) { [0.049382] * vector_def.dimensions }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
SiteSetting.ai_embeddings_model,
|
||||
hypothetical_post,
|
||||
hyde_embedding,
|
||||
)
|
||||
end
|
||||
before { EmbeddingsGenerationStubs.hugging_face_service(hypothetical_post, hyde_embedding) }
|
||||
|
||||
after { described_class.clear_cache_for(query) }
|
||||
|
||||
|
@ -9,6 +9,10 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
||||
|
||||
fab!(:target) { Fabricate(:topic) }
|
||||
|
||||
fab!(:vector_def) { Fabricate(:cloudflare_embedding_def) }
|
||||
|
||||
before { SiteSetting.ai_embeddings_selected_model = vector_def.id }
|
||||
|
||||
# The Distance gap to target increases for each element of topics.
|
||||
def seed_embeddings(topics)
|
||||
schema = DiscourseAi::Embeddings::Schema.for(Topic)
|
||||
|
@ -19,13 +19,13 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
||||
)
|
||||
end
|
||||
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
|
||||
|
||||
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new }
|
||||
fab!(:open_ai_embedding_def)
|
||||
|
||||
it "truncates a topic" do
|
||||
prepared_text = truncation.prepare_target_text(topic, vector_def)
|
||||
prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)
|
||||
|
||||
expect(vector_def.tokenizer.size(prepared_text)).to be <= vector_def.max_sequence_length
|
||||
expect(open_ai_embedding_def.tokenizer.size(prepared_text)).to be <=
|
||||
open_ai_embedding_def.max_sequence_length
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -7,8 +7,10 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
||||
let(:expected_embedding_1) { [0.0038493] * vdef.dimensions }
|
||||
let(:expected_embedding_2) { [0.0037684] * vdef.dimensions }
|
||||
|
||||
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vdef) }
|
||||
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vdef) }
|
||||
before { SiteSetting.ai_embeddings_selected_model = vdef.id }
|
||||
|
||||
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic) }
|
||||
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post) }
|
||||
|
||||
fab!(:topic)
|
||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||
@ -84,63 +86,16 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
||||
end
|
||||
end
|
||||
|
||||
context "with text-embedding-ada-002" do
|
||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
end
|
||||
|
||||
context "with all all-mpnet-base-v2" do
|
||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new }
|
||||
|
||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
end
|
||||
|
||||
context "with gemini" do
|
||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::Gemini.new }
|
||||
let(:api_key) { "test-123" }
|
||||
|
||||
before { SiteSetting.ai_gemini_api_key = api_key }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.gemini_service(api_key, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
end
|
||||
|
||||
context "with multilingual-e5-large" do
|
||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large.new }
|
||||
|
||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
end
|
||||
|
||||
context "with text-embedding-3-large" do
|
||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large.new }
|
||||
context "with open_ai as the provider" do
|
||||
fab!(:vdef) { Fabricate(:open_ai_embedding_def) }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.openai_service(
|
||||
vdef.class.name,
|
||||
vdef.lookup_custom_param("model_name"),
|
||||
text,
|
||||
expected_embedding,
|
||||
extra_args: {
|
||||
dimensions: 2000,
|
||||
dimensions: vdef.dimensions,
|
||||
},
|
||||
)
|
||||
end
|
||||
@ -148,11 +103,31 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
end
|
||||
|
||||
context "with text-embedding-3-small" do
|
||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small.new }
|
||||
context "with hugging_face as the provider" do
|
||||
fab!(:vdef) { Fabricate(:embedding_definition) }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
end
|
||||
|
||||
context "with google as the provider" do
|
||||
fab!(:vdef) { Fabricate(:gemini_embedding_def) }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.gemini_service(vdef.api_key, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
end
|
||||
|
||||
context "with cloudflare as the provider" do
|
||||
fab!(:vdef) { Fabricate(:cloudflare_embedding_def) }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.cloudflare_service(text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
|
@ -204,12 +204,12 @@ RSpec.describe AiTool do
|
||||
end
|
||||
|
||||
context "when defining RAG fragments" do
|
||||
fab!(:cloudflare_embedding_def)
|
||||
|
||||
before do
|
||||
SiteSetting.authorized_extensions = "txt"
|
||||
SiteSetting.ai_embeddings_selected_model = cloudflare_embedding_def.id
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
SiteSetting.ai_embeddings_model = "bge-large-en"
|
||||
|
||||
Jobs.run_immediately!
|
||||
end
|
||||
|
||||
@ -228,9 +228,9 @@ RSpec.describe AiTool do
|
||||
# this is a trick, we get ever increasing embeddings, this gives us in turn
|
||||
# 100% consistent search results
|
||||
@counter = 0
|
||||
stub_request(:post, "http://test.com/api/v1/classify").to_return(
|
||||
stub_request(:post, cloudflare_embedding_def.url).to_return(
|
||||
status: 200,
|
||||
body: lambda { |req| ([@counter += 1] * 1024).to_json },
|
||||
body: lambda { |req| { result: { data: [([@counter += 1] * 1024)] } }.to_json },
|
||||
headers: {
|
||||
},
|
||||
)
|
||||
|
@ -4,11 +4,9 @@ RSpec.describe RagDocumentFragment do
|
||||
fab!(:persona) { Fabricate(:ai_persona) }
|
||||
fab!(:upload_1) { Fabricate(:upload) }
|
||||
fab!(:upload_2) { Fabricate(:upload) }
|
||||
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
end
|
||||
before { SiteSetting.ai_embeddings_enabled = true }
|
||||
|
||||
describe ".link_uploads_and_persona" do
|
||||
it "does nothing if there is no persona" do
|
||||
@ -76,25 +74,25 @@ RSpec.describe RagDocumentFragment do
|
||||
describe ".indexing_status" do
|
||||
let(:vector) { DiscourseAi::Embeddings::Vector.instance }
|
||||
|
||||
fab!(:rag_document_fragment_1) do
|
||||
let(:rag_document_fragment_1) do
|
||||
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
||||
end
|
||||
|
||||
fab!(:rag_document_fragment_2) do
|
||||
let(:rag_document_fragment_2) do
|
||||
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
||||
end
|
||||
|
||||
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
|
||||
let(:expected_embedding) { [0.0038493] * vector_def.dimensions }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
SiteSetting.ai_embeddings_model = "bge-large-en"
|
||||
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||
rag_document_fragment_1
|
||||
rag_document_fragment_2
|
||||
|
||||
WebMock.stub_request(
|
||||
:post,
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
||||
WebMock.stub_request(:post, "https://test.com/embeddings").to_return(
|
||||
status: 200,
|
||||
body: JSON.dump(expected_embedding),
|
||||
)
|
||||
|
||||
vector.generate_representation_from(rag_document_fragment_1)
|
||||
end
|
||||
@ -106,7 +104,7 @@ RSpec.describe RagDocumentFragment do
|
||||
UploadReference.create!(upload_id: upload_2.id, target: persona)
|
||||
|
||||
Sidekiq::Testing.fake! do
|
||||
SiteSetting.ai_embeddings_model = "all-mpnet-base-v2"
|
||||
SiteSetting.ai_embeddings_selected_model = Fabricate(:open_ai_embedding_def).id
|
||||
expect(RagDocumentFragment.exists?(old_id)).to eq(false)
|
||||
expect(Jobs::DigestRagUpload.jobs.size).to eq(2)
|
||||
end
|
||||
|
184
spec/requests/admin/ai_embeddings_controller_spec.rb
Normal file
184
spec/requests/admin/ai_embeddings_controller_spec.rb
Normal file
@ -0,0 +1,184 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do
|
||||
fab!(:admin)
|
||||
|
||||
before { sign_in(admin) }
|
||||
|
||||
let(:valid_attrs) do
|
||||
{
|
||||
display_name: "Embedding config test",
|
||||
dimensions: 1001,
|
||||
max_sequence_length: 234,
|
||||
pg_function: "<#>",
|
||||
provider: "hugging_face",
|
||||
url: "https://test.com/api/v1/embeddings",
|
||||
api_key: "test",
|
||||
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
|
||||
}
|
||||
end
|
||||
|
||||
describe "POST #create" do
|
||||
context "with valid attrs" do
|
||||
it "creates a new embedding definition" do
|
||||
post "/admin/plugins/discourse-ai/ai-embeddings.json", params: { ai_embedding: valid_attrs }
|
||||
|
||||
created_def = EmbeddingDefinition.last
|
||||
|
||||
expect(response.status).to eq(201)
|
||||
expect(created_def.display_name).to eq(valid_attrs[:display_name])
|
||||
end
|
||||
|
||||
it "stores provider-specific config params" do
|
||||
post "/admin/plugins/discourse-ai/ai-embeddings.json",
|
||||
params: {
|
||||
ai_embedding:
|
||||
valid_attrs.merge(
|
||||
provider: "open_ai",
|
||||
provider_params: {
|
||||
model_name: "embeddings-v1",
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
created_def = EmbeddingDefinition.last
|
||||
|
||||
expect(response.status).to eq(201)
|
||||
expect(created_def.provider_params["model_name"]).to eq("embeddings-v1")
|
||||
end
|
||||
|
||||
it "ignores parameters not associated with that provider" do
|
||||
post "/admin/plugins/discourse-ai/ai-embeddings.json",
|
||||
params: {
|
||||
ai_embedding: valid_attrs.merge(provider_params: { custom: "custom" }),
|
||||
}
|
||||
|
||||
created_def = EmbeddingDefinition.last
|
||||
|
||||
expect(response.status).to eq(201)
|
||||
expect(created_def.lookup_custom_param("custom")).to be_nil
|
||||
end
|
||||
end
|
||||
|
||||
context "with invalid attrs" do
|
||||
it "doesn't create a new embedding defitinion" do
|
||||
post "/admin/plugins/discourse-ai/ai-embeddings.json",
|
||||
params: {
|
||||
ai_embedding: valid_attrs.except(:provider),
|
||||
}
|
||||
|
||||
created_def = EmbeddingDefinition.last
|
||||
|
||||
expect(created_def).to be_nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "PUT #update" do
|
||||
fab!(:embedding_definition)
|
||||
|
||||
context "with valid update params" do
|
||||
let(:update_attrs) { { provider: "open_ai" } }
|
||||
|
||||
it "updates the model" do
|
||||
put "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json",
|
||||
params: {
|
||||
ai_embedding: update_attrs,
|
||||
}
|
||||
|
||||
expect(response.status).to eq(200)
|
||||
expect(embedding_definition.reload.provider).to eq(update_attrs[:provider])
|
||||
end
|
||||
|
||||
it "returns a 404 if there is no model with the given Id" do
|
||||
put "/admin/plugins/discourse-ai/ai-embeddings/9999999.json"
|
||||
|
||||
expect(response.status).to eq(404)
|
||||
end
|
||||
|
||||
it "doesn't allow dimenstions to be updated" do
|
||||
new_dimensions = 200
|
||||
|
||||
put "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json",
|
||||
params: {
|
||||
ai_embedding: {
|
||||
dimensions: new_dimensions,
|
||||
},
|
||||
}
|
||||
|
||||
expect(response.status).to eq(200)
|
||||
expect(embedding_definition.reload.dimensions).not_to eq(new_dimensions)
|
||||
end
|
||||
end
|
||||
|
||||
context "with invalid update params" do
|
||||
it "doesn't update the model" do
|
||||
put "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json",
|
||||
params: {
|
||||
ai_embedding: {
|
||||
url: "",
|
||||
},
|
||||
}
|
||||
|
||||
expect(response.status).to eq(422)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "DELETE #destroy" do
|
||||
fab!(:embedding_definition)
|
||||
|
||||
it "destroys the embedding defitinion" do
|
||||
expect {
|
||||
delete "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json"
|
||||
|
||||
expect(response).to have_http_status(:no_content)
|
||||
}.to change(EmbeddingDefinition, :count).by(-1)
|
||||
end
|
||||
|
||||
it "validates the model is not in use" do
|
||||
SiteSetting.ai_embeddings_selected_model = embedding_definition.id
|
||||
|
||||
delete "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json"
|
||||
|
||||
expect(response.status).to eq(409)
|
||||
expect(embedding_definition.reload).to eq(embedding_definition)
|
||||
end
|
||||
end
|
||||
|
||||
describe "GET #test" do
|
||||
context "when we can generate an embedding" do
|
||||
it "returns a success true flag" do
|
||||
WebMock.stub_request(:post, valid_attrs[:url]).to_return(status: 200, body: [[1]].to_json)
|
||||
|
||||
get "/admin/plugins/discourse-ai/ai-embeddings/test.json",
|
||||
params: {
|
||||
ai_embedding: valid_attrs,
|
||||
}
|
||||
|
||||
expect(response).to be_successful
|
||||
expect(response.parsed_body["success"]).to eq(true)
|
||||
end
|
||||
end
|
||||
|
||||
context "when we cannot generate an embedding" do
|
||||
it "returns a success false flag and the error message" do
|
||||
error_message = { error: "Embedding generation failed." }
|
||||
|
||||
WebMock.stub_request(:post, valid_attrs[:url]).to_return(
|
||||
status: 422,
|
||||
body: error_message.to_json,
|
||||
)
|
||||
|
||||
get "/admin/plugins/discourse-ai/ai-embeddings/test.json",
|
||||
params: {
|
||||
ai_embedding: valid_attrs,
|
||||
}
|
||||
|
||||
expect(response).to be_successful
|
||||
expect(response.parsed_body["success"]).to eq(false)
|
||||
expect(response.parsed_body["error"]).to eq(error_message.to_json)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -8,7 +8,6 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
||||
sign_in(admin)
|
||||
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
end
|
||||
|
||||
describe "GET #index" do
|
||||
|
@ -4,11 +4,12 @@ RSpec.describe DiscourseAi::Admin::RagDocumentFragmentsController do
|
||||
fab!(:admin)
|
||||
fab!(:ai_persona)
|
||||
|
||||
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||
|
||||
before do
|
||||
sign_in(admin)
|
||||
|
||||
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
end
|
||||
|
||||
describe "GET #indexing_status_check" do
|
||||
|
@ -2,9 +2,11 @@
|
||||
|
||||
describe DiscourseAi::Embeddings::EmbeddingsController do
|
||||
context "when performing a topic search" do
|
||||
fab!(:vector_def) { Fabricate(:open_ai_embedding_def) }
|
||||
|
||||
before do
|
||||
SiteSetting.min_search_term_length = 3
|
||||
SiteSetting.ai_embeddings_model = "text-embedding-3-small"
|
||||
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||
DiscourseAi::Embeddings::SemanticSearch.clear_cache_for("test")
|
||||
SearchIndexer.enable
|
||||
end
|
||||
@ -31,7 +33,15 @@ describe DiscourseAi::Embeddings::EmbeddingsController do
|
||||
|
||||
def stub_embedding(query)
|
||||
embedding = [0.049382] * 1536
|
||||
EmbeddingsGenerationStubs.openai_service(SiteSetting.ai_embeddings_model, query, embedding)
|
||||
|
||||
EmbeddingsGenerationStubs.openai_service(
|
||||
vector_def.lookup_custom_param("model_name"),
|
||||
query,
|
||||
embedding,
|
||||
extra_args: {
|
||||
dimensions: vector_def.dimensions,
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
def create_api_key(user)
|
||||
|
@ -1,10 +1,13 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
describe DiscourseAi::Inference::OpenAiEmbeddings do
|
||||
let(:api_key) { "123456" }
|
||||
let(:dimensions) { 1000 }
|
||||
let(:model) { "text-embedding-ada-002" }
|
||||
|
||||
it "supports azure embeddings" do
|
||||
SiteSetting.ai_openai_embeddings_url =
|
||||
azure_url =
|
||||
"https://my-company.openai.azure.com/openai/deployments/embeddings-deployment/embeddings?api-version=2023-05-15"
|
||||
SiteSetting.ai_openai_api_key = "123456"
|
||||
|
||||
body_json = {
|
||||
usage: {
|
||||
@ -14,28 +17,22 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
|
||||
data: [{ object: "embedding", embedding: [0.0, 0.1] }],
|
||||
}.to_json
|
||||
|
||||
stub_request(
|
||||
:post,
|
||||
"https://my-company.openai.azure.com/openai/deployments/embeddings-deployment/embeddings?api-version=2023-05-15",
|
||||
).with(
|
||||
stub_request(:post, azure_url).with(
|
||||
body: "{\"model\":\"text-embedding-ada-002\",\"input\":\"hello\"}",
|
||||
headers: {
|
||||
"Api-Key" => "123456",
|
||||
"Api-Key" => api_key,
|
||||
"Content-Type" => "application/json",
|
||||
},
|
||||
).to_return(status: 200, body: body_json, headers: {})
|
||||
|
||||
result =
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: "text-embedding-ada-002").perform!(
|
||||
"hello",
|
||||
)
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.new(azure_url, api_key, model, nil).perform!("hello")
|
||||
|
||||
expect(result).to eq([0.0, 0.1])
|
||||
end
|
||||
|
||||
it "supports openai embeddings" do
|
||||
SiteSetting.ai_openai_api_key = "123456"
|
||||
|
||||
url = "https://api.openai.com/v1/embeddings"
|
||||
body_json = {
|
||||
usage: {
|
||||
prompt_tokens: 1,
|
||||
@ -44,21 +41,20 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
|
||||
data: [{ object: "embedding", embedding: [0.0, 0.1] }],
|
||||
}.to_json
|
||||
|
||||
body = { model: "text-embedding-ada-002", input: "hello", dimensions: 1000 }.to_json
|
||||
body = { model: model, input: "hello", dimensions: dimensions }.to_json
|
||||
|
||||
stub_request(:post, "https://api.openai.com/v1/embeddings").with(
|
||||
stub_request(:post, url).with(
|
||||
body: body,
|
||||
headers: {
|
||||
"Authorization" => "Bearer 123456",
|
||||
"Authorization" => "Bearer #{api_key}",
|
||||
"Content-Type" => "application/json",
|
||||
},
|
||||
).to_return(status: 200, body: body_json, headers: {})
|
||||
|
||||
result =
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(
|
||||
model: "text-embedding-ada-002",
|
||||
dimensions: 1000,
|
||||
).perform!("hello")
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.new(url, api_key, model, dimensions).perform!(
|
||||
"hello",
|
||||
)
|
||||
|
||||
expect(result).to eq([0.0, 0.1])
|
||||
end
|
||||
|
@ -2,15 +2,11 @@
|
||||
|
||||
class EmbeddingsGenerationStubs
|
||||
class << self
|
||||
def discourse_service(model, string, embedding)
|
||||
model = "bge-large-en-v1.5" if model == "bge-large-en"
|
||||
def hugging_face_service(string, embedding)
|
||||
WebMock
|
||||
.stub_request(
|
||||
:post,
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
)
|
||||
.with(body: JSON.dump({ model: model, content: string }))
|
||||
.to_return(status: 200, body: JSON.dump(embedding))
|
||||
.stub_request(:post, "https://test.com/embeddings")
|
||||
.with(body: JSON.dump({ inputs: string, truncate: true }))
|
||||
.to_return(status: 200, body: JSON.dump([embedding]))
|
||||
end
|
||||
|
||||
def openai_service(model, string, embedding, extra_args: {})
|
||||
@ -29,5 +25,12 @@ class EmbeddingsGenerationStubs
|
||||
.with(body: JSON.dump({ content: { parts: [{ text: string }] } }))
|
||||
.to_return(status: 200, body: JSON.dump({ embedding: { values: embedding } }))
|
||||
end
|
||||
|
||||
def cloudflare_service(string, embedding)
|
||||
WebMock
|
||||
.stub_request(:post, "https://test.com/embeddings")
|
||||
.with(body: JSON.dump({ text: [string] }))
|
||||
.to_return(status: 200, body: JSON.dump({ result: { data: [embedding] } }))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
87
spec/system/embeddings/ai_embedding_definition_spec.rb
Normal file
87
spec/system/embeddings/ai_embedding_definition_spec.rb
Normal file
@ -0,0 +1,87 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe "Managing Embeddings configurations", type: :system, js: true do
|
||||
fab!(:admin)
|
||||
let(:page_header) { PageObjects::Components::DPageHeader.new }
|
||||
|
||||
before { sign_in(admin) }
|
||||
|
||||
it "correctly sets defaults" do
|
||||
preset = "text-embedding-3-small"
|
||||
api_key = "abcd"
|
||||
|
||||
visit "/admin/plugins/discourse-ai/ai-embeddings"
|
||||
|
||||
find(".ai-embeddings-list-editor__new-button").click()
|
||||
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__presets")
|
||||
select_kit.expand
|
||||
select_kit.select_row_by_value(preset)
|
||||
find(".ai-embedding-editor__next").click
|
||||
find("input.ai-embedding-editor__api-key").fill_in(with: api_key)
|
||||
find(".ai-embedding-editor__save").click()
|
||||
|
||||
expect(page).to have_current_path("/admin/plugins/discourse-ai/ai-embeddings")
|
||||
|
||||
embedding_def = EmbeddingDefinition.order(:id).last
|
||||
expect(embedding_def.api_key).to eq(api_key)
|
||||
|
||||
preset = EmbeddingDefinition.presets.find { |p| p[:preset_id] == preset }
|
||||
|
||||
expect(embedding_def.display_name).to eq(preset[:display_name])
|
||||
expect(embedding_def.url).to eq(preset[:url])
|
||||
expect(embedding_def.tokenizer_class).to eq(preset[:tokenizer_class])
|
||||
expect(embedding_def.dimensions).to eq(preset[:dimensions])
|
||||
expect(embedding_def.max_sequence_length).to eq(preset[:max_sequence_length])
|
||||
expect(embedding_def.pg_function).to eq(preset[:pg_function])
|
||||
expect(embedding_def.provider).to eq(preset[:provider])
|
||||
expect(embedding_def.provider_params.symbolize_keys).to eq(preset[:provider_params])
|
||||
end
|
||||
|
||||
it "supports manual config" do
|
||||
api_key = "abcd"
|
||||
|
||||
visit "/admin/plugins/discourse-ai/ai-embeddings"
|
||||
|
||||
find(".ai-embeddings-list-editor__new-button").click()
|
||||
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__presets")
|
||||
select_kit.expand
|
||||
select_kit.select_row_by_value("manual")
|
||||
find(".ai-embedding-editor__next").click
|
||||
|
||||
find("input.ai-embedding-editor__display-name").fill_in(with: "OpenAI's text-embedding-3-small")
|
||||
|
||||
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__provider")
|
||||
select_kit.expand
|
||||
select_kit.select_row_by_value(EmbeddingDefinition::OPEN_AI)
|
||||
|
||||
find("input.ai-embedding-editor__url").fill_in(with: "https://api.openai.com/v1/embeddings")
|
||||
find("input.ai-embedding-editor__api-key").fill_in(with: api_key)
|
||||
|
||||
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__tokenizer")
|
||||
select_kit.expand
|
||||
select_kit.select_row_by_value("DiscourseAi::Tokenizer::OpenAiTokenizer")
|
||||
|
||||
find("input.ai-embedding-editor__dimensions").fill_in(with: 1536)
|
||||
find("input.ai-embedding-editor__max_sequence_length").fill_in(with: 8191)
|
||||
|
||||
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__distance_functions")
|
||||
select_kit.expand
|
||||
select_kit.select_row_by_value("<=>")
|
||||
find(".ai-embedding-editor__save").click()
|
||||
|
||||
expect(page).to have_current_path("/admin/plugins/discourse-ai/ai-embeddings")
|
||||
|
||||
embedding_def = EmbeddingDefinition.order(:id).last
|
||||
expect(embedding_def.api_key).to eq(api_key)
|
||||
|
||||
preset = EmbeddingDefinition.presets.find { |p| p[:preset_id] == "text-embedding-3-small" }
|
||||
|
||||
expect(embedding_def.display_name).to eq(preset[:display_name])
|
||||
expect(embedding_def.url).to eq(preset[:url])
|
||||
expect(embedding_def.tokenizer_class).to eq(preset[:tokenizer_class])
|
||||
expect(embedding_def.dimensions).to eq(preset[:dimensions])
|
||||
expect(embedding_def.max_sequence_length).to eq(preset[:max_sequence_length])
|
||||
expect(embedding_def.pg_function).to eq(preset[:pg_function])
|
||||
expect(embedding_def.provider).to eq(preset[:provider])
|
||||
end
|
||||
end
|
@ -10,7 +10,6 @@ RSpec.describe "AI Composer helper", type: :system, js: true do
|
||||
fab!(:post) { Fabricate(:post, topic: topic, raw: "Apple pie is a delicious dessert to eat") }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
prompt = DiscourseAi::Embeddings::HydeGenerators::OpenAi.new.prompt(query)
|
||||
OpenAiCompletionsInferenceStubs.stub_response(
|
||||
prompt,
|
||||
@ -21,11 +20,7 @@ RSpec.describe "AI Composer helper", type: :system, js: true do
|
||||
)
|
||||
|
||||
hyde_embedding = [0.049382, 0.9999]
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
SiteSetting.ai_embeddings_model,
|
||||
hypothetical_post,
|
||||
hyde_embedding,
|
||||
)
|
||||
EmbeddingsGenerationStubs.hugging_face_service(hypothetical_post, hyde_embedding)
|
||||
|
||||
SearchIndexer.enable
|
||||
SearchIndexer.index(topic, force: true)
|
||||
|
Loading…
x
Reference in New Issue
Block a user