diff --git a/app/controllers/discourse_ai/embeddings/embeddings_controller.rb b/app/controllers/discourse_ai/embeddings/embeddings_controller.rb
index 71a7c435..a202c709 100644
--- a/app/controllers/discourse_ai/embeddings/embeddings_controller.rb
+++ b/app/controllers/discourse_ai/embeddings/embeddings_controller.rb
@@ -9,7 +9,6 @@ module DiscourseAi
def search
query = params[:q]
- page = (params[:page] || 1).to_i
grouped_results =
Search::GroupedSearchResults.new(
@@ -19,12 +18,19 @@ module DiscourseAi
use_pg_headlines_for_excerpt: false,
)
- DiscourseAi::Embeddings::SemanticSearch
- .new(guardian)
- .search_for_topics(query, page)
- .each { |topic_post| grouped_results.add(topic_post) }
+ semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian)
- render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
+ if !semantic_search.cached_query?(query)
+ RateLimiter.new(current_user, "semantic-search", 4, 1.minutes).performed!
+ end
+
+ hijack do
+ semantic_search
+ .search_for_topics(query)
+ .each { |topic_post| grouped_results.add(topic_post) }
+
+ render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
+ end
end
end
end
diff --git a/assets/javascripts/discourse/connectors/full-page-search-below-search-header/semantic-search.hbs b/assets/javascripts/discourse/connectors/full-page-search-below-search-header/semantic-search.hbs
new file mode 100644
index 00000000..9b71af92
--- /dev/null
+++ b/assets/javascripts/discourse/connectors/full-page-search-below-search-header/semantic-search.hbs
@@ -0,0 +1,51 @@
+{{#if this.searchEnabled}}
+
+
+ {{#if this.searching}}
+
+
+ {{i18n "discourse_ai.embeddings.semantic_search_loading"}}
+
+
+ .
+ .
+ .
+
+
+ {{else}}
+ {{#if this.results.length}}
+
+
+
+
+ {{#unless this.collapsedResults}}
+
+
+
+ {{/unless}}
+ {{else}}
+
+ {{i18n "discourse_ai.embeddings.semantic_search_results.none"}}
+
+ {{/if}}
+ {{/if}}
+
+
+{{/if}}
\ No newline at end of file
diff --git a/assets/javascripts/discourse/connectors/full-page-search-below-search-header/semantic-search.js b/assets/javascripts/discourse/connectors/full-page-search-below-search-header/semantic-search.js
new file mode 100644
index 00000000..083994c3
--- /dev/null
+++ b/assets/javascripts/discourse/connectors/full-page-search-below-search-header/semantic-search.js
@@ -0,0 +1,82 @@
+import Component from "@glimmer/component";
+import { action, computed } from "@ember/object";
+import I18n from "I18n";
+import { tracked } from "@glimmer/tracking";
+import { ajax } from "discourse/lib/ajax";
+import { translateResults } from "discourse/lib/search";
+import discourseDebounce from "discourse-common/lib/debounce";
+import { inject as service } from "@ember/service";
+import { bind } from "discourse-common/utils/decorators";
+import { SEARCH_TYPE_DEFAULT } from "discourse/controllers/full-page-search";
+
+export default class extends Component {
+ static shouldRender(_args, { siteSettings }) {
+ return siteSettings.ai_embeddings_semantic_search_enabled;
+ }
+
+ @service appEvents;
+
+ @tracked searching = false;
+ @tracked collapsedResults = true;
+ @tracked results = [];
+
+ @computed("args.outletArgs.search")
+ get searchTerm() {
+ return this.args.outletArgs.search;
+ }
+
+ @computed("args.outletArgs.type")
+ get searchEnabled() {
+ return this.args.outletArgs.type === SEARCH_TYPE_DEFAULT;
+ }
+
+ @computed("results")
+ get collapsedResultsTitle() {
+ return I18n.t("discourse_ai.embeddings.semantic_search_results.toggle", {
+ count: this.results.length,
+ });
+ }
+
+ @action
+ setup() {
+ this.appEvents.on(
+ "full-page-search:trigger-search",
+ this,
+ "debouncedSearch"
+ );
+ }
+
+ @action
+ teardown() {
+ this.appEvents.off(
+ "full-page-search:trigger-search",
+ this,
+ "debouncedSearch"
+ );
+ }
+
+ @bind
+ performHyDESearch() {
+ if (!this.searchTerm || !this.searchEnabled || this.searching) {
+ return;
+ }
+
+ this.searching = true;
+ this.collapsedResults = true;
+ this.results = [];
+
+ ajax("/discourse-ai/embeddings/semantic-search", {
+ data: { q: this.searchTerm },
+ })
+ .then(async (results) => {
+ const model = (await translateResults(results)) || {};
+ this.results = model.posts;
+ })
+ .finally(() => (this.searching = false));
+ }
+
+ @action
+ debouncedSearch() {
+ discourseDebounce(this, this.performHyDESearch, 500);
+ }
+}
diff --git a/assets/javascripts/initializers/semantic-full-page-search.js b/assets/javascripts/initializers/semantic-full-page-search.js
deleted file mode 100644
index 4db7f979..00000000
--- a/assets/javascripts/initializers/semantic-full-page-search.js
+++ /dev/null
@@ -1,63 +0,0 @@
-import { withPluginApi } from "discourse/lib/plugin-api";
-import { translateResults, updateRecentSearches } from "discourse/lib/search";
-import { ajax } from "discourse/lib/ajax";
-
-const SEMANTIC_SEARCH = "semantic_search";
-
-function initializeSemanticSearch(api) {
- api.addFullPageSearchType(
- "discourse_ai.embeddings.semantic_search",
- SEMANTIC_SEARCH,
- (searchController, args) => {
- if (searchController.currentUser) {
- updateRecentSearches(searchController.currentUser, args.searchTerm);
- }
-
- ajax("/discourse-ai/embeddings/semantic-search", { data: args })
- .then(async (results) => {
- const model = (await translateResults(results)) || {};
-
- if (results.grouped_search_result) {
- searchController.set("q", results.grouped_search_result.term);
- }
-
- if (args.page > 1) {
- if (model) {
- searchController.model.posts.pushObjects(model.posts);
- searchController.model.topics.pushObjects(model.topics);
- searchController.model.set(
- "grouped_search_result",
- results.grouped_search_result
- );
- }
- } else {
- model.grouped_search_result = results.grouped_search_result;
- searchController.set("model", model);
- }
- searchController.set("error", null);
- })
- .catch((e) => {
- searchController.set("error", e.jqXHR.responseJSON?.message);
- })
- .finally(() => {
- searchController.setProperties({
- searching: false,
- loading: false,
- });
- });
- }
- );
-}
-
-export default {
- name: "discourse-ai-full-page-semantic-search",
-
- initialize(container) {
- const settings = container.lookup("service:site-settings");
- const semanticSearch = settings.ai_embeddings_semantic_search_enabled;
-
- if (settings.ai_embeddings_enabled && semanticSearch) {
- withPluginApi("1.6.0", initializeSemanticSearch);
- }
- },
-};
diff --git a/assets/stylesheets/modules/embeddings/common/semantic-search.scss b/assets/stylesheets/modules/embeddings/common/semantic-search.scss
new file mode 100644
index 00000000..fcf09cb6
--- /dev/null
+++ b/assets/stylesheets/modules/embeddings/common/semantic-search.scss
@@ -0,0 +1,39 @@
+.semantic-search__container {
+ background: var(--primary-very-low);
+ margin: 1rem 0 1rem 0;
+
+ .semantic-search__results {
+ display: flex;
+ flex-direction: column;
+ align-items: baseline;
+
+ .semantic-search {
+ &__searching-text {
+ display: inline-block;
+ margin-left: 3px;
+ }
+ &__indicator-wave {
+ flex: 0 0 auto;
+ display: inline-flex;
+ }
+ &__indicator-dot {
+ display: inline-block;
+ animation: ai-summary__indicator-wave 1.8s linear infinite;
+ &:nth-child(2) {
+ animation-delay: -1.6s;
+ }
+ &:nth-child(3) {
+ animation-delay: -1.4s;
+ }
+ }
+ }
+
+ .semantic-search__entries {
+ margin-top: 10px;
+ }
+
+ .semantic-search__searching {
+ margin-left: 5px;
+ }
+ }
+}
diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml
index f026b06a..271ea9e4 100644
--- a/config/locales/client.en.yml
+++ b/config/locales/client.en.yml
@@ -34,6 +34,10 @@ en:
embeddings:
semantic_search: "Topics (Semantic)"
+ semantic_search_loading: "Searching for more results using AI"
+ semantic_search_results:
+ toggle: "Found %{count} results using AI"
+ none: "Sorry, our AI search found no matching topics."
ai_bot:
pm_warning: "AI chatbot messages are monitored regularly by moderators."
diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml
index 2b5ec6b3..987e5013 100644
--- a/config/locales/server.en.yml
+++ b/config/locales/server.en.yml
@@ -55,6 +55,7 @@ en:
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
ai_embeddings_semantic_search_enabled: "Enable full-page semantic search."
ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results"
+ ai_embeddings_semantic_search_hyde_model: "Model used to expand keywords to get better results during a semantic search"
ai_summarization_discourse_service_api_endpoint: "URL where the Discourse summarization API is running."
ai_summarization_discourse_service_api_key: "API key for the Discourse summarization API."
diff --git a/config/settings.yml b/config/settings.yml
index 621b2635..6663a133 100644
--- a/config/settings.yml
+++ b/config/settings.yml
@@ -177,6 +177,18 @@ discourse_ai:
ai_embeddings_semantic_search_enabled:
default: false
client: true
+ ai_embeddings_semantic_search_hyde_model:
+ default: "gpt-3.5-turbo"
+ type: enum
+ allow_any: false
+ choices:
+ - Llama2-*-chat-hf
+ - claude-instant-1
+ - claude-2
+ - gpt-3.5-turbo
+ - gpt-4
+ - StableBeluga2
+ - Upstage-Llama-2-*-instruct-v2
ai_summarization_discourse_service_api_endpoint: ""
ai_summarization_discourse_service_api_key:
diff --git a/lib/modules/ai_helper/semantic_categorizer.rb b/lib/modules/ai_helper/semantic_categorizer.rb
index 5acb3f5b..c3918c5d 100644
--- a/lib/modules/ai_helper/semantic_categorizer.rb
+++ b/lib/modules/ai_helper/semantic_categorizer.rb
@@ -11,13 +11,12 @@ module DiscourseAi
return [] if @text.blank?
return [] unless SiteSetting.ai_embeddings_enabled
+ strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
+ vector_rep =
+ DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
+
candidates =
- ::DiscourseAi::Embeddings::SemanticSearch.new(nil).asymmetric_semantic_search(
- @text,
- 100,
- 0,
- return_distance: true,
- )
+ vector_rep.asymmetric_semantic_search(@text, limit: 100, offset: 0, return_distance: true)
candidate_ids = candidates.map(&:first)
::Topic
diff --git a/lib/modules/embeddings/entry_point.rb b/lib/modules/embeddings/entry_point.rb
index bcbc7979..208e297b 100644
--- a/lib/modules/embeddings/entry_point.rb
+++ b/lib/modules/embeddings/entry_point.rb
@@ -4,16 +4,21 @@ module DiscourseAi
module Embeddings
class EntryPoint
def load_files
- require_relative "models/base"
- require_relative "models/all_mpnet_base_v2"
- require_relative "models/text_embedding_ada_002"
- require_relative "models/multilingual_e5_large"
+ require_relative "vector_representations/base"
+ require_relative "vector_representations/all_mpnet_base_v2"
+ require_relative "vector_representations/text_embedding_ada_002"
+ require_relative "vector_representations/multilingual_e5_large"
require_relative "strategies/truncation"
- require_relative "manager"
require_relative "jobs/regular/generate_embeddings"
require_relative "semantic_related"
- require_relative "semantic_search"
require_relative "semantic_topic_query"
+
+ require_relative "hyde_generators/base"
+ require_relative "hyde_generators/openai"
+ require_relative "hyde_generators/anthropic"
+ require_relative "hyde_generators/llama2"
+ require_relative "hyde_generators/llama2_ftos"
+ require_relative "semantic_search"
end
def inject_into(plugin)
diff --git a/lib/modules/embeddings/hyde_generators/anthropic.rb b/lib/modules/embeddings/hyde_generators/anthropic.rb
new file mode 100644
index 00000000..693ea0dc
--- /dev/null
+++ b/lib/modules/embeddings/hyde_generators/anthropic.rb
@@ -0,0 +1,32 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Embeddings
+ module HydeGenerators
+ class Anthropic < DiscourseAi::Embeddings::HydeGenerators::Base
+ def prompt(search_term)
+ <<~TEXT
+ Given a search term given between tags, generate a forum post about the search term.
+ Respond with the generated post between tags.
+
+ #{search_term}
+ TEXT
+ end
+
+ def models
+ %w[claude-instant-1 claude-2]
+ end
+
+ def hypothetical_post_from(query)
+ response =
+ ::DiscourseAi::Inference::AnthropicCompletions.perform!(
+ prompt(query),
+ SiteSetting.ai_embeddings_semantic_search_hyde_model,
+ ).dig(:completion)
+
+ Nokogiri::HTML5.fragment(response).at("ai").text
+ end
+ end
+ end
+ end
+end
diff --git a/lib/modules/embeddings/hyde_generators/base.rb b/lib/modules/embeddings/hyde_generators/base.rb
new file mode 100644
index 00000000..8514b414
--- /dev/null
+++ b/lib/modules/embeddings/hyde_generators/base.rb
@@ -0,0 +1,17 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Embeddings
+ module HydeGenerators
+ class Base
+ def self.current_hyde_model
+ DiscourseAi::Embeddings::HydeGenerators::Base.descendants.find do |generator_klass|
+ generator_klass.new.models.include?(
+ SiteSetting.ai_embeddings_semantic_search_hyde_model,
+ )
+ end
+ end
+ end
+ end
+ end
+end
diff --git a/lib/modules/embeddings/hyde_generators/llama2.rb b/lib/modules/embeddings/hyde_generators/llama2.rb
new file mode 100644
index 00000000..6a72bb8c
--- /dev/null
+++ b/lib/modules/embeddings/hyde_generators/llama2.rb
@@ -0,0 +1,34 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Embeddings
+ module HydeGenerators
+ class Llama2 < DiscourseAi::Embeddings::HydeGenerators::Base
+ def prompt(search_term)
+ <<~TEXT
+ [INST] <>
+ You are a helpful bot
+ You create forum posts about a given topic
+ <>
+
+ Topic: #{search_term}
+ [/INST]
+ Here is a forum post about the above topic:
+ TEXT
+ end
+
+ def models
+ ["Llama2-*-chat-hf"]
+ end
+
+ def hypothetical_post_from(query)
+ ::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(
+ prompt(query),
+ SiteSetting.ai_embeddings_semantic_search_hyde_model,
+ token_limit: 400,
+ ).dig(:generated_text)
+ end
+ end
+ end
+ end
+end
diff --git a/lib/modules/embeddings/hyde_generators/llama2_ftos.rb b/lib/modules/embeddings/hyde_generators/llama2_ftos.rb
new file mode 100644
index 00000000..fd4245ba
--- /dev/null
+++ b/lib/modules/embeddings/hyde_generators/llama2_ftos.rb
@@ -0,0 +1,27 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Embeddings
+ module HydeGenerators
+ class Llama2Ftos < DiscourseAi::Embeddings::HydeGenerators::Llama2
+ def prompt(search_term)
+ <<~TEXT
+ ### System:
+ You are a helpful bot
+ You create forum posts about a given topic
+
+ ### User:
+ Topic: #{search_term}
+
+ ### Assistant:
+ Here is a forum post about the above topic:
+ TEXT
+ end
+
+ def models
+ %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2]
+ end
+ end
+ end
+ end
+end
diff --git a/lib/modules/embeddings/hyde_generators/openai.rb b/lib/modules/embeddings/hyde_generators/openai.rb
new file mode 100644
index 00000000..f44ca8fe
--- /dev/null
+++ b/lib/modules/embeddings/hyde_generators/openai.rb
@@ -0,0 +1,30 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Embeddings
+ module HydeGenerators
+ class OpenAi < DiscourseAi::Embeddings::HydeGenerators::Base
+ def prompt(search_term)
+ [
+ {
+ role: "system",
+ content: "You are a helpful bot. You create forum posts about a given topic.",
+ },
+ { role: "user", content: "Create a forum post about the topic: #{search_term}" },
+ ]
+ end
+
+ def models
+ %w[gpt-3.5-turbo gpt-4]
+ end
+
+ def hypothetical_post_from(query)
+ ::DiscourseAi::Inference::OpenAiCompletions.perform!(
+ prompt(query),
+ SiteSetting.ai_embeddings_semantic_search_hyde_model,
+ ).dig(:choices, 0, :message, :content)
+ end
+ end
+ end
+ end
+end
diff --git a/lib/modules/embeddings/jobs/regular/generate_embeddings.rb b/lib/modules/embeddings/jobs/regular/generate_embeddings.rb
index 7d41cd30..919b43c6 100644
--- a/lib/modules/embeddings/jobs/regular/generate_embeddings.rb
+++ b/lib/modules/embeddings/jobs/regular/generate_embeddings.rb
@@ -11,7 +11,11 @@ module Jobs
post = topic.first_post
return if post.nil? || post.raw.blank?
- DiscourseAi::Embeddings::Manager.new(topic).generate!
+ strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
+ vector_rep =
+ DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new
+
+ vector_rep.generate_topic_representation_from(topic, strategy)
end
end
end
diff --git a/lib/modules/embeddings/manager.rb b/lib/modules/embeddings/manager.rb
deleted file mode 100644
index 2a0c5aee..00000000
--- a/lib/modules/embeddings/manager.rb
+++ /dev/null
@@ -1,64 +0,0 @@
-# frozen_string_literal: true
-
-module DiscourseAi
- module Embeddings
- class Manager
- attr_reader :target, :model, :strategy
-
- def initialize(target)
- @target = target
- @model =
- DiscourseAi::Embeddings::Models::Base.subclasses.find do
- _1.name == SiteSetting.ai_embeddings_model
- end
- @strategy = DiscourseAi::Embeddings::Strategies::Truncation.new(@target, @model)
- end
-
- def generate!
- @strategy.process!
-
- # TODO bail here if we already have an embedding with matching version and digest
-
- @embeddings = @model.generate_embeddings(@strategy.processed_target)
-
- persist!
- end
-
- def persist!
- begin
- DB.exec(
- <<~SQL,
- INSERT INTO ai_topic_embeddings_#{table_suffix} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
- VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
- ON CONFLICT (topic_id)
- DO UPDATE SET
- model_version = :model_version,
- strategy_version = :strategy_version,
- digest = :digest,
- embeddings = '[:embeddings]',
- updated_at = CURRENT_TIMESTAMP
-
- SQL
- topic_id: @target.id,
- model_version: @model.version,
- strategy_version: @strategy.version,
- digest: @strategy.digest,
- embeddings: @embeddings,
- )
- rescue PG::Error => e
- Rails.logger.error(
- "Error #{e} persisting embedding for topic #{topic.id} and model #{model.name}",
- )
- end
- end
-
- def table_suffix
- "#{@model.id}_#{@strategy.id}"
- end
-
- def topic_embeddings_table
- "ai_topic_embeddings_#{table_suffix}"
- end
- end
- end
-end
diff --git a/lib/modules/embeddings/models/all_mpnet_base_v2.rb b/lib/modules/embeddings/models/all_mpnet_base_v2.rb
deleted file mode 100644
index 7160052a..00000000
--- a/lib/modules/embeddings/models/all_mpnet_base_v2.rb
+++ /dev/null
@@ -1,52 +0,0 @@
-# frozen_string_literal: true
-
-module DiscourseAi
- module Embeddings
- module Models
- class AllMpnetBaseV2 < Base
- class << self
- def id
- 1
- end
-
- def version
- 1
- end
-
- def name
- "all-mpnet-base-v2"
- end
-
- def dimensions
- 768
- end
-
- def max_sequence_length
- 384
- end
-
- def pg_function
- "<#>"
- end
-
- def pg_index_type
- "vector_ip_ops"
- end
-
- def generate_embeddings(text)
- DiscourseAi::Inference::DiscourseClassifier.perform!(
- "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
- name,
- text,
- SiteSetting.ai_embeddings_discourse_service_api_key,
- )
- end
-
- def tokenizer
- DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
- end
- end
- end
- end
- end
-end
diff --git a/lib/modules/embeddings/models/base.rb b/lib/modules/embeddings/models/base.rb
deleted file mode 100644
index c888308a..00000000
--- a/lib/modules/embeddings/models/base.rb
+++ /dev/null
@@ -1,10 +0,0 @@
-# frozen_string_literal: true
-
-module DiscourseAi
- module Embeddings
- module Models
- class Base
- end
- end
- end
-end
diff --git a/lib/modules/embeddings/models/multilingual_e5_large.rb b/lib/modules/embeddings/models/multilingual_e5_large.rb
deleted file mode 100644
index 4de1b4d4..00000000
--- a/lib/modules/embeddings/models/multilingual_e5_large.rb
+++ /dev/null
@@ -1,52 +0,0 @@
-# frozen_string_literal: true
-
-module DiscourseAi
- module Embeddings
- module Models
- class MultilingualE5Large < Base
- class << self
- def id
- 3
- end
-
- def version
- 1
- end
-
- def name
- "multilingual-e5-large"
- end
-
- def dimensions
- 1024
- end
-
- def max_sequence_length
- 512
- end
-
- def pg_function
- "<=>"
- end
-
- def pg_index_type
- "vector_cosine_ops"
- end
-
- def generate_embeddings(text)
- DiscourseAi::Inference::DiscourseClassifier.perform!(
- "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
- name,
- "query: #{text}",
- SiteSetting.ai_embeddings_discourse_service_api_key,
- )
- end
-
- def tokenizer
- DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
- end
- end
- end
- end
- end
-end
diff --git a/lib/modules/embeddings/models/text_embedding_ada_002.rb b/lib/modules/embeddings/models/text_embedding_ada_002.rb
deleted file mode 100644
index 167418d4..00000000
--- a/lib/modules/embeddings/models/text_embedding_ada_002.rb
+++ /dev/null
@@ -1,48 +0,0 @@
-# frozen_string_literal: true
-
-module DiscourseAi
- module Embeddings
- module Models
- class TextEmbeddingAda002 < Base
- class << self
- def id
- 2
- end
-
- def version
- 1
- end
-
- def name
- "text-embedding-ada-002"
- end
-
- def dimensions
- 1536
- end
-
- def max_sequence_length
- 8191
- end
-
- def pg_function
- "<=>"
- end
-
- def pg_index_type
- "vector_cosine_ops"
- end
-
- def generate_embeddings(text)
- response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text)
- response[:data].first[:embedding]
- end
-
- def tokenizer
- DiscourseAi::Tokenizer::OpenAiTokenizer
- end
- end
- end
- end
- end
-end
diff --git a/lib/modules/embeddings/semantic_related.rb b/lib/modules/embeddings/semantic_related.rb
index 667866ad..683cf66d 100644
--- a/lib/modules/embeddings/semantic_related.rb
+++ b/lib/modules/embeddings/semantic_related.rb
@@ -5,101 +5,67 @@ module DiscourseAi
class SemanticRelated
MissingEmbeddingError = Class.new(StandardError)
- class << self
- def semantic_suggested_key(topic_id)
- "semantic-suggested-topic-#{topic_id}"
- end
+ def self.clear_cache_for(topic)
+ Discourse.cache.delete("semantic-suggested-topic-#{topic.id}")
+ Discourse.redis.del("build-semantic-suggested-topic-#{topic.id}")
+ end
- def build_semantic_suggested_key(topic_id)
- "build-semantic-suggested-topic-#{topic_id}"
- end
+ def related_topic_ids_for(topic)
+ return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
- def clear_cache_for(topic)
- Discourse.cache.delete(semantic_suggested_key(topic.id))
- Discourse.redis.del(build_semantic_suggested_key(topic.id))
- end
+ strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
+ vector_rep =
+ DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
+ cache_for = results_ttl(topic)
- def related_topic_ids_for(topic)
- return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
-
- manager = DiscourseAi::Embeddings::Manager.new(topic)
- cache_for = results_ttl(topic)
-
- begin
- Discourse
- .cache
- .fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
- symmetric_semantic_search(manager)
- end
- rescue MissingEmbeddingError
- # avoid a flood of jobs when visiting topic
- if Discourse.redis.set(
- build_semantic_suggested_key(topic.id),
- "queued",
- ex: 15.minutes.to_i,
- nx: true,
- )
- Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
+ asd =
+ Discourse
+ .cache
+ .fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
+ vector_rep
+ .symmetric_topics_similarity_search(topic)
+ .tap do |candidate_ids|
+ # Happens when the topic doesn't have any embeddings
+ # I'd rather not use Exceptions to control the flow, so this should be refactored soon
+ if candidate_ids.empty? || !candidate_ids.include?(topic.id)
+ raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}"
+ end
+ end
end
- []
- end
+ rescue MissingEmbeddingError
+ # avoid a flood of jobs when visiting topic
+ if Discourse.redis.set(
+ build_semantic_suggested_key(topic.id),
+ "queued",
+ ex: 15.minutes.to_i,
+ nx: true,
+ )
+ Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
end
+ []
+ end
- def symmetric_semantic_search(manager)
- topic = manager.target
- candidate_ids = self.query_symmetric_embeddings(manager)
-
- # Happens when the topic doesn't have any embeddings
- # I'd rather not use Exceptions to control the flow, so this should be refactored soon
- if candidate_ids.empty? || !candidate_ids.include?(topic.id)
- raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}"
- end
-
- candidate_ids
+ def results_ttl(topic)
+ case topic.created_at
+ when 6.hour.ago..Time.now
+ 15.minutes
+ when 3.day.ago..6.hour.ago
+ 1.hour
+ when 15.days.ago..3.day.ago
+ 12.hours
+ else
+ 1.week
end
+ end
- def query_symmetric_embeddings(manager)
- topic = manager.target
- model = manager.model
- table = manager.topic_embeddings_table
- begin
- DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
- SELECT
- topic_id
- FROM
- #{table}
- ORDER BY
- embeddings #{model.pg_function} (
- SELECT
- embeddings
- FROM
- #{table}
- WHERE
- topic_id = :topic_id
- LIMIT 1
- )
- LIMIT 100
- SQL
- rescue PG::Error => e
- Rails.logger.error(
- "Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}",
- )
- raise MissingEmbeddingError
- end
- end
+ private
- def results_ttl(topic)
- case topic.created_at
- when 6.hour.ago..Time.now
- 15.minutes
- when 3.day.ago..6.hour.ago
- 1.hour
- when 15.days.ago..3.day.ago
- 12.hours
- else
- 1.week
- end
- end
+ def semantic_suggested_key(topic_id)
+ "semantic-suggested-topic-#{topic_id}"
+ end
+
+ def build_semantic_suggested_key(topic_id)
+ "build-semantic-suggested-topic-#{topic_id}"
end
end
end
diff --git a/lib/modules/embeddings/semantic_search.rb b/lib/modules/embeddings/semantic_search.rb
index e35fdc56..2e2e3fd2 100644
--- a/lib/modules/embeddings/semantic_search.rb
+++ b/lib/modules/embeddings/semantic_search.rb
@@ -3,59 +3,66 @@
module DiscourseAi
module Embeddings
class SemanticSearch
+ def self.clear_cache_for(query)
+ digest = OpenSSL::Digest::SHA1.hexdigest(query)
+
+ Discourse.cache.delete("hyde-doc-#{digest}")
+ Discourse.cache.delete("hyde-doc-embedding-#{digest}")
+ end
+
def initialize(guardian)
@guardian = guardian
- @manager = DiscourseAi::Embeddings::Manager.new(nil)
- @model = @manager.model
+ end
+
+ def cached_query?(query)
+ digest = OpenSSL::Digest::SHA1.hexdigest(query)
+ Discourse.cache.read("hyde-doc-embedding-#{digest}").present?
end
def search_for_topics(query, page = 1)
- limit = Search.per_filter + 1
- offset = (page - 1) * Search.per_filter
+ max_results_per_page = 50
+ limit = [Search.per_filter, max_results_per_page].min + 1
+ offset = (page - 1) * limit
- candidate_ids = asymmetric_semantic_search(query, limit, offset)
+ strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
+ vector_rep =
+ DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
+
+ digest = OpenSSL::Digest::SHA1.hexdigest(query)
+
+ hypothetical_post =
+ Discourse
+ .cache
+ .fetch("hyde-doc-#{digest}", expires_in: 1.week) do
+ hyde_generator = DiscourseAi::Embeddings::HydeGenerators::Base.current_hyde_model.new
+ hyde_generator.hypothetical_post_from(query)
+ end
+
+ hypothetical_post_embedding =
+ Discourse
+ .cache
+ .fetch("hyde-doc-embedding-#{digest}", expires_in: 1.week) do
+ vector_rep.vector_from(hypothetical_post)
+ end
+
+ candidate_topic_ids =
+ vector_rep.asymmetric_topics_similarity_search(
+ hypothetical_post_embedding,
+ limit: limit,
+ offset: offset,
+ )
::Post
.where(post_type: ::Topic.visible_post_types(guardian.user))
.public_posts
.where("topics.visible")
- .where(topic_id: candidate_ids, post_number: 1)
- .order("array_position(ARRAY#{candidate_ids}, topic_id)")
- end
-
- def asymmetric_semantic_search(query, limit, offset, return_distance: false)
- embedding = model.generate_embeddings(query)
- table = @manager.topic_embeddings_table
-
- begin
- candidate_ids = DB.query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset)
- SELECT
- topic_id,
- embeddings #{@model.pg_function} '[:query_embedding]' AS distance
- FROM
- #{table}
- ORDER BY
- embeddings #{@model.pg_function} '[:query_embedding]'
- LIMIT :limit
- OFFSET :offset
- SQL
- rescue PG::Error => e
- Rails.logger.error(
- "Error #{e} querying embeddings for model #{model.name} and search #{query}",
- )
- raise MissingEmbeddingError
- end
-
- if return_distance
- candidate_ids.map { |c| [c.topic_id, c.distance] }
- else
- candidate_ids.map(&:topic_id)
- end
+ .where(topic_id: candidate_topic_ids, post_number: 1)
+ .order("array_position(ARRAY#{candidate_topic_ids}, topic_id)")
end
private
- attr_reader :model, :guardian
+ attr_reader :guardian
end
end
end
diff --git a/lib/modules/embeddings/semantic_topic_query.rb b/lib/modules/embeddings/semantic_topic_query.rb
index c1034baf..2ee85b65 100644
--- a/lib/modules/embeddings/semantic_topic_query.rb
+++ b/lib/modules/embeddings/semantic_topic_query.rb
@@ -14,7 +14,7 @@ class DiscourseAi::Embeddings::SemanticTopicQuery < TopicQuery
list =
create_list(:semantic_related, query_opts) do |topics|
- candidate_ids = DiscourseAi::Embeddings::SemanticRelated.related_topic_ids_for(topic)
+ candidate_ids = DiscourseAi::Embeddings::SemanticRelated.new.related_topic_ids_for(topic)
list =
topics
diff --git a/lib/modules/embeddings/strategies/truncation.rb b/lib/modules/embeddings/strategies/truncation.rb
index f7d76340..4b2c977a 100644
--- a/lib/modules/embeddings/strategies/truncation.rb
+++ b/lib/modules/embeddings/strategies/truncation.rb
@@ -4,77 +4,57 @@ module DiscourseAi
module Embeddings
module Strategies
class Truncation
- attr_reader :processed_target, :digest
-
- def self.id
- 1
- end
-
def id
- self.class.id
+ 1
end
def version
1
end
- def initialize(target, model)
- @model = model
- @target = target
- @tokenizer = @model.tokenizer
- @max_length = @model.max_sequence_length - 2
- @processed_target = nil
- end
-
- # Need a better name for this method
- def process!
- case @target
+ def prepare_text_from(target, tokenizer, max_length)
+ case target
when Topic
- @processed_target = topic_truncation(@target)
+ topic_truncation(target, tokenizer, max_length)
when Post
- @processed_target = post_truncation(@target)
+ post_truncation(target, tokenizer, max_length)
else
raise ArgumentError, "Invalid target type"
end
-
- @digest = OpenSSL::Digest::SHA1.hexdigest(@processed_target)
end
- def topic_truncation(topic)
- t = +""
+ private
- t << topic.title
- t << "\n\n"
- t << topic.category.name
+ def topic_information(topic)
+ info = +""
+
+ info << topic.title
+ info << "\n\n"
+ info << topic.category.name
if SiteSetting.tagging_enabled
- t << "\n\n"
- t << topic.tags.pluck(:name).join(", ")
+ info << "\n\n"
+ info << topic.tags.pluck(:name).join(", ")
end
- t << "\n\n"
+ info << "\n\n"
+ end
+
+ def topic_truncation(topic, tokenizer, max_length)
+ text = +topic_information(topic)
topic.posts.find_each do |post|
- t << post.raw
- break if @tokenizer.size(t) >= @max_length #maybe keep a partial counter to speed this up?
- t << "\n\n"
+ text << post.raw
+ break if tokenizer.size(text) >= max_length #maybe keep a partial counter to speed this up?
+ text << "\n\n"
end
- @tokenizer.truncate(t, @max_length)
+ tokenizer.truncate(text, max_length)
end
- def post_truncation(post)
- t = +""
+ def post_truncation(topic, tokenizer, max_length)
+ text = +topic_information(post.topic)
+ text << post.raw
- t << post.topic.title
- t << "\n\n"
- t << post.topic.category.name
- if SiteSetting.tagging_enabled
- t << "\n\n"
- t << post.topic.tags.pluck(:name).join(", ")
- end
- t << "\n\n"
- t << post.raw
-
- @tokenizer.truncate(t, @max_length)
+ tokenizer.truncate(text, max_length)
end
end
end
diff --git a/lib/modules/embeddings/vector_representations/all_mpnet_base_v2.rb b/lib/modules/embeddings/vector_representations/all_mpnet_base_v2.rb
new file mode 100644
index 00000000..8dfb2a47
--- /dev/null
+++ b/lib/modules/embeddings/vector_representations/all_mpnet_base_v2.rb
@@ -0,0 +1,50 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Embeddings
+ module VectorRepresentations
+ class AllMpnetBaseV2 < Base
+ def vector_from(text)
+ DiscourseAi::Inference::DiscourseClassifier.perform!(
+ "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
+ name,
+ text,
+ SiteSetting.ai_embeddings_discourse_service_api_key,
+ )
+ end
+
+ def name
+ "all-mpnet-base-v2"
+ end
+
+ def dimensions
+ 768
+ end
+
+ def max_sequence_length
+ 384
+ end
+
+ def id
+ 1
+ end
+
+ def version
+ 1
+ end
+
+ def pg_function
+ "<#>"
+ end
+
+ def pg_index_type
+ "vector_ip_ops"
+ end
+
+ def tokenizer
+ DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
+ end
+ end
+ end
+ end
+end
diff --git a/lib/modules/embeddings/vector_representations/base.rb b/lib/modules/embeddings/vector_representations/base.rb
new file mode 100644
index 00000000..e89bf259
--- /dev/null
+++ b/lib/modules/embeddings/vector_representations/base.rb
@@ -0,0 +1,166 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Embeddings
+ module VectorRepresentations
+ class Base
+ def self.current_representation(strategy)
+ subclasses.map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model }
+ end
+
+ def initialize(strategy)
+ @strategy = strategy
+ end
+
+ def create_index(lists, probes)
+ index_name = "#{table_name}_search"
+
+ DB.exec(<<~SQL)
+ DROP INDEX IF EXISTS #{index_name};
+ CREATE INDEX IF NOT EXISTS
+ #{index}
+ ON
+ #{table_name}
+ USING
+ ivfflat (embeddings #{pg_index_type})
+ WITH
+ (lists = #{lists})
+ WHERE
+ model_version = #{version} AND
+ strategy_version = #{@strategy.version};
+ SQL
+ end
+
+ def vector_from(text)
+ raise NotImplementedError
+ end
+
+ def generate_topic_representation_from(target, persist: true)
+ text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
+
+ vector_from(text).tap do |vector|
+ if persist
+ digest = OpenSSL::Digest::SHA1.hexdigest(text)
+ save_to_db(target, vector, digest)
+ end
+ end
+ end
+
+ def topic_id_from_representation(raw_vector)
+ DB.query_single(<<~SQL, query_embedding: raw_vector).first
+ SELECT
+ topic_id
+ FROM
+ #{table_name}
+ ORDER BY
+ embeddings #{pg_function} '[:query_embedding]'
+ LIMIT 1
+ SQL
+ end
+
+ def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
+ results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
+ SELECT
+ topic_id,
+ embeddings #{pg_function} '[:query_embedding]' AS distance
+ FROM
+ #{table_name}
+ ORDER BY
+ embeddings #{pg_function} '[:query_embedding]'
+ LIMIT :limit
+ OFFSET :offset
+ SQL
+
+ if return_distance
+ results.map { |r| [r.topic_id, r.distance] }
+ else
+ results.map(&:topic_id)
+ end
+ rescue PG::Error => e
+ Rails.logger.error("Error #{e} querying embeddings for model #{name}")
+ raise MissingEmbeddingError
+ end
+
+ def symmetric_topics_similarity_search(topic)
+ DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
+ SELECT
+ topic_id
+ FROM
+ #{table_name}
+ ORDER BY
+ embeddings #{pg_function} (
+ SELECT
+ embeddings
+ FROM
+ #{table_name}
+ WHERE
+ topic_id = :topic_id
+ LIMIT 1
+ )
+ LIMIT 100
+ SQL
+ rescue PG::Error => e
+ Rails.logger.error(
+ "Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
+ )
+ raise MissingEmbeddingError
+ end
+
+ def table_name
+ "ai_topic_embeddings_#{id}_#{@strategy.id}"
+ end
+
+ def name
+ raise NotImplementedError
+ end
+
+ def dimensions
+ raise NotImplementedError
+ end
+
+ def max_sequence_length
+ raise NotImplementedError
+ end
+
+ def id
+ raise NotImplementedError
+ end
+
+ def pg_function
+ raise NotImplementedError
+ end
+
+ def version
+ raise NotImplementedError
+ end
+
+ def tokenizer
+ raise NotImplementedError
+ end
+
+ protected
+
+ def save_to_db(target, vector, digest)
+ DB.exec(
+ <<~SQL,
+ INSERT INTO #{table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
+ VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
+ ON CONFLICT (topic_id)
+ DO UPDATE SET
+ model_version = :model_version,
+ strategy_version = :strategy_version,
+ digest = :digest,
+ embeddings = '[:embeddings]',
+ updated_at = CURRENT_TIMESTAMP
+ SQL
+ topic_id: target.id,
+ model_version: version,
+ strategy_version: @strategy.version,
+ digest: digest,
+ embeddings: vector,
+ )
+ end
+ end
+ end
+ end
+end
diff --git a/lib/modules/embeddings/vector_representations/multilingual_e5_large.rb b/lib/modules/embeddings/vector_representations/multilingual_e5_large.rb
new file mode 100644
index 00000000..30ffb4d8
--- /dev/null
+++ b/lib/modules/embeddings/vector_representations/multilingual_e5_large.rb
@@ -0,0 +1,50 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Embeddings
+ module VectorRepresentations
+ class MultilingualE5Large < Base
+ def vector_from(text)
+ DiscourseAi::Inference::DiscourseClassifier.perform!(
+ "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
+ name,
+ "query: #{text}",
+ SiteSetting.ai_embeddings_discourse_service_api_key,
+ )
+ end
+
+ def id
+ 3
+ end
+
+ def version
+ 1
+ end
+
+ def name
+ "multilingual-e5-large"
+ end
+
+ def dimensions
+ 1024
+ end
+
+ def max_sequence_length
+ 512
+ end
+
+ def pg_function
+ "<=>"
+ end
+
+ def pg_index_type
+ "vector_cosine_ops"
+ end
+
+ def tokenizer
+ DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
+ end
+ end
+ end
+ end
+end
diff --git a/lib/modules/embeddings/vector_representations/text_embedding_ada_002.rb b/lib/modules/embeddings/vector_representations/text_embedding_ada_002.rb
new file mode 100644
index 00000000..3d1fc0ff
--- /dev/null
+++ b/lib/modules/embeddings/vector_representations/text_embedding_ada_002.rb
@@ -0,0 +1,46 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Embeddings
+ module VectorRepresentations
+ class TextEmbeddingAda002 < Base
+ def id
+ 2
+ end
+
+ def version
+ 1
+ end
+
+ def name
+ "text-embedding-ada-002"
+ end
+
+ def dimensions
+ 1536
+ end
+
+ def max_sequence_length
+ 8191
+ end
+
+ def pg_function
+ "<=>"
+ end
+
+ def pg_index_type
+ "vector_cosine_ops"
+ end
+
+ def vector_from(text)
+ response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text)
+ response[:data].first[:embedding]
+ end
+
+ def tokenizer
+ DiscourseAi::Tokenizer::OpenAiTokenizer
+ end
+ end
+ end
+ end
+end
diff --git a/lib/shared/inference/hugging_face_text_generation.rb b/lib/shared/inference/hugging_face_text_generation.rb
index f3753a21..9a8cd22e 100644
--- a/lib/shared/inference/hugging_face_text_generation.rb
+++ b/lib/shared/inference/hugging_face_text_generation.rb
@@ -4,7 +4,7 @@ module ::DiscourseAi
module Inference
class HuggingFaceTextGeneration
CompletionFailed = Class.new(StandardError)
- TIMEOUT = 60
+ TIMEOUT = 120
def self.perform!(
prompt,
diff --git a/lib/tasks/modules/embeddings/database.rake b/lib/tasks/modules/embeddings/database.rake
index 96b9f15d..4ee260cb 100644
--- a/lib/tasks/modules/embeddings/database.rake
+++ b/lib/tasks/modules/embeddings/database.rake
@@ -4,18 +4,22 @@ desc "Backfill embeddings for all topics"
task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args|
public_categories = Category.where(read_restricted: false).pluck(:id)
manager = DiscourseAi::Embeddings::Manager.new(Topic.first)
+
+ strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
+ vector_rep =
+ DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new(strategy)
+ table_name = vector_rep.table_name
+
Topic
- .joins(
- "LEFT JOIN #{manager.topic_embeddings_table} ON #{manager.topic_embeddings_table}.topic_id = topics.id",
- )
- .where("#{manager.topic_embeddings_table}.topic_id IS NULL")
+ .joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
+ .where("#{table_name}.topic_id IS NULL")
.where("topics.id >= ?", args[:start_topic].to_i || 0)
.where("category_id IN (?)", public_categories)
.where(deleted_at: nil)
.order("topics.id ASC")
.find_each do |t|
print "."
- DiscourseAi::Embeddings::Manager.new(t).generate!
+ vector_rep.generate_topic_representation_from(t)
end
end
@@ -28,25 +32,11 @@ task "ai:embeddings:index", [:work_mem] => [:environment] do |_, args|
lists = count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i
probes = count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i
- manager = DiscourseAi::Embeddings::Manager.new(Topic.first)
- table = manager.topic_embeddings_table
- index = "#{table}_search"
+ vector_representation_klass = DiscourseAi::Embeddings::Vectors::Base.find_vector_representation
+ strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
DB.exec("SET work_mem TO '#{args[:work_mem] || "1GB"}';")
- DB.exec(<<~SQL)
- DROP INDEX IF EXISTS #{index};
- CREATE INDEX IF NOT EXISTS
- #{index}
- ON
- #{table}
- USING
- ivfflat (embeddings #{manager.model.pg_index_type})
- WITH
- (lists = #{lists})
- WHERE
- model_version = #{manager.model.version} AND
- strategy_version = #{manager.strategy.version};
- SQL
+ vector_representation_klass.new(strategy).create_index(lists, probes)
DB.exec("RESET work_mem;")
DB.exec("SET ivfflat.probes = #{probes};")
end
diff --git a/plugin.rb b/plugin.rb
index cf01059c..0ad43c07 100644
--- a/plugin.rb
+++ b/plugin.rb
@@ -17,6 +17,7 @@ register_asset "stylesheets/modules/ai-helper/common/ai-helper.scss"
register_asset "stylesheets/modules/ai-bot/common/bot-replies.scss"
register_asset "stylesheets/modules/embeddings/common/semantic-related-topics.scss"
+register_asset "stylesheets/modules/embeddings/common/semantic-search.scss"
module ::DiscourseAi
PLUGIN_NAME = "discourse-ai"
diff --git a/spec/integration/embeddings/manager_spec.rb b/spec/integration/embeddings/manager_spec.rb
deleted file mode 100644
index d515c5ed..00000000
--- a/spec/integration/embeddings/manager_spec.rb
+++ /dev/null
@@ -1,44 +0,0 @@
-# frozen_string_literal: true
-
-require_relative "../../support/embeddings_generation_stubs"
-
-RSpec.describe DiscourseAi::Embeddings::Manager do
- let(:user) { Fabricate(:user) }
- let(:expected_embedding) do
- JSON.parse(
- File.read("#{Rails.root}/plugins/discourse-ai/spec/fixtures/embeddings/embedding.txt"),
- )
- end
- let(:discourse_model) { "all-mpnet-base-v2" }
-
- before do
- SiteSetting.discourse_ai_enabled = true
- SiteSetting.ai_embeddings_enabled = true
- SiteSetting.ai_embeddings_model = "all-mpnet-base-v2"
- SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
- Jobs.run_immediately!
- end
-
- it "generates embeddings for new topics automatically" do
- pc =
- PostCreator.new(
- user,
- raw: "this is the new content for my topic",
- title: "this is my new topic title",
- )
- input =
- "This is my new topic title\n\nUncategorized\n\n\n\nthis is the new content for my topic\n\n"
- EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
- post = pc.create
- manager = DiscourseAi::Embeddings::Manager.new(post.topic)
-
- embeddings =
- DB.query_single(
- "SELECT embeddings FROM #{manager.topic_embeddings_table} WHERE topic_id = #{post.topic.id}",
- ).first
-
- expect(embeddings.split(",")[1].to_f).to be_within(0.0001).of(expected_embedding[1])
- expect(embeddings.split(",")[13].to_f).to be_within(0.0001).of(expected_embedding[13])
- expect(embeddings.split(",")[135].to_f).to be_within(0.0001).of(expected_embedding[135])
- end
-end
diff --git a/spec/lib/modules/embeddings/entry_point_spec.rb b/spec/lib/modules/embeddings/entry_point_spec.rb
index 72545005..3d4f14e8 100644
--- a/spec/lib/modules/embeddings/entry_point_spec.rb
+++ b/spec/lib/modules/embeddings/entry_point_spec.rb
@@ -28,97 +28,4 @@ describe DiscourseAi::Embeddings::EntryPoint do
end
end
end
-
- describe "SemanticTopicQuery extension" do
- describe "#list_semantic_related_topics" do
- subject(:topic_query) { DiscourseAi::Embeddings::SemanticTopicQuery.new(user) }
-
- fab!(:target) { Fabricate(:topic) }
-
- def stub_semantic_search_with(results)
- DiscourseAi::Embeddings::SemanticRelated.expects(:related_topic_ids_for).returns(results)
- end
-
- context "when the semantic search returns an unlisted topic" do
- fab!(:unlisted_topic) { Fabricate(:topic, visible: false) }
-
- before { stub_semantic_search_with([unlisted_topic.id]) }
-
- it "filters it out" do
- expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
- end
- end
-
- context "when the semantic search returns a private topic" do
- fab!(:private_topic) { Fabricate(:private_message_topic) }
-
- before { stub_semantic_search_with([private_topic.id]) }
-
- it "filters it out" do
- expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
- end
- end
-
- context "when the semantic search returns a topic from a restricted category" do
- fab!(:group) { Fabricate(:group) }
- fab!(:category) { Fabricate(:private_category, group: group) }
- fab!(:secured_category_topic) { Fabricate(:topic, category: category) }
-
- before { stub_semantic_search_with([secured_category_topic.id]) }
-
- it "filters it out" do
- expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
- end
-
- it "doesn't filter it out if the user has access to the category" do
- group.add(user)
-
- expect(topic_query.list_semantic_related_topics(target).topics).to contain_exactly(
- secured_category_topic,
- )
- end
- end
-
- context "when the semantic search returns a closed topic and we explicitly exclude them" do
- fab!(:closed_topic) { Fabricate(:topic, closed: true) }
-
- before do
- SiteSetting.ai_embeddings_semantic_related_include_closed_topics = false
- stub_semantic_search_with([closed_topic.id])
- end
-
- it "filters it out" do
- expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
- end
- end
-
- context "when the semantic search returns public topics" do
- fab!(:normal_topic_1) { Fabricate(:topic) }
- fab!(:normal_topic_2) { Fabricate(:topic) }
- fab!(:normal_topic_3) { Fabricate(:topic) }
- fab!(:closed_topic) { Fabricate(:topic, closed: true) }
-
- before do
- stub_semantic_search_with(
- [closed_topic.id, normal_topic_1.id, normal_topic_2.id, normal_topic_3.id],
- )
- end
-
- it "filters it out" do
- expect(topic_query.list_semantic_related_topics(target).topics).to eq(
- [closed_topic, normal_topic_1, normal_topic_2, normal_topic_3],
- )
- end
-
- it "returns the plugin limit for the number of results" do
- SiteSetting.ai_embeddings_semantic_related_topics = 2
-
- expect(topic_query.list_semantic_related_topics(target).topics).to contain_exactly(
- closed_topic,
- normal_topic_1,
- )
- end
- end
- end
- end
end
diff --git a/spec/lib/modules/embeddings/models/all_mpnet_base_v2_spec.rb b/spec/lib/modules/embeddings/models/all_mpnet_base_v2_spec.rb
deleted file mode 100644
index 6766d72a..00000000
--- a/spec/lib/modules/embeddings/models/all_mpnet_base_v2_spec.rb
+++ /dev/null
@@ -1,24 +0,0 @@
-# frozen_string_literal: true
-
-require_relative "../../../../support/embeddings_generation_stubs"
-
-RSpec.describe DiscourseAi::Embeddings::Models::AllMpnetBaseV2 do
- describe "#generate_embeddings" do
- let(:input) { "test" }
- let(:expected_embedding) { [0.0038493, 0.482001] }
-
- context "when the model uses the discourse service to create embeddings" do
- before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
-
- let(:discourse_model) { "all-mpnet-base-v2" }
-
- it "returns an embedding for a given string" do
- EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
-
- embedding = described_class.generate_embeddings(input)
-
- expect(embedding).to contain_exactly(*expected_embedding)
- end
- end
- end
-end
diff --git a/spec/lib/modules/embeddings/models/text_embedding_ada_002_spec.rb b/spec/lib/modules/embeddings/models/text_embedding_ada_002_spec.rb
deleted file mode 100644
index 60d41d02..00000000
--- a/spec/lib/modules/embeddings/models/text_embedding_ada_002_spec.rb
+++ /dev/null
@@ -1,22 +0,0 @@
-# frozen_string_literal: true
-
-require_relative "../../../../support/embeddings_generation_stubs"
-
-RSpec.describe DiscourseAi::Embeddings::Models::TextEmbeddingAda002 do
- describe "#generate_embeddings" do
- let(:input) { "test" }
- let(:expected_embedding) { [0.0038493, 0.482001] }
-
- context "when the model uses OpenAI to create embeddings" do
- let(:openai_model) { "text-embedding-ada-002" }
-
- it "returns an embedding for a given string" do
- EmbeddingsGenerationStubs.openai_service(openai_model, input, expected_embedding)
-
- embedding = described_class.generate_embeddings(input)
-
- expect(embedding).to contain_exactly(*expected_embedding)
- end
- end
- end
-end
diff --git a/spec/lib/modules/embeddings/semantic_related_spec.rb b/spec/lib/modules/embeddings/semantic_related_spec.rb
index 22c85b51..1911218f 100644
--- a/spec/lib/modules/embeddings/semantic_related_spec.rb
+++ b/spec/lib/modules/embeddings/semantic_related_spec.rb
@@ -3,6 +3,8 @@
require "rails_helper"
describe DiscourseAi::Embeddings::SemanticRelated do
+ subject(:semantic_related) { described_class.new }
+
fab!(:target) { Fabricate(:topic) }
fab!(:normal_topic_1) { Fabricate(:topic) }
fab!(:normal_topic_2) { Fabricate(:topic) }
@@ -25,13 +27,13 @@ describe DiscourseAi::Embeddings::SemanticRelated do
results = nil
expect_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
- results = described_class.related_topic_ids_for(topic)
+ results = semantic_related.related_topic_ids_for(topic)
end
expect(results).to eq([])
expect_not_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
- results = described_class.related_topic_ids_for(topic)
+ results = semantic_related.related_topic_ids_for(topic)
end
expect(results).to eq([])
diff --git a/spec/lib/modules/embeddings/semantic_search_spec.rb b/spec/lib/modules/embeddings/semantic_search_spec.rb
index 48ac4584..49826dd0 100644
--- a/spec/lib/modules/embeddings/semantic_search_spec.rb
+++ b/spec/lib/modules/embeddings/semantic_search_spec.rb
@@ -1,5 +1,8 @@
# frozen_string_literal: true
+require_relative "../../../support/embeddings_generation_stubs"
+require_relative "../../../support/openai_completions_inference_stubs"
+
RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
fab!(:post) { Fabricate(:post) }
fab!(:user) { Fabricate(:user) }
@@ -8,10 +11,28 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
let(:subject) { described_class.new(Guardian.new(user)) }
describe "#search_for_topics" do
+ let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
+
+ before do
+ SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
+
+ prompt = DiscourseAi::Embeddings::HydeGenerators::OpenAi.new.prompt(query)
+ OpenAiCompletionsInferenceStubs.stub_response(prompt, hypothetical_post)
+
+ hyde_embedding = [0.049382, 0.9999]
+ EmbeddingsGenerationStubs.discourse_service(
+ SiteSetting.ai_embeddings_model,
+ hypothetical_post,
+ hyde_embedding,
+ )
+ end
+
+ after { described_class.clear_cache_for(query) }
+
def stub_candidate_ids(candidate_ids)
- DiscourseAi::Embeddings::SemanticSearch
+ DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
.any_instance
- .expects(:asymmetric_semantic_search)
+ .expects(:asymmetric_topics_similarity_search)
.returns(candidate_ids)
end
diff --git a/spec/lib/modules/embeddings/semantic_topic_query_spec.rb b/spec/lib/modules/embeddings/semantic_topic_query_spec.rb
index cfc4bc20..911bb4d8 100644
--- a/spec/lib/modules/embeddings/semantic_topic_query_spec.rb
+++ b/spec/lib/modules/embeddings/semantic_topic_query_spec.rb
@@ -12,9 +12,14 @@ describe DiscourseAi::Embeddings::EntryPoint do
fab!(:target) { Fabricate(:topic) }
def stub_semantic_search_with(results)
- DiscourseAi::Embeddings::SemanticRelated.expects(:related_topic_ids_for).returns(results)
+ DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
+ .any_instance
+ .expects(:symmetric_topics_similarity_search)
+ .returns(results.concat([target.id]))
end
+ after { DiscourseAi::Embeddings::SemanticRelated.clear_cache_for(target) }
+
context "when the semantic search returns an unlisted topic" do
fab!(:unlisted_topic) { Fabricate(:topic, visible: false) }
diff --git a/spec/lib/modules/embeddings/strategies/truncation_spec.rb b/spec/lib/modules/embeddings/strategies/truncation_spec.rb
index c25ade73..850db322 100644
--- a/spec/lib/modules/embeddings/strategies/truncation_spec.rb
+++ b/spec/lib/modules/embeddings/strategies/truncation_spec.rb
@@ -1,8 +1,10 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
- describe "#process!" do
- context "when the model uses OpenAI to create embeddings" do
+ subject(:truncation) { described_class.new }
+
+ describe "#prepare_text_from" do
+ context "when using vector from OpenAI" do
before { SiteSetting.max_post_length = 100_000 }
fab!(:topic) { Fabricate(:topic) }
@@ -18,13 +20,15 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
end
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
- let(:model) { DiscourseAi::Embeddings::Models::Base.descendants.sample(1).first }
- let(:truncation) { described_class.new(topic, model) }
+ let(:model) do
+ DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new(truncation)
+ end
it "truncates a topic" do
- truncation.process!
+ prepared_text =
+ truncation.prepare_text_from(topic, model.tokenizer, model.max_sequence_length)
- expect(model.tokenizer.size(truncation.processed_target)).to be <= model.max_sequence_length
+ expect(model.tokenizer.size(prepared_text)).to be <= model.max_sequence_length
end
end
end
diff --git a/spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb b/spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb
new file mode 100644
index 00000000..16f20abf
--- /dev/null
+++ b/spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb
@@ -0,0 +1,18 @@
+# frozen_string_literal: true
+
+require_relative "../../../../support/embeddings_generation_stubs"
+require_relative "vector_rep_shared_examples"
+
+RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do
+ subject(:vector_rep) { described_class.new(truncation) }
+
+ let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
+
+ before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
+
+ def stub_vector_mapping(text, expected_embedding)
+ EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
+ end
+
+ it_behaves_like "generates and store embedding using with vector representation"
+end
diff --git a/spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb b/spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb
new file mode 100644
index 00000000..1c1b2a5f
--- /dev/null
+++ b/spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb
@@ -0,0 +1,22 @@
+# frozen_string_literal: true
+
+require_relative "../../../../support/embeddings_generation_stubs"
+require_relative "vector_rep_shared_examples"
+
+RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large do
+ subject(:vector_rep) { described_class.new(truncation) }
+
+ let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
+
+ before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
+
+ def stub_vector_mapping(text, expected_embedding)
+ EmbeddingsGenerationStubs.discourse_service(
+ vector_rep.name,
+ "query: #{text}",
+ expected_embedding,
+ )
+ end
+
+ it_behaves_like "generates and store embedding using with vector representation"
+end
diff --git a/spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb b/spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb
new file mode 100644
index 00000000..59b48dd3
--- /dev/null
+++ b/spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb
@@ -0,0 +1,16 @@
+# frozen_string_literal: true
+
+require_relative "../../../../support/embeddings_generation_stubs"
+require_relative "vector_rep_shared_examples"
+
+RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002 do
+ subject(:vector_rep) { described_class.new(truncation) }
+
+ let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
+
+ def stub_vector_mapping(text, expected_embedding)
+ EmbeddingsGenerationStubs.openai_service(vector_rep.name, text, expected_embedding)
+ end
+
+ it_behaves_like "generates and store embedding using with vector representation"
+end
diff --git a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb
new file mode 100644
index 00000000..7d5cc213
--- /dev/null
+++ b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb
@@ -0,0 +1,54 @@
+# frozen_string_literal: true
+
+RSpec.shared_examples "generates and store embedding using with vector representation" do
+ before { @expected_embedding = [0.0038493] * vector_rep.dimensions }
+
+ describe "#vector_from" do
+ it "creates a vector from a given string" do
+ text = "This is a piece of text"
+ stub_vector_mapping(text, @expected_embedding)
+
+ expect(vector_rep.vector_from(text)).to eq(@expected_embedding)
+ end
+ end
+
+ describe "#generate_topic_representation_from" do
+ fab!(:topic) { Fabricate(:topic) }
+ fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
+
+ it "creates a vector from a topic and stores it in the database" do
+ text =
+ truncation.prepare_text_from(
+ topic,
+ vector_rep.tokenizer,
+ vector_rep.max_sequence_length - 2,
+ )
+ stub_vector_mapping(text, @expected_embedding)
+
+ vector_rep.generate_topic_representation_from(topic)
+
+ expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id)
+ end
+ end
+
+ describe "#asymmetric_topics_similarity_search" do
+ fab!(:topic) { Fabricate(:topic) }
+ fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
+
+ it "finds IDs of similar topics with a given embedding" do
+ similar_vector = [0.0038494] * vector_rep.dimensions
+ text =
+ truncation.prepare_text_from(
+ topic,
+ vector_rep.tokenizer,
+ vector_rep.max_sequence_length - 2,
+ )
+ stub_vector_mapping(text, @expected_embedding)
+ vector_rep.generate_topic_representation_from(topic)
+
+ expect(
+ vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),
+ ).to contain_exactly(topic.id)
+ end
+ end
+end
diff --git a/spec/requests/topic_spec.rb b/spec/requests/topic_spec.rb
index 25a00895..a4f2141d 100644
--- a/spec/requests/topic_spec.rb
+++ b/spec/requests/topic_spec.rb
@@ -16,9 +16,10 @@ describe ::TopicsController do
context "when a user is logged on" do
it "includes related topics in payload when configured" do
- DiscourseAi::Embeddings::SemanticRelated.stubs(:related_topic_ids_for).returns(
- [topic1.id, topic2.id, topic3.id],
- )
+ DiscourseAi::Embeddings::SemanticRelated
+ .any_instance
+ .stubs(:related_topic_ids_for)
+ .returns([topic1.id, topic2.id, topic3.id])
get("#{topic.relative_url}.json")
expect(response.status).to eq(200)