FEATURE: configurable embeddings (#1049)

* Use AR model for embeddings features

* endpoints

* Embeddings CRUD UI

* Add presets. Hide a couple more settings

* system specs

* Seed embedding definition from old settings

* Generate search bit index on the fly. cleanup orphaned data

* support for seeded models

* Fix run test for new embedding

* fix selected model not set correctly
This commit is contained in:
Roman Rizzi 2025-01-21 12:23:19 -03:00 committed by GitHub
parent fad4b65d4f
commit f5cf1019fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
78 changed files with 2131 additions and 1008 deletions

View File

@ -0,0 +1,21 @@
import DiscourseRoute from "discourse/routes/discourse";
export default class AdminPluginsShowDiscourseAiEmbeddingsEdit extends DiscourseRoute {
async model(params) {
const allEmbeddings = this.modelFor(
"adminPlugins.show.discourse-ai-embeddings"
);
const id = parseInt(params.id, 10);
const record = allEmbeddings.findBy("id", id);
record.provider_params = record.provider_params || {};
return record;
}
setupController(controller, model) {
super.setupController(controller, model);
controller.set(
"allEmbeddings",
this.modelFor("adminPlugins.show.discourse-ai-embeddings")
);
}
}

View File

@ -0,0 +1,17 @@
import DiscourseRoute from "discourse/routes/discourse";
export default class AdminPluginsShowDiscourseAiEmbeddingsNew extends DiscourseRoute {
async model() {
const record = this.store.createRecord("ai-embedding");
record.provider_params = {};
return record;
}
setupController(controller, model) {
super.setupController(controller, model);
controller.set(
"allEmbeddings",
this.modelFor("adminPlugins.show.discourse-ai-embeddings")
);
}
}

View File

@ -0,0 +1,7 @@
import DiscourseRoute from "discourse/routes/discourse";
export default class DiscourseAiAiEmbeddingsRoute extends DiscourseRoute {
model() {
return this.store.findAll("ai-embedding");
}
}

View File

@ -0,0 +1,4 @@
<AiEmbeddingsListEditor
@embeddings={{this.allEmbeddings}}
@currentEmbedding={{this.model}}
/>

View File

@ -0,0 +1 @@
<AiEmbeddingsListEditor @embeddings={{this.model}} />

View File

@ -0,0 +1,4 @@
<AiEmbeddingsListEditor
@embeddings={{this.allEmbeddings}}
@currentEmbedding={{this.model}}
/>

View File

@ -0,0 +1,130 @@
# frozen_string_literal: true
module DiscourseAi
module Admin
class AiEmbeddingsController < ::Admin::AdminController
requires_plugin ::DiscourseAi::PLUGIN_NAME
def index
embedding_defs = EmbeddingDefinition.all.order(:display_name)
render json: {
ai_embeddings:
ActiveModel::ArraySerializer.new(
embedding_defs,
each_serializer: AiEmbeddingDefinitionSerializer,
root: false,
).as_json,
meta: {
provider_params: EmbeddingDefinition.provider_params,
providers: EmbeddingDefinition.provider_names,
distance_functions: EmbeddingDefinition.distance_functions,
tokenizers:
EmbeddingDefinition.tokenizer_names.map { |tn|
{ id: tn, name: tn.split("::").last }
},
presets: EmbeddingDefinition.presets,
},
}
end
def new
end
def edit
embedding_def = EmbeddingDefinition.find(params[:id])
render json: AiEmbeddingDefinitionSerializer.new(embedding_def)
end
def create
embedding_def = EmbeddingDefinition.new(ai_embeddings_params)
if embedding_def.save
render json: AiEmbeddingDefinitionSerializer.new(embedding_def), status: :created
else
render_json_error embedding_def
end
end
def update
embedding_def = EmbeddingDefinition.find(params[:id])
if embedding_def.seeded?
return(
render_json_error(I18n.t("discourse_ai.embeddings.cannot_edit_builtin"), status: 403)
)
end
if embedding_def.update(ai_embeddings_params.except(:dimensions))
render json: AiEmbeddingDefinitionSerializer.new(embedding_def)
else
render_json_error embedding_def
end
end
def destroy
embedding_def = EmbeddingDefinition.find(params[:id])
if embedding_def.seeded?
return(
render_json_error(I18n.t("discourse_ai.embeddings.cannot_edit_builtin"), status: 403)
)
end
if embedding_def.id == SiteSetting.ai_embeddings_selected_model.to_i
return render_json_error(I18n.t("discourse_ai.embeddings.delete_failed"), status: 409)
end
if embedding_def.destroy
head :no_content
else
render_json_error embedding_def
end
end
def test
RateLimiter.new(
current_user,
"ai_embeddings_test_#{current_user.id}",
3,
1.minute,
).performed!
embedding_def = EmbeddingDefinition.new(ai_embeddings_params)
DiscourseAi::Embeddings::Vector.new(embedding_def).vector_from("this is a test")
render json: { success: true }
rescue Net::HTTPBadResponse => e
render json: { success: false, error: e.message }
end
private
def ai_embeddings_params
permitted =
params.require(:ai_embedding).permit(
:display_name,
:dimensions,
:max_sequence_length,
:pg_function,
:provider,
:url,
:api_key,
:tokenizer_class,
)
extra_field_names = EmbeddingDefinition.provider_params.dig(permitted[:provider]&.to_sym)
if extra_field_names.present?
received_prov_params =
params.dig(:ai_embedding, :provider_params)&.slice(*extra_field_names.keys)
if received_prov_params.present?
permitted[:provider_params] = received_prov_params.permit!
end
end
permitted
end
end
end
end

View File

@ -18,7 +18,7 @@ module ::Jobs
target = target_type.constantize.find_by(id: target_id)
return if !target
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep = DiscourseAi::Embeddings::Vector.instance
tokenizer = vector_rep.tokenizer
chunk_tokens = target.rag_chunk_tokens

View File

@ -0,0 +1,13 @@
# frozen_string_literal: true
module ::Jobs
class ManageEmbeddingDefSearchIndex < ::Jobs::Base
def execute(args)
embedding_def = EmbeddingDefinition.find_by(id: args[:id])
return if embedding_def.nil?
return if DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_def)
DiscourseAi::Embeddings::Schema.prepare_search_indexes(embedding_def)
end
end
end

View File

@ -0,0 +1,11 @@
# frozen_string_literal: true
module Jobs
class RemoveOrphanedEmbeddings < ::Jobs::Scheduled
every 1.week
def execute(_args)
DiscourseAi::Embeddings::Schema.remove_orphaned_data
end
end
end

View File

@ -0,0 +1,231 @@
# frozen_string_literal: true
class EmbeddingDefinition < ActiveRecord::Base
CLOUDFLARE = "cloudflare"
HUGGING_FACE = "hugging_face"
OPEN_AI = "open_ai"
GOOGLE = "google"
class << self
def provider_names
[CLOUDFLARE, HUGGING_FACE, OPEN_AI, GOOGLE]
end
def distance_functions
%w[<#> <=>]
end
def tokenizer_names
[
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer,
DiscourseAi::Tokenizer::BgeLargeEnTokenizer,
DiscourseAi::Tokenizer::BgeM3Tokenizer,
DiscourseAi::Tokenizer::OpenAiTokenizer,
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer,
DiscourseAi::Tokenizer::OpenAiTokenizer,
].map(&:name)
end
def provider_params
{ open_ai: { model_name: :text } }
end
def presets
@presets ||=
begin
[
{
preset_id: "bge-large-en",
display_name: "bge-large-en",
dimensions: 1024,
max_sequence_length: 512,
pg_function: "<#>",
tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer",
provider: HUGGING_FACE,
},
{
preset_id: "bge-m3",
display_name: "bge-m3",
dimensions: 1024,
max_sequence_length: 8192,
pg_function: "<#>",
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
provider: HUGGING_FACE,
},
{
preset_id: "gemini-embedding-001",
display_name: "Gemini's embedding-001",
dimensions: 768,
max_sequence_length: 1536,
pg_function: "<=>",
url:
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent",
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
provider: GOOGLE,
},
{
preset_id: "multilingual-e5-large",
display_name: "multilingual-e5-large",
dimensions: 1024,
max_sequence_length: 512,
pg_function: "<=>",
tokenizer_class: "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer",
provider: HUGGING_FACE,
},
{
preset_id: "text-embedding-3-large",
display_name: "OpenAI's text-embedding-3-large",
dimensions: 2000,
max_sequence_length: 8191,
pg_function: "<=>",
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
url: "https://api.openai.com/v1/embeddings",
provider: OPEN_AI,
provider_params: {
model_name: "text-embedding-3-large",
},
},
{
preset_id: "text-embedding-3-small",
display_name: "OpenAI's text-embedding-3-small",
dimensions: 1536,
max_sequence_length: 8191,
pg_function: "<=>",
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
url: "https://api.openai.com/v1/embeddings",
provider: OPEN_AI,
provider_params: {
model_name: "text-embedding-3-small",
},
},
{
preset_id: "text-embedding-ada-002",
display_name: "OpenAI's text-embedding-ada-002",
dimensions: 1536,
max_sequence_length: 8191,
pg_function: "<=>",
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
url: "https://api.openai.com/v1/embeddings",
provider: OPEN_AI,
provider_params: {
model_name: "text-embedding-ada-002",
},
},
]
end
end
end
validates :provider, presence: true, inclusion: provider_names
validates :display_name, presence: true, length: { maximum: 100 }
validates :tokenizer_class, presence: true, inclusion: tokenizer_names
validates_presence_of :url, :api_key, :dimensions, :max_sequence_length, :pg_function
after_create :create_indexes
def create_indexes
Jobs.enqueue(:manage_embedding_def_search_index, id: self.id)
end
def tokenizer
tokenizer_class.constantize
end
def inference_client
case provider
when CLOUDFLARE
cloudflare_client
when HUGGING_FACE
hugging_face_client
when OPEN_AI
open_ai_client
when GOOGLE
gemini_client
else
raise "Uknown embeddings provider"
end
end
def lookup_custom_param(key)
provider_params&.dig(key)
end
def endpoint_url
return url if !url.starts_with?("srv://")
service = DiscourseAi::Utils::DnsSrv.lookup(url.sub("srv://", ""))
"https://#{service.target}:#{service.port}"
end
def prepare_query_text(text, asymetric: false)
strategy.prepare_query_text(text, self, asymetric: asymetric)
end
def prepare_target_text(target)
strategy.prepare_target_text(target, self)
end
def strategy_id
strategy.id
end
def strategy_version
strategy.version
end
def api_key
if seeded?
env_key = "DISCOURSE_AI_SEEDED_EMBEDDING_API_KEY"
ENV[env_key] || self[:api_key]
else
self[:api_key]
end
end
private
def strategy
@strategy ||= DiscourseAi::Embeddings::Strategies::Truncation.new
end
def cloudflare_client
DiscourseAi::Inference::CloudflareWorkersAi.new(endpoint_url, api_key)
end
def hugging_face_client
DiscourseAi::Inference::HuggingFaceTextEmbeddings.new(endpoint_url, api_key)
end
def open_ai_client
DiscourseAi::Inference::OpenAiEmbeddings.new(
endpoint_url,
api_key,
lookup_custom_param("model_name"),
dimensions,
)
end
def gemini_client
DiscourseAi::Inference::GeminiEmbeddings.new(endpoint_url, api_key)
end
end
# == Schema Information
#
# Table name: embedding_definitions
#
# id :bigint not null, primary key
# display_name :string not null
# dimensions :integer not null
# max_sequence_length :integer not null
# version :integer default(1), not null
# pg_function :string not null
# provider :string not null
# tokenizer_class :string not null
# url :string not null
# api_key :string
# seeded :boolean default(FALSE), not null
# provider_params :jsonb
# created_at :datetime not null
# updated_at :datetime not null
#

View File

@ -0,0 +1,29 @@
# frozen_string_literal: true
class AiEmbeddingDefinitionSerializer < ApplicationSerializer
root "ai_embedding"
attributes :id,
:display_name,
:dimensions,
:max_sequence_length,
:pg_function,
:provider,
:url,
:api_key,
:seeded,
:tokenizer_class,
:provider_params
def api_key
object.seeded? ? "********" : object.api_key
end
def url
object.seeded? ? "********" : object.url
end
def provider
object.seeded? ? "CDCK" : object.provider
end
end

View File

@ -20,5 +20,14 @@ export default {
});
this.route("discourse-ai-spam", { path: "ai-spam" });
this.route("discourse-ai-usage", { path: "ai-usage" });
this.route(
"discourse-ai-embeddings",
{ path: "ai-embeddings" },
function () {
this.route("new");
this.route("edit", { path: "/:id/edit" });
}
);
},
};

View File

@ -0,0 +1,21 @@
import RestAdapter from "discourse/adapters/rest";
export default class Adapter extends RestAdapter {
jsonMode = true;
basePath() {
return "/admin/plugins/discourse-ai/";
}
pathFor(store, type, findArgs) {
// removes underscores which are implemented in base
let path =
this.basePath(store, type, findArgs) +
store.pluralize(this.apiNameFor(type));
return this.appendQueryParams(path, findArgs);
}
apiNameFor() {
return "ai-embedding";
}
}

View File

@ -0,0 +1,38 @@
import { ajax } from "discourse/lib/ajax";
import RestModel from "discourse/models/rest";
export default class AiEmbedding extends RestModel {
createProperties() {
return this.getProperties(
"id",
"display_name",
"dimensions",
"provider",
"tokenizer_class",
"dimensions",
"url",
"api_key",
"max_sequence_length",
"provider_params",
"pg_function"
);
}
updateProperties() {
const attrs = this.createProperties();
attrs.id = this.id;
return attrs;
}
async testConfig() {
return await ajax(`/admin/plugins/discourse-ai/ai-embeddings/test.json`, {
data: { ai_embedding: this.createProperties() },
});
}
workingCopy() {
const attrs = this.createProperties();
return this.store.createRecord("ai-embedding", attrs);
}
}

View File

@ -0,0 +1,376 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { Input } from "@ember/component";
import { concat, get } from "@ember/helper";
import { on } from "@ember/modifier";
import { action, computed } from "@ember/object";
import didInsert from "@ember/render-modifiers/modifiers/did-insert";
import didUpdate from "@ember/render-modifiers/modifiers/did-update";
import { later } from "@ember/runloop";
import { service } from "@ember/service";
import BackButton from "discourse/components/back-button";
import DButton from "discourse/components/d-button";
import icon from "discourse/helpers/d-icon";
import { popupAjaxError } from "discourse/lib/ajax-error";
import { i18n } from "discourse-i18n";
import ComboBox from "select-kit/components/combo-box";
import DTooltip from "float-kit/components/d-tooltip";
import not from "truth-helpers/helpers/not";
export default class AiEmbeddingEditor extends Component {
@service toasts;
@service router;
@service dialog;
@service store;
@tracked isSaving = false;
@tracked selectedPreset = null;
@tracked testRunning = false;
@tracked testResult = null;
@tracked testError = null;
@tracked apiKeySecret = true;
@tracked editingModel = null;
get selectedProviders() {
const t = (provName) => {
return i18n(`discourse_ai.embeddings.providers.${provName}`);
};
return this.args.embeddings.resultSetMeta.providers.map((prov) => {
return { id: prov, name: t(prov) };
});
}
get distanceFunctions() {
const t = (df) => {
return i18n(`discourse_ai.embeddings.distance_functions.${df}`);
};
return this.args.embeddings.resultSetMeta.distance_functions.map((df) => {
return { id: df, name: t(df) };
});
}
get presets() {
const presets = this.args.embeddings.resultSetMeta.presets.map((preset) => {
return {
name: preset.display_name,
id: preset.preset_id,
};
});
presets.pushObject({
name: i18n("discourse_ai.embeddings.configure_manually"),
id: "manual",
});
return presets;
}
get showPresets() {
return !this.selectedPreset && this.args.model.isNew;
}
@computed("editingModel.provider")
get metaProviderParams() {
return (
this.args.embeddings.resultSetMeta.provider_params[
this.editingModel?.provider
] || {}
);
}
get testErrorMessage() {
return i18n("discourse_ai.llms.tests.failure", { error: this.testError });
}
get displayTestResult() {
return this.testRunning || this.testResult !== null;
}
@action
configurePreset() {
this.selectedPreset =
this.args.embeddings.resultSetMeta.presets.findBy(
"preset_id",
this.presetId
) || {};
this.editingModel = this.store
.createRecord("ai-embedding", this.selectedPreset)
.workingCopy();
}
@action
updateModel() {
this.editingModel = this.args.model.workingCopy();
}
@action
makeApiKeySecret() {
this.apiKeySecret = true;
}
@action
toggleApiKeySecret() {
this.apiKeySecret = !this.apiKeySecret;
}
@action
async save() {
this.isSaving = true;
const isNew = this.args.model.isNew;
try {
await this.editingModel.save();
if (isNew) {
this.args.embeddings.addObject(this.editingModel);
this.router.transitionTo(
"adminPlugins.show.discourse-ai-embeddings.index"
);
} else {
this.toasts.success({
data: { message: i18n("discourse_ai.embeddings.saved") },
duration: 2000,
});
}
} catch (e) {
popupAjaxError(e);
} finally {
later(() => {
this.isSaving = false;
}, 1000);
}
}
@action
async test() {
this.testRunning = true;
try {
const configTestResult = await this.editingModel.testConfig();
this.testResult = configTestResult.success;
if (this.testResult) {
this.testError = null;
} else {
this.testError = configTestResult.error;
}
} catch (e) {
popupAjaxError(e);
} finally {
later(() => {
this.testRunning = false;
}, 1000);
}
}
@action
delete() {
return this.dialog.confirm({
message: i18n("discourse_ai.embeddings.confirm_delete"),
didConfirm: () => {
return this.args.model
.destroyRecord()
.then(() => {
this.args.llms.removeObject(this.args.model);
this.router.transitionTo(
"adminPlugins.show.discourse-ai-embeddings.index"
);
})
.catch(popupAjaxError);
},
});
}
<template>
<BackButton
@route="adminPlugins.show.discourse-ai-embeddings"
@label="discourse_ai.embeddings.back"
/>
<form
{{didInsert this.updateModel @model.id}}
{{didUpdate this.updateModel @model.id}}
class="form-horizontal ai-embedding-editor"
>
{{#if this.showPresets}}
<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.presets"}}</label>
<ComboBox
@value={{this.presetId}}
@content={{this.presets}}
class="ai-embedding-editor__presets"
/>
</div>
<div class="control-group ai-llm-editor__action_panel">
<DButton
@action={{this.configurePreset}}
@label="discourse_ai.tools.next.title"
class="ai-embedding-editor__next"
/>
</div>
{{else}}
<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.display_name"}}</label>
<Input
class="ai-embedding-editor-input ai-embedding-editor__display-name"
@type="text"
@value={{this.editingModel.display_name}}
/>
</div>
<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.provider"}}</label>
<ComboBox
@value={{this.editingModel.provider}}
@content={{this.selectedProviders}}
@class="ai-embedding-editor__provider"
/>
</div>
<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.url"}}</label>
<Input
class="ai-embedding-editor-input ai-embedding-editor__url"
@type="text"
@value={{this.editingModel.url}}
required="true"
/>
</div>
<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.api_key"}}</label>
<div class="ai-embedding-editor__secret-api-key-group">
<Input
@value={{this.editingModel.api_key}}
class="ai-embedding-editor-input ai-embedding-editor__api-key"
@type={{if this.apiKeySecret "password" "text"}}
required="true"
{{on "focusout" this.makeApiKeySecret}}
/>
<DButton
@action={{this.toggleApiKeySecret}}
@icon="far-eye-slash"
/>
</div>
</div>
<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.tokenizer"}}</label>
<ComboBox
@value={{this.editingModel.tokenizer_class}}
@content={{@embeddings.resultSetMeta.tokenizers}}
@class="ai-embedding-editor__tokenizer"
/>
</div>
<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.dimensions"}}</label>
<Input
@type="number"
class="ai-embedding-editor-input ai-embedding-editor__dimensions"
step="any"
min="0"
lang="en"
@value={{this.editingModel.dimensions}}
required="true"
disabled={{not this.editingModel.isNew}}
/>
{{#if this.editingModel.isNew}}
<DTooltip
@icon="circle-exclamation"
@content={{i18n
"discourse_ai.embeddings.hints.dimensions_warning"
}}
/>
{{/if}}
</div>
<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.max_sequence_length"}}</label>
<Input
@type="number"
class="ai-embedding-editor-input ai-embedding-editor__max_sequence_length"
step="any"
min="0"
lang="en"
@value={{this.editingModel.max_sequence_length}}
required="true"
/>
</div>
<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.distance_function"}}</label>
<ComboBox
@value={{this.editingModel.pg_function}}
@content={{this.distanceFunctions}}
@class="ai-embedding-editor__distance_functions"
/>
</div>
{{#each-in this.metaProviderParams as |field type|}}
<div
class="control-group ai-embedding-editor-provider-param__{{type}}"
>
<label>
{{i18n (concat "discourse_ai.embeddings.provider_fields." field)}}
</label>
<Input
@type="text"
class="ai-embedding-editor-input ai-embedding-editor__{{field}}"
@value={{mut (get this.editingModel.provider_params field)}}
/>
</div>
{{/each-in}}
<div class="control-group ai-embedding-editor__action_panel">
<DButton
class="ai-embedding-editor__test"
@action={{this.test}}
@disabled={{this.testRunning}}
@label="discourse_ai.embeddings.tests.title"
/>
<DButton
class="btn-primary ai-embedding-editor__save"
@action={{this.save}}
@disabled={{this.isSaving}}
@label="discourse_ai.embeddings.save"
/>
{{#unless this.editingModel.isNew}}
<DButton
@action={{this.delete}}
class="btn-danger ai-embedding-editor__delete"
@label="discourse_ai.embeddings.delete"
/>
{{/unless}}
<div class="control-group ai-embedding-editor-tests">
{{#if this.displayTestResult}}
{{#if this.testRunning}}
<div class="spinner small"></div>
{{i18n "discourse_ai.embeddings.tests.running"}}
{{else}}
{{#if this.testResult}}
<div class="ai-embedding-editor-tests__success">
{{icon "check"}}
{{i18n "discourse_ai.embeddings.tests.success"}}
</div>
{{else}}
<div class="ai-embedding-editor-tests__failure">
{{icon "xmark"}}
{{this.testErrorMessage}}
</div>
{{/if}}
{{/if}}
{{/if}}
</div>
</div>
{{/if}}
</form>
</template>
}

View File

@ -0,0 +1,114 @@
import Component from "@glimmer/component";
import { concat } from "@ember/helper";
import { service } from "@ember/service";
import DBreadcrumbsItem from "discourse/components/d-breadcrumbs-item";
import DButton from "discourse/components/d-button";
import DPageSubheader from "discourse/components/d-page-subheader";
import { i18n } from "discourse-i18n";
import AdminConfigAreaEmptyList from "admin/components/admin-config-area-empty-list";
import DTooltip from "float-kit/components/d-tooltip";
import AiEmbeddingEditor from "./ai-embedding-editor";
export default class AiEmbeddingsListEditor extends Component {
@service adminPluginNavManager;
get hasEmbeddingElements() {
return this.args.embeddings.length !== 0;
}
<template>
<DBreadcrumbsItem
@path="/admin/plugins/{{this.adminPluginNavManager.currentPlugin.name}}/ai-embeddings"
@label={{i18n "discourse_ai.embeddings.short_title"}}
/>
<section class="ai-embeddings-list-editor admin-detail">
{{#if @currentEmbedding}}
<AiEmbeddingEditor
@model={{@currentEmbedding}}
@embeddings={{@embeddings}}
/>
{{else}}
<DPageSubheader
@titleLabel={{i18n "discourse_ai.embeddings.short_title"}}
@descriptionLabel={{i18n "discourse_ai.embeddings.description"}}
@learnMoreUrl="https://meta.discourse.org/t/discourse-ai-embeddings/259603"
>
<:actions as |actions|>
<actions.Primary
@label="discourse_ai.embeddings.new"
@route="adminPlugins.show.discourse-ai-embeddings.new"
@icon="plus"
class="ai-embeddings-list-editor__new-button"
/>
</:actions>
</DPageSubheader>
{{#if this.hasEmbeddingElements}}
<table class="d-admin-table">
<thead>
<tr>
<th>{{i18n "discourse_ai.embeddings.display_name"}}</th>
<th>{{i18n "discourse_ai.embeddings.provider"}}</th>
<th></th>
</tr>
</thead>
<tbody>
{{#each @embeddings as |embedding|}}
<tr class="ai-embeddings-list__row d-admin-row__content">
<td class="d-admin-row__overview">
<div class="ai-embeddings-list__name">
<strong>
{{embedding.display_name}}
</strong>
</div>
</td>
<td class="d-admin-row__detail">
<div class="d-admin-row__mobile-label">
{{i18n "discourse_ai.embeddings.provider"}}
</div>
{{i18n
(concat
"discourse_ai.embeddings.providers." embedding.provider
)
}}
</td>
<td class="d-admin-row__controls">
{{#if embedding.seeded}}
<DTooltip
class="ai-embeddings-list__edit-disabled-tooltip"
>
<:trigger>
<DButton
class="btn btn-default btn-small disabled"
@label="discourse_ai.embeddings.edit"
/>
</:trigger>
<:content>
{{i18n "discourse_ai.embeddings.seeded_warning"}}
</:content>
</DTooltip>
{{else}}
<DButton
class="btn btn-default btn-small ai-embeddings-list__edit-button"
@label="discourse_ai.embeddings.edit"
@route="adminPlugins.show.discourse-ai-embeddings.edit"
@routeModels={{embedding.id}}
/>
{{/if}}
</td>
</tr>
{{/each}}
</tbody>
</table>
{{else}}
<AdminConfigAreaEmptyList
@ctaLabel="discourse_ai.embeddings.new"
@ctaRoute="adminPlugins.show.discourse-ai-embeddings.new"
@ctaClass="ai-embeddings-list-editor__empty-new-button"
@emptyLabel="discourse_ai.embeddings.empty"
/>
{{/if}}
{{/if}}
</section>
</template>
}

View File

@ -12,6 +12,10 @@ export default {
withPluginApi("1.1.0", (api) => {
api.addAdminPluginConfigurationNav("discourse-ai", PLUGIN_NAV_MODE_TOP, [
{
label: "discourse_ai.embeddings.short_title",
route: "adminPlugins.show.discourse-ai-embeddings",
},
{
label: "discourse_ai.llms.short_title",
route: "adminPlugins.show.discourse-ai-llms",

View File

@ -0,0 +1,26 @@
.ai-embedding-editor {
padding-left: 0.5em;
.ai-embedding-editor-input {
width: 350px;
}
.ai-embedding-editor-tests {
&__failure {
color: var(--danger);
}
&__success {
color: var(--success);
}
}
&__api-key {
margin-right: 0.5em;
}
&__secret-api-key-group {
display: flex;
align-items: center;
}
}

View File

@ -502,6 +502,49 @@ en:
accuracy: "Accuracy:"
embeddings:
short_title: "Embeddings"
description: "Embeddings are a crucial component of the Discourse AI plugin, enabling features like related topics and semantic search."
new: "New embedding"
back: "Back"
save: "Save"
saved: "Embedding configuration saved"
delete: "Delete"
confirm_delete: Are you sure you want to remove this embedding configuration?
empty: "You haven't setup embeddings yet"
presets: "Select a preset..."
configure_manually: "Configure manually"
edit: "Edit"
seeded_warning: "This is pre-configured on your site and cannot be edited."
tests:
title: "Run test"
running: "Running test..."
success: "Success!"
failure: "Attempting to generate an embedding resulted in: %{error}"
hints:
dimensions_warning: "Once saved, this value can't be changed."
display_name: "Name"
provider: "Provider"
url: "Embeddings service URL"
api_key: "Embeddings service API Key"
tokenizer: "Tokenizer"
dimensions: "Embedding dimensions"
max_sequence_length: "Sequence length"
distance_function: "Distance function"
distance_functions:
<#>: "Negative inner product (<#>)"
<=>: "Cosine distance (<=>)"
providers:
hugging_face: "Hugging Face"
open_ai: "OpenAI"
google: "Google"
cloudflare: "Cloudflare"
CDCK: "CDCK"
provider_fields:
model_name: "Model name"
semantic_search: "Topics (Semantic)"
semantic_search_loading: "Searching for more results using AI"
semantic_search_results:

View File

@ -49,10 +49,7 @@ en:
ai_nsfw_flag_threshold_sexy: "Threshold for an image classified as sexy to be considered NSFW."
ai_nsfw_models: "Models to use for NSFW inference."
ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)"
ai_openai_api_key: "API key for OpenAI API. ONLY used for embeddings and Dall-E. For GPT use the LLM config tab"
ai_hugging_face_tei_endpoint: URL where the API is running for the Hugging Face text embeddings inference
ai_hugging_face_tei_api_key: API key for Hugging Face text embeddings inference
ai_openai_api_key: "API key for OpenAI API. ONLY used for Dall-E. For GPT use the LLM config tab"
ai_helper_enabled: "Enable the AI helper."
composer_ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
@ -67,15 +64,11 @@ en:
ai_helper_image_caption_model: "Select the model to use for generating image captions"
ai_auto_image_caption_allowed_groups: "Users on these groups can toggle automatic image captioning."
ai_embeddings_enabled: "Enable the embeddings module."
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
ai_embeddings_discourse_service_api_key: "API key for the embeddings API"
ai_embeddings_model: "Use all-mpnet-base-v2 for local and fast inference in english, text-embedding-ada-002 to use OpenAI API (need API key) and multilingual-e5-large for local multilingual embeddings"
ai_embeddings_selected_model: "Use the selected model for generating embeddings."
ai_embeddings_generate_for_pms: "Generate embeddings for personal messages."
ai_embeddings_semantic_related_topics_enabled: "Use Semantic Search for related topics."
ai_embeddings_semantic_related_topics: "Maximum number of topics to show in related topic section."
ai_embeddings_backfill_batch_size: "Number of embeddings to backfill every 15 minutes."
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
ai_embeddings_semantic_search_enabled: "Enable full-page semantic search."
ai_embeddings_semantic_quick_search_enabled: "Enable semantic search option in search menu popup."
ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results"
@ -437,13 +430,11 @@ en:
cannot_edit_builtin: "You can't edit a built-in model."
embeddings:
delete_failed: "This model is currently in use. Update the `ai embeddings selected model` first."
cannot_edit_builtin: "You can't edit a built-in model."
configuration:
disable_embeddings: "You have to disable 'ai embeddings enabled' first."
choose_model: "Set 'ai embeddings model' first."
model_unreachable: "We failed to generate a test embedding with this model. Check your settings are correct."
hint:
one: "Make sure the `%{settings}` setting was configured."
other: "Make sure the settings of the provider you want were configured. Options are: %{settings}"
choose_model: "Set 'ai embeddings selected model' first."
llm_models:
missing_provider_param: "%{param} can't be blank"

View File

@ -96,6 +96,13 @@ Discourse::Application.routes.draw do
controller: "discourse_ai/admin/ai_llm_quotas",
path: "quotas",
only: %i[index create update destroy]
resources :ai_embeddings,
only: %i[index new create edit update destroy],
path: "ai-embeddings",
controller: "discourse_ai/admin/ai_embeddings" do
collection { get :test }
end
end
end

View File

@ -28,11 +28,14 @@ discourse_ai:
ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations"
ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings"
ai_openai_embeddings_url:
hidden: true
default: "https://api.openai.com/v1/embeddings"
ai_openai_organization:
default: ""
hidden: true
ai_openai_api_key:
hidden: true
default: ""
secret: true
ai_stability_api_key:
@ -50,11 +53,14 @@ discourse_ai:
- "stable-diffusion-768-v2-1"
- "stable-diffusion-v1-5"
ai_hugging_face_tei_endpoint:
hidden: true
default: ""
ai_hugging_face_tei_endpoint_srv:
default: ""
hidden: true
ai_hugging_face_tei_api_key: ""
ai_hugging_face_tei_api_key:
default: ""
hidden: true
ai_hugging_face_tei_reranker_endpoint:
default: ""
ai_hugging_face_tei_reranker_endpoint_srv:
@ -69,12 +75,14 @@ discourse_ai:
ai_cloudflare_workers_account_id:
default: ""
secret: true
hidden: true
ai_cloudflare_workers_api_token:
default: ""
secret: true
hidden: true
ai_gemini_api_key:
default: ""
hidden: false
hidden: true
ai_strict_token_counting:
default: false
hidden: true
@ -158,27 +166,12 @@ discourse_ai:
default: false
client: true
validator: "DiscourseAi::Configuration::EmbeddingsModuleValidator"
ai_embeddings_discourse_service_api_endpoint: ""
ai_embeddings_discourse_service_api_endpoint_srv:
default: ""
hidden: true
ai_embeddings_discourse_service_api_key:
default: ""
secret: true
ai_embeddings_model:
ai_embeddings_selected_model:
type: enum
default: "bge-large-en"
default: ""
allow_any: false
choices:
- all-mpnet-base-v2
- text-embedding-ada-002
- text-embedding-3-small
- text-embedding-3-large
- multilingual-e5-large
- bge-large-en
- gemini
- bge-m3
validator: "DiscourseAi::Configuration::EmbeddingsModelValidator"
enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator"
validator: "DiscourseAi::Configuration::EmbeddingDefsValidator"
ai_embeddings_per_post_enabled:
default: false
hidden: true
@ -191,9 +184,6 @@ discourse_ai:
ai_embeddings_backfill_batch_size:
default: 250
hidden: true
ai_embeddings_pg_connection_string:
default: ""
hidden: true
ai_embeddings_semantic_search_enabled:
default: false
client: true
@ -213,6 +203,35 @@ discourse_ai:
default: false
client: true
hidden: true
ai_embeddings_discourse_service_api_endpoint:
default: ""
hidden: true
ai_embeddings_discourse_service_api_endpoint_srv:
default: ""
hidden: true
ai_embeddings_discourse_service_api_key:
hidden: true
default: ""
secret: true
ai_embeddings_model:
hidden: true
type: enum
default: "bge-large-en"
allow_any: false
choices:
- all-mpnet-base-v2
- text-embedding-ada-002
- text-embedding-3-small
- text-embedding-3-large
- multilingual-e5-large
- bge-large-en
- gemini
- bge-m3
ai_embeddings_pg_connection_string:
default: ""
hidden: true
ai_summarization_enabled:
default: false
client: true

View File

@ -0,0 +1,19 @@
# frozen_string_literal: true
class CreateEmbeddingDefinitions < ActiveRecord::Migration[7.2]
def change
create_table :embedding_definitions do |t|
t.string :display_name, null: false
t.integer :dimensions, null: false
t.integer :max_sequence_length, null: false
t.integer :version, null: false, default: 1
t.string :pg_function, null: false
t.string :provider, null: false
t.string :tokenizer_class, null: false
t.string :url, null: false
t.string :api_key
t.boolean :seeded, null: false, default: false
t.jsonb :provider_params
t.timestamps
end
end
end

View File

@ -0,0 +1,204 @@
# frozen_string_literal: true
class EmbeddingConfigDataMigration < ActiveRecord::Migration[7.0]
def up
current_model = fetch_setting("ai_embeddings_model") || "bge-large-en"
provider = provider_for(current_model)
if provider.present?
attrs = creds_for(provider)
if attrs.present?
attrs = attrs.merge(model_attrs(current_model))
attrs[:display_name] = current_model
attrs[:provider] = provider
persist_config(attrs)
end
end
end
def down
end
# Utils
def fetch_setting(name)
DB.query_single(
"SELECT value FROM site_settings WHERE name = :setting_name",
setting_name: name,
).first || ENV["DISCOURSE_#{name&.upcase}"]
end
def provider_for(model)
cloudflare_api_token = fetch_setting("ai_cloudflare_workers_api_token")
return "cloudflare" if model == "bge-large-en" && cloudflare_api_token.present?
tei_models = %w[bge-large-en bge-m3 multilingual-e5-large]
return "hugging_face" if tei_models.include?(model)
return "google" if model == "gemini"
if %w[text-embedding-3-large text-embedding-3-small text-embedding-ada-002].include?(model)
return "open_ai"
end
nil
end
def creds_for(provider)
# CF
if provider == "cloudflare"
api_key = fetch_setting("ai_cloudflare_workers_api_token")
account_id = fetch_setting("ai_cloudflare_workers_account_id")
return if api_key.blank? || account_id.blank?
{
url:
"https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/baai/bge-large-en-v1.5",
api_key: api_key,
}
# TEI
elsif provider == "hugging_face"
seeded = false
endpoint = fetch_setting("ai_hugging_face_tei_endpoint")
if endpoint.blank?
endpoint = fetch_setting("ai_hugging_face_tei_endpoint_srv")
if endpoint.present?
endpoint = "srv://#{endpoint}"
seeded = true
end
end
api_key = fetch_setting("ai_hugging_face_tei_api_key")
return if endpoint.blank? || api_key.blank?
{ url: endpoint, api_key: api_key, seeded: seeded }
# Gemini
elsif provider == "google"
api_key = fetch_setting("ai_gemini_api_key")
return if api_key.blank?
{
url: "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent",
api_key: api_key,
}
# Open AI
elsif provider == "open_ai"
endpoint = fetch_setting("ai_openai_embeddings_url")
api_key = fetch_setting("ai_openai_api_key")
return if endpoint.blank? || api_key.blank?
{ url: endpoint, api_key: api_key }
else
nil
end
end
def model_attrs(model_name)
if model_name == "bge-large-en"
{
dimensions: 1024,
max_sequence_length: 512,
id: 4,
pg_function: "<#>",
tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer",
}
elsif model_name == "bge-m3"
{
dimensions: 1024,
max_sequence_length: 8192,
id: 8,
pg_function: "<#>",
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
}
elsif model_name == "gemini"
{
dimensions: 768,
max_sequence_length: 1536,
id: 5,
pg_function: "<=>",
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
}
elsif model_name == "multilingual-e5-large"
{
dimensions: 1024,
max_sequence_length: 512,
id: 3,
pg_function: "<=>",
tokenizer_class: "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer",
}
elsif model_name == "text-embedding-3-large"
{
dimensions: 2000,
max_sequence_length: 8191,
id: 7,
pg_function: "<=>",
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
provider_params: {
model_name: "text-embedding-3-large",
},
}
elsif model_name == "text-embedding-3-small"
{
dimensions: 1536,
max_sequence_length: 8191,
id: 6,
pg_function: "<=>",
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
provider_params: {
model_name: "text-embedding-3-small",
},
}
else
{
dimensions: 1536,
max_sequence_length: 8191,
id: 2,
pg_function: "<=>",
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
provider_params: {
model_name: "text-embedding-ada-002",
},
}
end
end
def persist_config(attrs)
DB.exec(
<<~SQL,
INSERT INTO embedding_definitions (id, display_name, dimensions, max_sequence_length, version, pg_function, provider, tokenizer_class, url, api_key, provider_params, seeded, created_at, updated_at)
VALUES (:id, :display_name, :dimensions, :max_sequence_length, 1, :pg_function, :provider, :tokenizer_class, :url, :api_key, :provider_params, :seeded, :now, :now)
SQL
id: attrs[:id],
display_name: attrs[:display_name],
dimensions: attrs[:dimensions],
max_sequence_length: attrs[:max_sequence_length],
pg_function: attrs[:pg_function],
provider: attrs[:provider],
tokenizer_class: attrs[:tokenizer_class],
url: attrs[:url],
api_key: attrs[:api_key],
provider_params: attrs[:provider_params],
seeded: !!attrs[:seeded],
now: Time.zone.now,
)
# We hardcoded the ID to match with already generated embeddings. Let's restart the seq to avoid conflicts.
DB.exec(
"ALTER SEQUENCE embedding_definitions_id_seq RESTART WITH :new_seq",
new_seq: attrs[:id].to_i + 1,
)
DB.exec(<<~SQL, new_value: attrs[:id])
INSERT INTO site_settings(name, data_type, value, created_at, updated_at)
VALUES ('ai_embeddings_selected_model', 3, :new_value, NOW(), NOW())
SQL
end
end

View File

@ -196,7 +196,7 @@ module DiscourseAi
)
plugin.on(:site_setting_changed) do |name, old_value, new_value|
if name == :ai_embeddings_model && SiteSetting.ai_embeddings_enabled? &&
if name == :ai_embeddings_selected_model && SiteSetting.ai_embeddings_enabled? &&
new_value != old_value
RagDocumentFragment.delete_all
UploadReference

View File

@ -327,7 +327,7 @@ module DiscourseAi
rag_conversation_chunks
end
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector.vdef)
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment)
candidate_fragment_ids =
schema

View File

@ -93,7 +93,7 @@ module DiscourseAi
def nearest_neighbors(limit: 100)
vector = DiscourseAi::Embeddings::Vector.instance
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
schema = DiscourseAi::Embeddings::Schema.for(Topic)
raw_vector = vector.vector_from(@text)

View File

@ -0,0 +1,20 @@
# frozen_string_literal: true
require "enum_site_setting"
module DiscourseAi
module Configuration
class EmbeddingDefsEnumerator < ::EnumSiteSetting
def self.valid_value?(val)
true
end
def self.values
DB.query_hash(<<~SQL).map(&:symbolize_keys)
SELECT display_name AS name, id AS value
FROM embedding_definitions
SQL
end
end
end
end

View File

@ -0,0 +1,19 @@
# frozen_string_literal: true
module DiscourseAi
module Configuration
class EmbeddingDefsValidator
def initialize(opts = {})
@opts = opts
end
def valid_value?(val)
val.blank? || EmbeddingDefinition.exists?(id: val)
end
def error_message
""
end
end
end
end

View File

@ -11,41 +11,11 @@ module DiscourseAi
return true if val == "f"
return true if Rails.env.test?
chosen_model = SiteSetting.ai_embeddings_model
return false if !chosen_model
representation =
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(chosen_model)
return false if representation.nil?
if !representation.correctly_configured?
@representation = representation
return false
end
if !can_generate_embeddings?(chosen_model)
@unreachable = true
return false
end
true
SiteSetting.ai_embeddings_selected_model.present?
end
def error_message
return(I18n.t("discourse_ai.embeddings.configuration.model_unreachable")) if @unreachable
@representation&.configuration_hint
end
def can_generate_embeddings?(val)
DiscourseAi::Embeddings::VectorRepresentations::Base
.find_representation(val)
.new
.inference_client
.perform!("this is a test")
.present?
I18n.t("discourse_ai.embeddings.configuration.choose_model")
end
end
end

View File

@ -12,19 +12,80 @@ module DiscourseAi
POSTS_TABLE = "ai_posts_embeddings"
RAG_DOCS_TABLE = "ai_document_fragments_embeddings"
def self.for(
target_klass,
vector_def: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
)
case target_klass&.name
when "Topic"
new(TOPICS_TABLE, "topic_id", vector_def)
when "Post"
new(POSTS_TABLE, "post_id", vector_def)
when "RagDocumentFragment"
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def)
else
raise ArgumentError, "Invalid target type for embeddings"
EMBEDDING_TARGETS = %w[topics posts document_fragments]
EMBEDDING_TABLES = [TOPICS_TABLE, POSTS_TABLE, RAG_DOCS_TABLE]
MissingEmbeddingError = Class.new(StandardError)
class << self
def for(target_klass)
vector_def = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_selected_model)
raise "Invalid embeddings selected model" if vector_def.nil?
case target_klass&.name
when "Topic"
new(TOPICS_TABLE, "topic_id", vector_def)
when "Post"
new(POSTS_TABLE, "post_id", vector_def)
when "RagDocumentFragment"
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def)
else
raise ArgumentError, "Invalid target type for embeddings"
end
end
def search_index_name(table, def_id)
"ai_#{table}_embeddings_#{def_id}_1_search_bit"
end
def prepare_search_indexes(vector_def)
EMBEDDING_TARGETS.each { |target| DB.exec <<~SQL }
CREATE INDEX IF NOT EXISTS #{search_index_name(target, vector_def.id)} ON ai_#{target}_embeddings
USING hnsw ((binary_quantize(embeddings)::bit(#{vector_def.dimensions})) bit_hamming_ops)
WHERE model_id = #{vector_def.id} AND strategy_id = 1;
SQL
end
def correctly_indexed?(vector_def)
index_names = EMBEDDING_TARGETS.map { |t| search_index_name(t, vector_def.id) }
indexdefs =
DB.query_single(
"SELECT indexdef FROM pg_indexes WHERE indexname IN (:names)",
names: index_names,
)
return false if indexdefs.length < index_names.length
indexdefs.all? do |defs|
defs.include? "(binary_quantize(embeddings))::bit(#{vector_def.dimensions})"
end
end
def remove_orphaned_data
removed_defs_ids =
DB.query_single(
"SELECT DISTINCT(model_id) FROM #{TOPICS_TABLE} te LEFT JOIN embedding_definitions ed ON te.model_id = ed.id WHERE ed.id IS NULL",
)
EMBEDDING_TABLES.each do |t|
DB.exec(
"DELETE FROM #{t} WHERE model_id IN (:removed_defs)",
removed_defs: removed_defs_ids,
)
end
drop_index_statement =
EMBEDDING_TARGETS
.reduce([]) do |memo, et|
removed_defs_ids.each do |rdi|
memo << "DROP INDEX IF EXISTS #{search_index_name(et, rdi)};"
end
memo
end
.join("\n")
DB.exec(drop_index_statement)
end
end
@ -117,7 +178,7 @@ module DiscourseAi
offset: offset,
)
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
Rails.logger.error("Error #{e} querying embeddings for model #{vector_def.display_name}")
raise MissingEmbeddingError
end
@ -168,7 +229,7 @@ module DiscourseAi
builder.query(vid: vector_def.id, vsid: vector_def.strategy_id, target_id: record.id)
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
Rails.logger.error("Error #{e} querying embeddings for model #{vector_def.display_name}")
raise MissingEmbeddingError
end

View File

@ -82,7 +82,7 @@ module DiscourseAi
over_selection_limit = limit * OVER_SELECTION_FACTOR
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
schema = DiscourseAi::Embeddings::Schema.for(Topic)
candidate_topic_ids =
schema.asymmetric_similarity_search(
@ -132,7 +132,7 @@ module DiscourseAi
candidate_post_ids =
DiscourseAi::Embeddings::Schema
.for(Post, vector_def: vector.vdef)
.for(Post)
.asymmetric_similarity_search(
search_term_embedding,
limit: max_semantic_results_per_page,

View File

@ -4,13 +4,18 @@ module DiscourseAi
module Embeddings
class Vector
def self.instance
new(DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation)
vector_def = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_selected_model)
raise "Invalid embeddings selected model" if vector_def.nil?
new(vector_def)
end
def initialize(vector_definition)
@vdef = vector_definition
end
delegate :tokenizer, to: :vdef
def gen_bulk_reprensentations(relation)
http_pool_size = 100
pool =
@ -20,7 +25,7 @@ module DiscourseAi
idletime: 30,
)
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector_def: vdef)
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class)
embedding_gen = vdef.inference_client
promised_embeddings =
@ -53,7 +58,7 @@ module DiscourseAi
text = vdef.prepare_target_text(target)
return if text.blank?
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector_def: vdef)
schema = DiscourseAi::Embeddings::Schema.for(target.class)
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
return if schema.find_by_target(target)&.digest == new_digest

View File

@ -1,56 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class AllMpnetBaseV2 < Base
class << self
def name
"all-mpnet-base-v2"
end
def correctly_configured?
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
end
def dependant_setting_names
%w[
ai_embeddings_discourse_service_api_key
ai_embeddings_discourse_service_api_endpoint_srv
ai_embeddings_discourse_service_api_endpoint
]
end
end
def dimensions
768
end
def max_sequence_length
384
end
def id
1
end
def version
1
end
def pg_function
"<#>"
end
def tokenizer
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
end
def inference_client
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
end
end
end
end
end

View File

@ -1,103 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class Base
class << self
def find_representation(model_name)
# we are explicit here cause the loader may have not
# loaded the subclasses yet
[
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
DiscourseAi::Embeddings::VectorRepresentations::BgeM3,
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
].find { _1.name == model_name }
end
def current_representation
find_representation(SiteSetting.ai_embeddings_model).new
end
def correctly_configured?
raise NotImplementedError
end
def dependant_setting_names
raise NotImplementedError
end
def configuration_hint
settings = dependant_setting_names
I18n.t(
"discourse_ai.embeddings.configuration.hint",
settings: settings.join(", "),
count: settings.length,
)
end
end
def name
raise NotImplementedError
end
def dimensions
raise NotImplementedError
end
def max_sequence_length
raise NotImplementedError
end
def id
raise NotImplementedError
end
def pg_function
raise NotImplementedError
end
def version
raise NotImplementedError
end
def tokenizer
raise NotImplementedError
end
def asymmetric_query_prefix
""
end
def strategy_id
strategy.id
end
def strategy_version
strategy.version
end
def prepare_query_text(text, asymetric: false)
strategy.prepare_query_text(text, self, asymetric: asymetric)
end
def prepare_target_text(target)
strategy.prepare_target_text(target, self)
end
def strategy
@strategy ||= DiscourseAi::Embeddings::Strategies::Truncation.new
end
def inference_client
raise NotImplementedError
end
end
end
end
end

View File

@ -1,80 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class BgeLargeEn < Base
class << self
def name
"bge-large-en"
end
def correctly_configured?
SiteSetting.ai_cloudflare_workers_api_token.present? ||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
)
end
def dependant_setting_names
%w[
ai_cloudflare_workers_api_token
ai_hugging_face_tei_endpoint_srv
ai_hugging_face_tei_endpoint
ai_embeddings_discourse_service_api_key
ai_embeddings_discourse_service_api_endpoint_srv
ai_embeddings_discourse_service_api_endpoint
]
end
end
def dimensions
1024
end
def max_sequence_length
512
end
def id
4
end
def version
1
end
def pg_function
"<#>"
end
def tokenizer
DiscourseAi::Tokenizer::BgeLargeEnTokenizer
end
def asymmetric_query_prefix
"Represent this sentence for searching relevant passages:"
end
def inference_client
inference_model_name = "baai/bge-large-en-v1.5"
if SiteSetting.ai_cloudflare_workers_api_token.present?
DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name)
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.instance(
inference_model_name.split("/").last,
)
else
raise "No inference endpoint configured"
end
end
end
end
end
end

View File

@ -1,51 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class BgeM3 < Base
class << self
def name
"bge-m3"
end
def correctly_configured?
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
end
def dependant_setting_names
%w[ai_hugging_face_tei_endpoint_srv ai_hugging_face_tei_endpoint]
end
end
def dimensions
1024
end
def max_sequence_length
8192
end
def id
8
end
def version
1
end
def pg_function
"<#>"
end
def tokenizer
DiscourseAi::Tokenizer::BgeM3Tokenizer
end
def inference_client
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
end
end
end
end
end

View File

@ -1,54 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class Gemini < Base
class << self
def name
"gemini"
end
def correctly_configured?
SiteSetting.ai_gemini_api_key.present?
end
def dependant_setting_names
%w[ai_gemini_api_key]
end
end
def id
5
end
def version
1
end
def dimensions
768
end
def max_sequence_length
1536 # Gemini has a max sequence length of 2048, but the API has a limit of 10000 bytes, hence the lower value
end
def pg_function
"<=>"
end
# There is no public tokenizer for Gemini, and from the ones we already ship in the plugin
# OpenAI gets the closest results. Gemini Tokenizer results in ~10% less tokens, so it's safe
# to use OpenAI tokenizer since it will overestimate the number of tokens.
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
def inference_client
DiscourseAi::Inference::GeminiEmbeddings.instance
end
end
end
end
end

View File

@ -1,88 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class MultilingualE5Large < Base
class << self
def name
"multilingual-e5-large"
end
def correctly_configured?
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
)
end
def dependant_setting_names
%w[
ai_hugging_face_tei_endpoint_srv
ai_hugging_face_tei_endpoint
ai_embeddings_discourse_service_api_key
ai_embeddings_discourse_service_api_endpoint_srv
ai_embeddings_discourse_service_api_endpoint
]
end
end
def id
3
end
def version
1
end
def dimensions
1024
end
def max_sequence_length
512
end
def pg_function
"<=>"
end
def tokenizer
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
end
def inference_client
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
else
raise "No inference endpoint configured"
end
end
def prepare_text(text, asymetric: false)
prepared_text = super(text, asymetric: asymetric)
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
return "query: #{prepared_text}"
end
prepared_text
end
def prepare_target_text(target)
prepared_text = super(target)
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
return "query: #{prepared_text}"
end
prepared_text
end
end
end
end
end

View File

@ -1,56 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class TextEmbedding3Large < Base
class << self
def name
"text-embedding-3-large"
end
def correctly_configured?
SiteSetting.ai_openai_api_key.present?
end
def dependant_setting_names
%w[ai_openai_api_key]
end
end
def id
7
end
def version
1
end
def dimensions
# real dimentions are 3072, but we only support up to 2000 in the
# indexes, so we downsample to 2000 via API
2000
end
def max_sequence_length
8191
end
def pg_function
"<=>"
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
def inference_client
DiscourseAi::Inference::OpenAiEmbeddings.instance(
model: self.class.name,
dimensions: dimensions,
)
end
end
end
end
end

View File

@ -1,51 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class TextEmbedding3Small < Base
class << self
def name
"text-embedding-3-small"
end
def correctly_configured?
SiteSetting.ai_openai_api_key.present?
end
def dependant_setting_names
%w[ai_openai_api_key]
end
end
def id
6
end
def version
1
end
def dimensions
1536
end
def max_sequence_length
8191
end
def pg_function
"<=>"
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
def inference_client
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
end
end
end
end
end

View File

@ -1,51 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class TextEmbeddingAda002 < Base
class << self
def name
"text-embedding-ada-002"
end
def correctly_configured?
SiteSetting.ai_openai_api_key.present?
end
def dependant_setting_names
%w[ai_openai_api_key]
end
end
def id
2
end
def version
1
end
def dimensions
1536
end
def max_sequence_length
8191
end
def pg_function
"<=>"
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
def inference_client
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
end
end
end
end
end

View File

@ -3,22 +3,13 @@
module ::DiscourseAi
module Inference
class CloudflareWorkersAi
def initialize(account_id, api_token, model, referer = Discourse.base_url)
@account_id = account_id
def initialize(endpoint, api_token, referer = Discourse.base_url)
@endpoint = endpoint
@api_token = api_token
@model = model
@referer = referer
end
def self.instance(model)
new(
SiteSetting.ai_cloudflare_workers_account_id,
SiteSetting.ai_cloudflare_workers_api_token,
model,
)
end
attr_reader :account_id, :api_token, :model, :referer
attr_reader :endpoint, :api_token, :referer
def perform!(content)
headers = {
@ -29,8 +20,6 @@ module ::DiscourseAi
payload = { text: [content] }
endpoint = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/#{model}"
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(endpoint, payload.to_json, headers)
@ -43,7 +32,7 @@ module ::DiscourseAi
Rails.logger.warn(
"Cloudflare Workers AI Embeddings failed with status: #{response.status} body: #{response.body}",
)
raise Net::HTTPBadResponse
raise Net::HTTPBadResponse.new(response.body.to_s)
end
end
end

View File

@ -1,47 +0,0 @@
# frozen_string_literal: true
module ::DiscourseAi
module Inference
class DiscourseClassifier
def initialize(endpoint, api_key, model, referer = Discourse.base_url)
@endpoint = endpoint
@api_key = api_key
@model = model
@referer = referer
end
def self.instance(model)
endpoint =
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_embeddings_discourse_service_api_endpoint
end
new(
"#{endpoint}/api/v1/classify",
SiteSetting.ai_embeddings_discourse_service_api_key,
model,
)
end
attr_reader :endpoint, :api_key, :model, :referer
def perform!(content)
headers = { "Referer" => referer, "Content-Type" => "application/json" }
headers["X-API-KEY"] = api_key if api_key.present?
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(endpoint, { model: model, content: content }.to_json, headers)
raise Net::HTTPBadResponse if ![200, 415].include?(response.status)
JSON.parse(response.body, symbolize_names: true)
end
end
end
end

View File

@ -3,21 +3,17 @@
module ::DiscourseAi
module Inference
class GeminiEmbeddings
def self.instance
new(SiteSetting.ai_gemini_api_key)
end
def initialize(api_key, referer = Discourse.base_url)
def initialize(embedding_url, api_key, referer = Discourse.base_url)
@api_key = api_key
@embedding_url = embedding_url
@referer = referer
end
attr_reader :api_key, :referer
attr_reader :embedding_url, :api_key, :referer
def perform!(content)
headers = { "Referer" => referer, "Content-Type" => "application/json" }
url =
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}"
url = "#{embedding_url}\?key\=#{api_key}"
body = { content: { parts: [{ text: content }] } }
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
@ -32,7 +28,7 @@ module ::DiscourseAi
Rails.logger.warn(
"Google Gemini Embeddings failed with status: #{response.status} body: #{response.body}",
)
raise Net::HTTPBadResponse
raise Net::HTTPBadResponse.new(response.body.to_s)
end
end
end

View File

@ -12,19 +12,6 @@ module ::DiscourseAi
attr_reader :endpoint, :key, :referer
class << self
def instance
endpoint =
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_hugging_face_tei_endpoint
end
new(endpoint, SiteSetting.ai_hugging_face_tei_api_key)
end
def configured?
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
@ -100,7 +87,7 @@ module ::DiscourseAi
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(endpoint, body, headers)
raise Net::HTTPBadResponse if ![200].include?(response.status)
raise Net::HTTPBadResponse.new(response.body.to_s) if ![200].include?(response.status)
JSON.parse(response.body, symbolize_names: true).first
end

View File

@ -12,10 +12,6 @@ module ::DiscourseAi
attr_reader :endpoint, :api_key, :model, :dimensions
def self.instance(model:, dimensions: nil)
new(SiteSetting.ai_openai_embeddings_url, SiteSetting.ai_openai_api_key, model, dimensions)
end
def perform!(content)
headers = { "Content-Type" => "application/json" }
@ -29,7 +25,7 @@ module ::DiscourseAi
payload[:dimensions] = dimensions if dimensions.present?
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(SiteSetting.ai_openai_embeddings_url, payload.to_json, headers)
response = conn.post(endpoint, payload.to_json, headers)
case response.status
when 200
@ -40,7 +36,7 @@ module ::DiscourseAi
Rails.logger.warn(
"OpenAI Embeddings failed with status: #{response.status} body: #{response.body}",
)
raise Net::HTTPBadResponse
raise Net::HTTPBadResponse.new(response.body.to_s)
end
end
end

View File

@ -35,6 +35,7 @@ register_asset "stylesheets/modules/embeddings/common/semantic-search.scss"
register_asset "stylesheets/modules/sentiment/common/dashboard.scss"
register_asset "stylesheets/modules/llms/common/ai-llms-editor.scss"
register_asset "stylesheets/modules/embeddings/common/ai-embedding-editor.scss"
register_asset "stylesheets/modules/llms/common/usage.scss"
register_asset "stylesheets/modules/llms/common/spam.scss"

View File

@ -1,17 +0,0 @@
# frozen_string_literal: true
require_relative "../support/embeddings_generation_stubs"
RSpec.describe DiscourseAi::Configuration::EmbeddingsModelValidator do
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
describe "#can_generate_embeddings?" do
it "works" do
discourse_model = "all-mpnet-base-v2"
EmbeddingsGenerationStubs.discourse_service(discourse_model, "this is a test", [1] * 1024)
expect(subject.can_generate_embeddings?(discourse_model)).to eq(true)
end
end
end

View File

@ -0,0 +1,40 @@
# frozen_string_literal: true
Fabricator(:embedding_definition) do
display_name "Multilingual E5 Large"
provider "hugging_face"
tokenizer_class "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer"
api_key "123"
url "https://test.com/embeddings"
provider_params nil
pg_function "<=>"
max_sequence_length 512
dimensions 1024
end
Fabricator(:cloudflare_embedding_def, from: :embedding_definition) do
display_name "BGE Large EN"
provider "cloudflare"
pg_function "<#>"
tokenizer_class "DiscourseAi::Tokenizer::BgeLargeEnTokenizer"
provider_params nil
end
Fabricator(:open_ai_embedding_def, from: :embedding_definition) do
display_name "ADA 002"
provider "open_ai"
url "https://api.openai.com/v1/embeddings"
tokenizer_class "DiscourseAi::Tokenizer::OpenAiTokenizer"
provider_params { { model_name: "text-embedding-ada-002" } }
max_sequence_length 8191
dimensions 1536
end
Fabricator(:gemini_embedding_def, from: :embedding_definition) do
display_name "Gemini's embedding-001"
provider "google"
dimensions 768
max_sequence_length 1536
tokenizer_class "DiscourseAi::Tokenizer::OpenAiTokenizer"
url "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent"
end

View File

@ -6,9 +6,8 @@ RSpec.describe Jobs::DigestRagUpload do
let(:document_file) { StringIO.new("some text" * 200) }
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
fab!(:cloudflare_embedding_def)
let(:expected_embedding) { [0.0038493] * cloudflare_embedding_def.dimensions }
let(:document_with_metadata) { plugin_file_from_fixtures("doc_with_metadata.txt", "rag") }
@ -21,15 +20,14 @@ RSpec.describe Jobs::DigestRagUpload do
end
before do
SiteSetting.ai_embeddings_selected_model = cloudflare_embedding_def.id
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_model = "bge-large-en"
SiteSetting.authorized_extensions = "txt"
WebMock.stub_request(
:post,
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
).to_return(status: 200, body: JSON.dump(expected_embedding))
WebMock.stub_request(:post, cloudflare_embedding_def.url).to_return(
status: 200,
body: JSON.dump(expected_embedding),
)
end
describe "#execute" do

View File

@ -2,23 +2,26 @@
RSpec.describe Jobs::GenerateRagEmbeddings do
describe "#execute" do
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
fab!(:vector_def) { Fabricate(:embedding_definition) }
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
let(:expected_embedding) { [0.0038493] * vector_def.dimensions }
fab!(:ai_persona)
fab!(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, target: ai_persona) }
fab!(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, target: ai_persona) }
let(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, target: ai_persona) }
let(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, target: ai_persona) }
before do
SiteSetting.ai_embeddings_selected_model = vector_def.id
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
WebMock.stub_request(
:post,
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
).to_return(status: 200, body: JSON.dump(expected_embedding))
rag_document_fragment_1
rag_document_fragment_2
WebMock.stub_request(:post, vector_def.url).to_return(
status: 200,
body: JSON.dump(expected_embedding),
)
end
it "generates a new vector for each fragment" do

View File

@ -0,0 +1,56 @@
# frozen_string_literal: true
RSpec.describe Jobs::ManageEmbeddingDefSearchIndex do
fab!(:embedding_definition)
describe "#execute" do
context "when there is no embedding def" do
it "does nothing" do
invalid_id = 999_999_999
subject.execute(id: invalid_id)
expect(
DiscourseAi::Embeddings::Schema.correctly_indexed?(
EmbeddingDefinition.new(id: invalid_id),
),
).to eq(false)
end
end
context "when the embedding def is fresh" do
it "creates the indexes" do
subject.execute(id: embedding_definition.id)
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(true)
end
it "creates them only once" do
subject.execute(id: embedding_definition.id)
subject.execute(id: embedding_definition.id)
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(true)
end
context "when one of the idxs is missing" do
it "automatically recovers by creating it" do
DB.exec <<~SQL
CREATE INDEX IF NOT EXISTS ai_topics_embeddings_#{embedding_definition.id}_1_search_bit ON ai_topics_embeddings
USING hnsw ((binary_quantize(embeddings)::bit(#{embedding_definition.dimensions})) bit_hamming_ops)
WHERE model_id = #{embedding_definition.id} AND strategy_id = 1;
SQL
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(
false,
)
subject.execute(id: embedding_definition.id)
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(
true,
)
end
end
end
end
end

View File

@ -19,11 +19,11 @@ RSpec.describe Jobs::EmbeddingsBackfill do
topic
end
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
fab!(:vector_def) { Fabricate(:embedding_definition) }
before do
SiteSetting.ai_embeddings_selected_model = vector_def.id
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_backfill_batch_size = 1
SiteSetting.ai_embeddings_per_post_enabled = true
Jobs.run_immediately!
@ -32,10 +32,10 @@ RSpec.describe Jobs::EmbeddingsBackfill do
it "backfills topics based on bumped_at date" do
embedding = Array.new(1024) { 1 }
WebMock.stub_request(
:post,
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
).to_return(status: 200, body: JSON.dump(embedding))
WebMock.stub_request(:post, "https://test.com/embeddings").to_return(
status: 200,
body: JSON.dump(embedding),
)
Jobs::EmbeddingsBackfill.new.execute({})

View File

@ -0,0 +1,71 @@
# frozen_string_literal: true
RSpec.describe Jobs::RemoveOrphanedEmbeddings do
describe "#execute" do
fab!(:embedding_definition)
fab!(:embedding_definition_2) { Fabricate(:embedding_definition) }
fab!(:topic)
fab!(:post)
before do
DiscourseAi::Embeddings::Schema.prepare_search_indexes(embedding_definition)
DiscourseAi::Embeddings::Schema.prepare_search_indexes(embedding_definition_2)
# Seed embeddings. One of each def x target classes.
[embedding_definition, embedding_definition_2].each do |edef|
SiteSetting.ai_embeddings_selected_model = edef.id
[topic, post].each do |target|
schema = DiscourseAi::Embeddings::Schema.for(target.class)
schema.store(target, [1] * edef.dimensions, "test")
end
end
embedding_definition.destroy!
end
def find_all_embeddings_of(target, table, target_column)
DB.query_single("SELECT model_id FROM #{table} WHERE #{target_column} = #{target.id}")
end
it "delete embeddings without an existing embedding definition" do
expect(find_all_embeddings_of(post, "ai_posts_embeddings", "post_id")).to contain_exactly(
embedding_definition.id,
embedding_definition_2.id,
)
expect(find_all_embeddings_of(topic, "ai_topics_embeddings", "topic_id")).to contain_exactly(
embedding_definition.id,
embedding_definition_2.id,
)
subject.execute({})
expect(find_all_embeddings_of(topic, "ai_topics_embeddings", "topic_id")).to contain_exactly(
embedding_definition_2.id,
)
expect(find_all_embeddings_of(post, "ai_posts_embeddings", "post_id")).to contain_exactly(
embedding_definition_2.id,
)
end
it "deletes orphaned indexes" do
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(true)
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition_2)).to eq(true)
subject.execute({})
index_names =
DiscourseAi::Embeddings::Schema::EMBEDDING_TARGETS.map do |t|
"ai_#{t}_embeddings_#{embedding_definition.id}_1_search_bit"
end
indexnames =
DB.query_single(
"SELECT indexname FROM pg_indexes WHERE indexname IN (:names)",
names: index_names,
)
expect(indexnames).to be_empty
expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition_2)).to eq(true)
end
end
end

View File

@ -1,15 +0,0 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Configuration::EmbeddingsModuleValidator do
let(:validator) { described_class.new }
describe "#can_generate_embeddings?" do
it "returns true if embeddings can be generated" do
stub_request(
:post,
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent?key=",
).to_return(status: 200, body: { embedding: { values: [1, 2, 3] } }.to_json)
expect(validator.can_generate_embeddings?("gemini")).to eq(true)
end
end
end

View File

@ -4,7 +4,7 @@ require "rails_helper"
require "webmock/rspec"
RSpec.describe DiscourseAi::Inference::CloudflareWorkersAi do
subject { described_class.new(account_id, api_token, model) }
subject { described_class.new(endpoint, api_token) }
let(:account_id) { "test_account_id" }
let(:api_token) { "test_api_token" }

View File

@ -297,9 +297,11 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
end
describe "#craft_prompt" do
fab!(:vector_def) { Fabricate(:embedding_definition) }
before do
Group.refresh_automatic_groups!
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_selected_model = vector_def.id
SiteSetting.ai_embeddings_enabled = true
end
@ -326,13 +328,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
fab!(:llm_model) { Fabricate(:fake_model) }
it "will run the question consolidator" do
vector_def = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
context_embedding = vector_def.dimensions.times.map { rand(-1.0...1.0) }
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
consolidated_question,
context_embedding,
)
EmbeddingsGenerationStubs.hugging_face_service(consolidated_question, context_embedding)
custom_ai_persona =
Fabricate(
@ -373,14 +370,11 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
end
context "when a persona has RAG uploads" do
let(:vector_def) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
end
let(:embedding_value) { 0.04381 }
let(:prompt_cc_embeddings) { [embedding_value] * vector_def.dimensions }
def stub_fragments(fragment_count, persona: ai_persona)
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector_def)
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment)
fragment_count.times do |i|
fragment =
@ -403,8 +397,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
stored_ai_persona = AiPersona.find(ai_persona.id)
UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id])
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
EmbeddingsGenerationStubs.hugging_face_service(
with_cc.dig(:conversation_context, 0, :content),
prompt_cc_embeddings,
)

View File

@ -108,16 +108,13 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
it "supports semantic search when enabled" do
assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model)
vector_def = Fabricate(:embedding_definition)
SiteSetting.ai_embeddings_selected_model = vector_def.id
SiteSetting.ai_embeddings_semantic_search_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
hyde_embedding = [0.049382] * vector_rep.dimensions
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
query,
hyde_embedding,
)
hyde_embedding = [0.049382] * vector_def.dimensions
EmbeddingsGenerationStubs.hugging_face_service(query, hyde_embedding)
post1 = Fabricate(:post, topic: topic_with_tags)
search =

View File

@ -1,6 +1,7 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
fab!(:vector_def) { Fabricate(:embedding_definition) }
fab!(:user)
fab!(:muted_category) { Fabricate(:category) }
fab!(:category_mute) do
@ -19,14 +20,13 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
before do
SiteSetting.ai_embeddings_selected_model = vector_def.id
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_model = "bge-large-en"
WebMock.stub_request(
:post,
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
).to_return(status: 200, body: JSON.dump(expected_embedding))
WebMock.stub_request(:post, vector_def.url).to_return(
status: 200,
body: JSON.dump([expected_embedding]),
)
vector.generate_representation_from(topic)
vector.generate_representation_from(muted_topic)

View File

@ -3,25 +3,26 @@
RSpec.describe Jobs::GenerateEmbeddings do
subject(:job) { described_class.new }
fab!(:vector_def) { Fabricate(:embedding_definition) }
describe "#execute" do
before do
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_selected_model = vector_def.id
SiteSetting.ai_embeddings_enabled = true
end
fab!(:topic)
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector_def) }
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vector_def) }
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic) }
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post) }
it "works for topics" do
expected_embedding = [0.0038493] * vector_def.dimensions
text = vector_def.prepare_target_text(topic)
EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding)
EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding)
job.execute(target_id: topic.id, target_type: "Topic")
@ -32,7 +33,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
expected_embedding = [0.0038493] * vector_def.dimensions
text = vector_def.prepare_target_text(post)
EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding)
EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding)
job.execute(target_id: post.id, target_type: "Post")

View File

@ -1,12 +1,14 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Embeddings::Schema do
subject(:posts_schema) { described_class.for(Post, vector_def: vector_def) }
subject(:posts_schema) { described_class.for(Post) }
fab!(:vector_def) { Fabricate(:cloudflare_embedding_def) }
let(:embeddings) { [0.0038490295] * vector_def.dimensions }
fab!(:post) { Fabricate(:post, post_number: 1) }
let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") }
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new }
before { SiteSetting.ai_embeddings_selected_model = vector_def.id }
before { posts_schema.store(post, embeddings, digest) }

View File

@ -13,7 +13,13 @@ describe DiscourseAi::Embeddings::SemanticRelated do
fab!(:secured_category_topic) { Fabricate(:topic, category: secured_category) }
fab!(:closed_topic) { Fabricate(:topic, closed: true) }
before { SiteSetting.ai_embeddings_semantic_related_topics_enabled = true }
fab!(:vector_def) { Fabricate(:embedding_definition) }
before do
SiteSetting.ai_embeddings_semantic_related_topics_enabled = true
SiteSetting.ai_embeddings_selected_model = vector_def.id
SiteSetting.ai_embeddings_enabled = true
end
describe "#related_topic_ids_for" do
context "when embeddings do not exist" do
@ -24,21 +30,15 @@ describe DiscourseAi::Embeddings::SemanticRelated do
topic
end
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
end
it "properly generates embeddings if missing" do
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
Jobs.run_immediately!
embedding = Array.new(1024) { 1 }
WebMock.stub_request(
:post,
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
).to_return(status: 200, body: JSON.dump(embedding))
WebMock.stub_request(:post, vector_def.url).to_return(
status: 200,
body: JSON.dump([embedding]),
)
# miss first
ids = semantic_related.related_topic_ids_for(topic)

View File

@ -7,22 +7,18 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
let(:query) { "test_query" }
let(:subject) { described_class.new(Guardian.new(user)) }
before { assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model) }
fab!(:vector_def) { Fabricate(:embedding_definition) }
before do
SiteSetting.ai_embeddings_selected_model = vector_def.id
assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model)
end
describe "#search_for_topics" do
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:hyde_embedding) { [0.049382] * vector_def.dimensions }
before do
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
hypothetical_post,
hyde_embedding,
)
end
before { EmbeddingsGenerationStubs.hugging_face_service(hypothetical_post, hyde_embedding) }
after { described_class.clear_cache_for(query) }

View File

@ -9,6 +9,10 @@ describe DiscourseAi::Embeddings::EntryPoint do
fab!(:target) { Fabricate(:topic) }
fab!(:vector_def) { Fabricate(:cloudflare_embedding_def) }
before { SiteSetting.ai_embeddings_selected_model = vector_def.id }
# The Distance gap to target increases for each element of topics.
def seed_embeddings(topics)
schema = DiscourseAi::Embeddings::Schema.for(Topic)

View File

@ -19,13 +19,13 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
)
end
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new }
fab!(:open_ai_embedding_def)
it "truncates a topic" do
prepared_text = truncation.prepare_target_text(topic, vector_def)
prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)
expect(vector_def.tokenizer.size(prepared_text)).to be <= vector_def.max_sequence_length
expect(open_ai_embedding_def.tokenizer.size(prepared_text)).to be <=
open_ai_embedding_def.max_sequence_length
end
end
end

View File

@ -7,8 +7,10 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
let(:expected_embedding_1) { [0.0038493] * vdef.dimensions }
let(:expected_embedding_2) { [0.0037684] * vdef.dimensions }
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vdef) }
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vdef) }
before { SiteSetting.ai_embeddings_selected_model = vdef.id }
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic) }
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post) }
fab!(:topic)
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
@ -84,63 +86,16 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
end
end
context "with text-embedding-ada-002" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with all all-mpnet-base-v2" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new }
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with gemini" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::Gemini.new }
let(:api_key) { "test-123" }
before { SiteSetting.ai_gemini_api_key = api_key }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.gemini_service(api_key, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with multilingual-e5-large" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large.new }
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with text-embedding-3-large" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large.new }
context "with open_ai as the provider" do
fab!(:vdef) { Fabricate(:open_ai_embedding_def) }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(
vdef.class.name,
vdef.lookup_custom_param("model_name"),
text,
expected_embedding,
extra_args: {
dimensions: 2000,
dimensions: vdef.dimensions,
},
)
end
@ -148,11 +103,31 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with text-embedding-3-small" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small.new }
context "with hugging_face as the provider" do
fab!(:vdef) { Fabricate(:embedding_definition) }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding)
EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with google as the provider" do
fab!(:vdef) { Fabricate(:gemini_embedding_def) }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.gemini_service(vdef.api_key, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with cloudflare as the provider" do
fab!(:vdef) { Fabricate(:cloudflare_embedding_def) }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.cloudflare_service(text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"

View File

@ -204,12 +204,12 @@ RSpec.describe AiTool do
end
context "when defining RAG fragments" do
fab!(:cloudflare_embedding_def)
before do
SiteSetting.authorized_extensions = "txt"
SiteSetting.ai_embeddings_selected_model = cloudflare_embedding_def.id
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_model = "bge-large-en"
Jobs.run_immediately!
end
@ -228,9 +228,9 @@ RSpec.describe AiTool do
# this is a trick, we get ever increasing embeddings, this gives us in turn
# 100% consistent search results
@counter = 0
stub_request(:post, "http://test.com/api/v1/classify").to_return(
stub_request(:post, cloudflare_embedding_def.url).to_return(
status: 200,
body: lambda { |req| ([@counter += 1] * 1024).to_json },
body: lambda { |req| { result: { data: [([@counter += 1] * 1024)] } }.to_json },
headers: {
},
)

View File

@ -4,11 +4,9 @@ RSpec.describe RagDocumentFragment do
fab!(:persona) { Fabricate(:ai_persona) }
fab!(:upload_1) { Fabricate(:upload) }
fab!(:upload_2) { Fabricate(:upload) }
fab!(:vector_def) { Fabricate(:embedding_definition) }
before do
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
end
before { SiteSetting.ai_embeddings_enabled = true }
describe ".link_uploads_and_persona" do
it "does nothing if there is no persona" do
@ -76,25 +74,25 @@ RSpec.describe RagDocumentFragment do
describe ".indexing_status" do
let(:vector) { DiscourseAi::Embeddings::Vector.instance }
fab!(:rag_document_fragment_1) do
let(:rag_document_fragment_1) do
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
end
fab!(:rag_document_fragment_2) do
let(:rag_document_fragment_2) do
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
end
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
let(:expected_embedding) { [0.0038493] * vector_def.dimensions }
before do
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_model = "bge-large-en"
SiteSetting.ai_embeddings_selected_model = vector_def.id
rag_document_fragment_1
rag_document_fragment_2
WebMock.stub_request(
:post,
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
).to_return(status: 200, body: JSON.dump(expected_embedding))
WebMock.stub_request(:post, "https://test.com/embeddings").to_return(
status: 200,
body: JSON.dump(expected_embedding),
)
vector.generate_representation_from(rag_document_fragment_1)
end
@ -106,7 +104,7 @@ RSpec.describe RagDocumentFragment do
UploadReference.create!(upload_id: upload_2.id, target: persona)
Sidekiq::Testing.fake! do
SiteSetting.ai_embeddings_model = "all-mpnet-base-v2"
SiteSetting.ai_embeddings_selected_model = Fabricate(:open_ai_embedding_def).id
expect(RagDocumentFragment.exists?(old_id)).to eq(false)
expect(Jobs::DigestRagUpload.jobs.size).to eq(2)
end

View File

@ -0,0 +1,184 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do
fab!(:admin)
before { sign_in(admin) }
let(:valid_attrs) do
{
display_name: "Embedding config test",
dimensions: 1001,
max_sequence_length: 234,
pg_function: "<#>",
provider: "hugging_face",
url: "https://test.com/api/v1/embeddings",
api_key: "test",
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
}
end
describe "POST #create" do
context "with valid attrs" do
it "creates a new embedding definition" do
post "/admin/plugins/discourse-ai/ai-embeddings.json", params: { ai_embedding: valid_attrs }
created_def = EmbeddingDefinition.last
expect(response.status).to eq(201)
expect(created_def.display_name).to eq(valid_attrs[:display_name])
end
it "stores provider-specific config params" do
post "/admin/plugins/discourse-ai/ai-embeddings.json",
params: {
ai_embedding:
valid_attrs.merge(
provider: "open_ai",
provider_params: {
model_name: "embeddings-v1",
},
),
}
created_def = EmbeddingDefinition.last
expect(response.status).to eq(201)
expect(created_def.provider_params["model_name"]).to eq("embeddings-v1")
end
it "ignores parameters not associated with that provider" do
post "/admin/plugins/discourse-ai/ai-embeddings.json",
params: {
ai_embedding: valid_attrs.merge(provider_params: { custom: "custom" }),
}
created_def = EmbeddingDefinition.last
expect(response.status).to eq(201)
expect(created_def.lookup_custom_param("custom")).to be_nil
end
end
context "with invalid attrs" do
it "doesn't create a new embedding defitinion" do
post "/admin/plugins/discourse-ai/ai-embeddings.json",
params: {
ai_embedding: valid_attrs.except(:provider),
}
created_def = EmbeddingDefinition.last
expect(created_def).to be_nil
end
end
end
describe "PUT #update" do
fab!(:embedding_definition)
context "with valid update params" do
let(:update_attrs) { { provider: "open_ai" } }
it "updates the model" do
put "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json",
params: {
ai_embedding: update_attrs,
}
expect(response.status).to eq(200)
expect(embedding_definition.reload.provider).to eq(update_attrs[:provider])
end
it "returns a 404 if there is no model with the given Id" do
put "/admin/plugins/discourse-ai/ai-embeddings/9999999.json"
expect(response.status).to eq(404)
end
it "doesn't allow dimenstions to be updated" do
new_dimensions = 200
put "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json",
params: {
ai_embedding: {
dimensions: new_dimensions,
},
}
expect(response.status).to eq(200)
expect(embedding_definition.reload.dimensions).not_to eq(new_dimensions)
end
end
context "with invalid update params" do
it "doesn't update the model" do
put "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json",
params: {
ai_embedding: {
url: "",
},
}
expect(response.status).to eq(422)
end
end
end
describe "DELETE #destroy" do
fab!(:embedding_definition)
it "destroys the embedding defitinion" do
expect {
delete "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json"
expect(response).to have_http_status(:no_content)
}.to change(EmbeddingDefinition, :count).by(-1)
end
it "validates the model is not in use" do
SiteSetting.ai_embeddings_selected_model = embedding_definition.id
delete "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json"
expect(response.status).to eq(409)
expect(embedding_definition.reload).to eq(embedding_definition)
end
end
describe "GET #test" do
context "when we can generate an embedding" do
it "returns a success true flag" do
WebMock.stub_request(:post, valid_attrs[:url]).to_return(status: 200, body: [[1]].to_json)
get "/admin/plugins/discourse-ai/ai-embeddings/test.json",
params: {
ai_embedding: valid_attrs,
}
expect(response).to be_successful
expect(response.parsed_body["success"]).to eq(true)
end
end
context "when we cannot generate an embedding" do
it "returns a success false flag and the error message" do
error_message = { error: "Embedding generation failed." }
WebMock.stub_request(:post, valid_attrs[:url]).to_return(
status: 422,
body: error_message.to_json,
)
get "/admin/plugins/discourse-ai/ai-embeddings/test.json",
params: {
ai_embedding: valid_attrs,
}
expect(response).to be_successful
expect(response.parsed_body["success"]).to eq(false)
expect(response.parsed_body["error"]).to eq(error_message.to_json)
end
end
end
end

View File

@ -8,7 +8,6 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
sign_in(admin)
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
end
describe "GET #index" do

View File

@ -4,11 +4,12 @@ RSpec.describe DiscourseAi::Admin::RagDocumentFragmentsController do
fab!(:admin)
fab!(:ai_persona)
fab!(:vector_def) { Fabricate(:embedding_definition) }
before do
sign_in(admin)
SiteSetting.ai_embeddings_selected_model = vector_def.id
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
end
describe "GET #indexing_status_check" do

View File

@ -2,9 +2,11 @@
describe DiscourseAi::Embeddings::EmbeddingsController do
context "when performing a topic search" do
fab!(:vector_def) { Fabricate(:open_ai_embedding_def) }
before do
SiteSetting.min_search_term_length = 3
SiteSetting.ai_embeddings_model = "text-embedding-3-small"
SiteSetting.ai_embeddings_selected_model = vector_def.id
DiscourseAi::Embeddings::SemanticSearch.clear_cache_for("test")
SearchIndexer.enable
end
@ -31,7 +33,15 @@ describe DiscourseAi::Embeddings::EmbeddingsController do
def stub_embedding(query)
embedding = [0.049382] * 1536
EmbeddingsGenerationStubs.openai_service(SiteSetting.ai_embeddings_model, query, embedding)
EmbeddingsGenerationStubs.openai_service(
vector_def.lookup_custom_param("model_name"),
query,
embedding,
extra_args: {
dimensions: vector_def.dimensions,
},
)
end
def create_api_key(user)

View File

@ -1,10 +1,13 @@
# frozen_string_literal: true
describe DiscourseAi::Inference::OpenAiEmbeddings do
let(:api_key) { "123456" }
let(:dimensions) { 1000 }
let(:model) { "text-embedding-ada-002" }
it "supports azure embeddings" do
SiteSetting.ai_openai_embeddings_url =
azure_url =
"https://my-company.openai.azure.com/openai/deployments/embeddings-deployment/embeddings?api-version=2023-05-15"
SiteSetting.ai_openai_api_key = "123456"
body_json = {
usage: {
@ -14,28 +17,22 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
data: [{ object: "embedding", embedding: [0.0, 0.1] }],
}.to_json
stub_request(
:post,
"https://my-company.openai.azure.com/openai/deployments/embeddings-deployment/embeddings?api-version=2023-05-15",
).with(
stub_request(:post, azure_url).with(
body: "{\"model\":\"text-embedding-ada-002\",\"input\":\"hello\"}",
headers: {
"Api-Key" => "123456",
"Api-Key" => api_key,
"Content-Type" => "application/json",
},
).to_return(status: 200, body: body_json, headers: {})
result =
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: "text-embedding-ada-002").perform!(
"hello",
)
DiscourseAi::Inference::OpenAiEmbeddings.new(azure_url, api_key, model, nil).perform!("hello")
expect(result).to eq([0.0, 0.1])
end
it "supports openai embeddings" do
SiteSetting.ai_openai_api_key = "123456"
url = "https://api.openai.com/v1/embeddings"
body_json = {
usage: {
prompt_tokens: 1,
@ -44,21 +41,20 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
data: [{ object: "embedding", embedding: [0.0, 0.1] }],
}.to_json
body = { model: "text-embedding-ada-002", input: "hello", dimensions: 1000 }.to_json
body = { model: model, input: "hello", dimensions: dimensions }.to_json
stub_request(:post, "https://api.openai.com/v1/embeddings").with(
stub_request(:post, url).with(
body: body,
headers: {
"Authorization" => "Bearer 123456",
"Authorization" => "Bearer #{api_key}",
"Content-Type" => "application/json",
},
).to_return(status: 200, body: body_json, headers: {})
result =
DiscourseAi::Inference::OpenAiEmbeddings.instance(
model: "text-embedding-ada-002",
dimensions: 1000,
).perform!("hello")
DiscourseAi::Inference::OpenAiEmbeddings.new(url, api_key, model, dimensions).perform!(
"hello",
)
expect(result).to eq([0.0, 0.1])
end

View File

@ -2,15 +2,11 @@
class EmbeddingsGenerationStubs
class << self
def discourse_service(model, string, embedding)
model = "bge-large-en-v1.5" if model == "bge-large-en"
def hugging_face_service(string, embedding)
WebMock
.stub_request(
:post,
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
)
.with(body: JSON.dump({ model: model, content: string }))
.to_return(status: 200, body: JSON.dump(embedding))
.stub_request(:post, "https://test.com/embeddings")
.with(body: JSON.dump({ inputs: string, truncate: true }))
.to_return(status: 200, body: JSON.dump([embedding]))
end
def openai_service(model, string, embedding, extra_args: {})
@ -29,5 +25,12 @@ class EmbeddingsGenerationStubs
.with(body: JSON.dump({ content: { parts: [{ text: string }] } }))
.to_return(status: 200, body: JSON.dump({ embedding: { values: embedding } }))
end
def cloudflare_service(string, embedding)
WebMock
.stub_request(:post, "https://test.com/embeddings")
.with(body: JSON.dump({ text: [string] }))
.to_return(status: 200, body: JSON.dump({ result: { data: [embedding] } }))
end
end
end

View File

@ -0,0 +1,87 @@
# frozen_string_literal: true
RSpec.describe "Managing Embeddings configurations", type: :system, js: true do
fab!(:admin)
let(:page_header) { PageObjects::Components::DPageHeader.new }
before { sign_in(admin) }
it "correctly sets defaults" do
preset = "text-embedding-3-small"
api_key = "abcd"
visit "/admin/plugins/discourse-ai/ai-embeddings"
find(".ai-embeddings-list-editor__new-button").click()
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__presets")
select_kit.expand
select_kit.select_row_by_value(preset)
find(".ai-embedding-editor__next").click
find("input.ai-embedding-editor__api-key").fill_in(with: api_key)
find(".ai-embedding-editor__save").click()
expect(page).to have_current_path("/admin/plugins/discourse-ai/ai-embeddings")
embedding_def = EmbeddingDefinition.order(:id).last
expect(embedding_def.api_key).to eq(api_key)
preset = EmbeddingDefinition.presets.find { |p| p[:preset_id] == preset }
expect(embedding_def.display_name).to eq(preset[:display_name])
expect(embedding_def.url).to eq(preset[:url])
expect(embedding_def.tokenizer_class).to eq(preset[:tokenizer_class])
expect(embedding_def.dimensions).to eq(preset[:dimensions])
expect(embedding_def.max_sequence_length).to eq(preset[:max_sequence_length])
expect(embedding_def.pg_function).to eq(preset[:pg_function])
expect(embedding_def.provider).to eq(preset[:provider])
expect(embedding_def.provider_params.symbolize_keys).to eq(preset[:provider_params])
end
it "supports manual config" do
api_key = "abcd"
visit "/admin/plugins/discourse-ai/ai-embeddings"
find(".ai-embeddings-list-editor__new-button").click()
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__presets")
select_kit.expand
select_kit.select_row_by_value("manual")
find(".ai-embedding-editor__next").click
find("input.ai-embedding-editor__display-name").fill_in(with: "OpenAI's text-embedding-3-small")
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__provider")
select_kit.expand
select_kit.select_row_by_value(EmbeddingDefinition::OPEN_AI)
find("input.ai-embedding-editor__url").fill_in(with: "https://api.openai.com/v1/embeddings")
find("input.ai-embedding-editor__api-key").fill_in(with: api_key)
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__tokenizer")
select_kit.expand
select_kit.select_row_by_value("DiscourseAi::Tokenizer::OpenAiTokenizer")
find("input.ai-embedding-editor__dimensions").fill_in(with: 1536)
find("input.ai-embedding-editor__max_sequence_length").fill_in(with: 8191)
select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__distance_functions")
select_kit.expand
select_kit.select_row_by_value("<=>")
find(".ai-embedding-editor__save").click()
expect(page).to have_current_path("/admin/plugins/discourse-ai/ai-embeddings")
embedding_def = EmbeddingDefinition.order(:id).last
expect(embedding_def.api_key).to eq(api_key)
preset = EmbeddingDefinition.presets.find { |p| p[:preset_id] == "text-embedding-3-small" }
expect(embedding_def.display_name).to eq(preset[:display_name])
expect(embedding_def.url).to eq(preset[:url])
expect(embedding_def.tokenizer_class).to eq(preset[:tokenizer_class])
expect(embedding_def.dimensions).to eq(preset[:dimensions])
expect(embedding_def.max_sequence_length).to eq(preset[:max_sequence_length])
expect(embedding_def.pg_function).to eq(preset[:pg_function])
expect(embedding_def.provider).to eq(preset[:provider])
end
end

View File

@ -10,7 +10,6 @@ RSpec.describe "AI Composer helper", type: :system, js: true do
fab!(:post) { Fabricate(:post, topic: topic, raw: "Apple pie is a delicious dessert to eat") }
before do
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
prompt = DiscourseAi::Embeddings::HydeGenerators::OpenAi.new.prompt(query)
OpenAiCompletionsInferenceStubs.stub_response(
prompt,
@ -21,11 +20,7 @@ RSpec.describe "AI Composer helper", type: :system, js: true do
)
hyde_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
hypothetical_post,
hyde_embedding,
)
EmbeddingsGenerationStubs.hugging_face_service(hypothetical_post, hyde_embedding)
SearchIndexer.enable
SearchIndexer.index(topic, force: true)