From 4e05763a99deb527bf9f280ce10df36e36049772 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Fri, 31 Mar 2023 15:29:56 -0300 Subject: [PATCH] FEATURE: Semantic assymetric full-page search (#34) Depends on discourse/discourse#20915 Hooks to the full-page-search component using an experimental API and performs an assymetric similarity search using our embeddings database. --- .../embeddings/embeddings_controller.rb | 35 ++++++ .../initializers/semantic-full-page-search.js | 64 +++++++++++ config/locales/client.en.yml | 3 + config/locales/server.en.yml | 1 + config/routes.rb | 5 + config/settings.yml | 11 +- lib/modules/embeddings/entry_point.rb | 3 +- lib/modules/embeddings/model.rb | 83 +++++++++++++++ lib/modules/embeddings/models.rb | 62 ----------- lib/modules/embeddings/semantic_related.rb | 41 ++----- lib/modules/embeddings/semantic_search.rb | 31 ++++++ lib/modules/embeddings/topic.rb | 100 +++++++++++------- lib/tasks/modules/embeddings/database.rake | 13 +-- spec/lib/modules/embeddings/model_spec.rb | 36 +++++++ .../embeddings/semantic_related_spec.rb | 7 +- .../embeddings/semantic_search_spec.rb | 86 +++++++++++++++ spec/requests/topic_spec.rb | 7 +- spec/support/embeddings_generation_stubs.rb | 22 ++++ 18 files changed, 462 insertions(+), 148 deletions(-) create mode 100644 app/controllers/discourse_ai/embeddings/embeddings_controller.rb create mode 100644 assets/javascripts/initializers/semantic-full-page-search.js create mode 100644 lib/modules/embeddings/model.rb delete mode 100644 lib/modules/embeddings/models.rb create mode 100644 lib/modules/embeddings/semantic_search.rb create mode 100644 spec/lib/modules/embeddings/model_spec.rb create mode 100644 spec/lib/modules/embeddings/semantic_search_spec.rb create mode 100644 spec/support/embeddings_generation_stubs.rb diff --git a/app/controllers/discourse_ai/embeddings/embeddings_controller.rb b/app/controllers/discourse_ai/embeddings/embeddings_controller.rb new file mode 100644 index 00000000..8d6e6ee2 --- /dev/null +++ b/app/controllers/discourse_ai/embeddings/embeddings_controller.rb @@ -0,0 +1,35 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + class EmbeddingsController < ::ApplicationController + requires_plugin ::DiscourseAi::PLUGIN_NAME + + SEMANTIC_SEARCH_TYPE = "semantic_search" + + def search + query = params[:q] + page = (params[:page] || 1).to_i + + grouped_results = + Search::GroupedSearchResults.new( + type_filter: SEMANTIC_SEARCH_TYPE, + term: query, + search_context: guardian, + ) + + model = + DiscourseAi::Embeddings::Model.instantiate( + SiteSetting.ai_embeddings_semantic_search_model, + ) + + DiscourseAi::Embeddings::SemanticSearch + .new(guardian, model) + .search_for_topics(query, page) + .each { |topic_post| grouped_results.add(topic_post) } + + render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results) + end + end + end +end diff --git a/assets/javascripts/initializers/semantic-full-page-search.js b/assets/javascripts/initializers/semantic-full-page-search.js new file mode 100644 index 00000000..dc21b537 --- /dev/null +++ b/assets/javascripts/initializers/semantic-full-page-search.js @@ -0,0 +1,64 @@ +import { withPluginApi } from "discourse/lib/plugin-api"; +import { translateResults, updateRecentSearches } from "discourse/lib/search"; +import { setTransient } from "discourse/lib/page-tracker"; +import { ajax } from "discourse/lib/ajax"; + +const SEMANTIC_SEARCH = "semantic_search"; + +function initializeSemanticSearch(api) { + api.addFullPageSearchType( + "discourse_ai.embeddings.semantic_search", + SEMANTIC_SEARCH, + (searchController, args, searchKey) => { + if (searchController.currentUser) { + updateRecentSearches(searchController.currentUser, args.searchTerm); + } + + ajax("/discourse-ai/embeddings/semantic-search", { data: args }) + .then(async (results) => { + const model = (await translateResults(results)) || {}; + + if (results.grouped_search_result) { + searchController.set("q", results.grouped_search_result.term); + } + + if (args.page > 1) { + if (model) { + searchController.model.posts.pushObjects(model.posts); + searchController.model.topics.pushObjects(model.topics); + searchController.model.set( + "grouped_search_result", + results.grouped_search_result + ); + } + } else { + setTransient("lastSearch", { searchKey, model }, 5); + model.grouped_search_result = results.grouped_search_result; + searchController.set("model", model); + } + searchController.set("error", null); + }) + .catch((e) => { + searchController.set("error", e.jqXHR.responseJSON?.message); + }) + .finally(() => { + searchController.setProperties({ + searching: false, + loading: false, + }); + }); + } + ); +} + +export default { + name: "discourse_ai-full-page-semantic-search", + + initialize(container) { + const settings = container.lookup("site-settings:main"); + + if (settings.ai_embeddings_enabled) { + withPluginApi("1.6.0", initializeSemanticSearch); + } + }, +}; diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 8bbd29cd..af639861 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -10,6 +10,9 @@ en: reviewables: model_used: "Model used:" accuracy: "Accuracy:" + + embeddings: + semantic_search: "Topics (Semantic)" review: types: reviewable_ai_post: diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index f246836d..6a6b31e1 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -47,6 +47,7 @@ en: ai_embeddings_semantic_related_topics_enabled: "Use Semantic Search for related topics." ai_embeddings_semantic_related_topics: "Maximum number of topics to show in related topic section." ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info." + ai_embeddings_semantic_search_model: "Model to use for semantic search." reviewables: reasons: diff --git a/config/routes.rb b/config/routes.rb index 0c9fcf47..89660afc 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -6,6 +6,11 @@ DiscourseAi::Engine.routes.draw do get "prompts" => "assistant#prompts" post "suggest" => "assistant#suggest" end + + # Embedding routes + scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do + get "semantic-search" => "embeddings#search" + end end Discourse::Application.routes.append { mount ::DiscourseAi::Engine, at: "discourse-ai" } diff --git a/config/settings.yml b/config/settings.yml index cb89a5e2..43907c22 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -108,7 +108,9 @@ plugins: - gpt-3.5-turbo - gpt-4 - ai_embeddings_enabled: false + ai_embeddings_enabled: + default: false + client: true ai_embeddings_discourse_service_api_endpoint: "" ai_embeddings_discourse_service_api_key: "" ai_embeddings_models: @@ -133,6 +135,13 @@ plugins: - all-distilroberta-v1 - multi-qa-mpnet-base-dot-v1 - paraphrase-multilingual-mpnet-base-v2 + ai_embeddings_semantic_search_model: + type: enum + default: msmarco-distilbert-base-v4 + choices: + - msmarco-distilbert-base-v4 + - msmarco-distilbert-base-tas-b + - text-embedding-ada-002 ai_embeddings_generate_for_pms: false ai_embeddings_semantic_related_topics_enabled: false ai_embeddings_semantic_related_topics: 5 diff --git a/lib/modules/embeddings/entry_point.rb b/lib/modules/embeddings/entry_point.rb index 5033bf38..7466afcc 100644 --- a/lib/modules/embeddings/entry_point.rb +++ b/lib/modules/embeddings/entry_point.rb @@ -4,10 +4,11 @@ module DiscourseAi module Embeddings class EntryPoint def load_files - require_relative "models" + require_relative "model" require_relative "topic" require_relative "jobs/regular/generate_embeddings" require_relative "semantic_related" + require_relative "semantic_search" end def inject_into(plugin) diff --git a/lib/modules/embeddings/model.rb b/lib/modules/embeddings/model.rb new file mode 100644 index 00000000..d2684b5c --- /dev/null +++ b/lib/modules/embeddings/model.rb @@ -0,0 +1,83 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + class Model + AVAILABLE_MODELS_TEMPLATES = { + "all-mpnet-base-v2" => [768, 384, %i[dot cosine euclidean], %i[symmetric], "discourse"], + "all-distilroberta-v1" => [768, 512, %i[dot cosine euclidean], %i[symmetric], "discourse"], + "multi-qa-mpnet-base-dot-v1" => [768, 512, %i[dot], %i[symmetric], "discourse"], + "paraphrase-multilingual-mpnet-base-v2" => [ + 768, + 128, + %i[cosine], + %i[symmetric], + "discourse", + ], + "msmarco-distilbert-base-v4" => [768, 512, %i[cosine], %i[asymmetric], "discourse"], + "msmarco-distilbert-base-tas-b" => [768, 512, %i[dot], %i[asymmetric], "discourse"], + "text-embedding-ada-002" => [1536, 2048, %i[cosine], %i[symmetric asymmetric], "openai"], + } + + SEARCH_FUNCTION_TO_PG_INDEX = { + dot: "vector_ip_ops", + cosine: "vector_cosine_ops", + euclidean: "vector_l2_ops", + } + + SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" } + + class << self + def instantiate(model_name) + new(model_name, *AVAILABLE_MODELS_TEMPLATES[model_name]) + end + + def enabled_models + SiteSetting + .ai_embeddings_models + .split("|") + .map { |model_name| instantiate(model_name.strip) } + end + end + + def initialize(name, dimensions, max_sequence_lenght, functions, type, provider) + @name = name + @dimensions = dimensions + @max_sequence_lenght = max_sequence_lenght + @functions = functions + @type = type + @provider = provider + end + + def generate_embedding(input) + send("#{provider}_embeddings", input) + end + + def pg_function + SEARCH_FUNCTION_TO_PG_FUNCTION[functions.first] + end + + def pg_index + SEARCH_FUNCTION_TO_PG_INDEX[functions.first] + end + + attr_reader :name, :dimensions, :max_sequence_lenght, :functions, :type, :provider + + private + + def discourse_embeddings(input) + DiscourseAi::Inference::DiscourseClassifier.perform!( + "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", + name.to_s, + input, + SiteSetting.ai_embeddings_discourse_service_api_key, + ) + end + + def openai_embeddings(input) + response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(input) + response[:data].first[:embedding] + end + end + end +end diff --git a/lib/modules/embeddings/models.rb b/lib/modules/embeddings/models.rb deleted file mode 100644 index 1d1dc391..00000000 --- a/lib/modules/embeddings/models.rb +++ /dev/null @@ -1,62 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - class Models - MODEL = Data.define(:name, :dimensions, :max_sequence_lenght, :functions, :type, :provider) - - SEARCH_FUNCTION_TO_PG_INDEX = { - dot: "vector_ip_ops", - cosine: "vector_cosine_ops", - euclidean: "vector_l2_ops", - } - - SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" } - - def self.enabled_models - setting = SiteSetting.ai_embeddings_models.split("|").map(&:strip) - list.filter { |model| setting.include?(model.name) } - end - - def self.list - @@list ||= [ - MODEL.new( - "all-mpnet-base-v2", - 768, - 384, - %i[dot cosine euclidean], - [:symmetric], - "discourse", - ), - MODEL.new( - "all-distilroberta-v1", - 768, - 512, - %i[dot cosine euclidean], - [:symmetric], - "discourse", - ), - MODEL.new("multi-qa-mpnet-base-dot-v1", 768, 512, [:dot], [:symmetric], "discourse"), - MODEL.new( - "paraphrase-multilingual-mpnet-base-v2", - 768, - 128, - [:cosine], - [:symmetric], - "discourse", - ), - MODEL.new("msmarco-distilbert-base-v4", 768, 512, [:cosine], [:asymmetric], "discourse"), - MODEL.new("msmarco-distilbert-base-tas-b", 768, 512, [:dot], [:asymmetric], "discourse"), - MODEL.new( - "text-embedding-ada-002", - 1536, - 2048, - [:cosine], - %i[:symmetric :asymmetric], - "openai", - ), - ] - end - end - end -end diff --git a/lib/modules/embeddings/semantic_related.rb b/lib/modules/embeddings/semantic_related.rb index db60ce91..93b50bec 100644 --- a/lib/modules/embeddings/semantic_related.rb +++ b/lib/modules/embeddings/semantic_related.rb @@ -16,12 +16,17 @@ module DiscourseAi 1.day end + model = + DiscourseAi::Embeddings::Model.instantiate( + SiteSetting.ai_embeddings_semantic_related_model, + ) + begin candidate_ids = Discourse .cache .fetch("semantic-suggested-topic-#{topic.id}", expires_in: cache_for) do - search_suggestions(topic) + DiscourseAi::Embeddings::Topic.new.symmetric_semantic_search(model, topic) end rescue StandardError => e Rails.logger.error("SemanticRelated: #{e}") @@ -39,40 +44,6 @@ module DiscourseAi .order("array_position(ARRAY#{candidate_ids}, id)") .limit(SiteSetting.ai_embeddings_semantic_related_topics) end - - def self.search_suggestions(topic) - model_name = SiteSetting.ai_embeddings_semantic_related_model - model = DiscourseAi::Embeddings::Models.list.find { |m| m.name == model_name } - function = - DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_FUNCTION[model.functions.first] - - candidate_ids = - DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id) - SELECT - topic_id - FROM - topic_embeddings_#{model_name.underscore} - ORDER BY - embedding #{function} ( - SELECT - embedding - FROM - topic_embeddings_#{model_name.underscore} - WHERE - topic_id = :topic_id - LIMIT 1 - ) - LIMIT 11 - SQL - - # Happens when the topic doesn't have any embeddings - # I'd rather not use Exceptions to control the flow, so this should be refactored soon - if candidate_ids.empty? || !candidate_ids.include?(topic.id) - raise StandardError, "No embeddings found for topic #{topic.id}" - end - - candidate_ids - end end end end diff --git a/lib/modules/embeddings/semantic_search.rb b/lib/modules/embeddings/semantic_search.rb new file mode 100644 index 00000000..11f42d2c --- /dev/null +++ b/lib/modules/embeddings/semantic_search.rb @@ -0,0 +1,31 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + class SemanticSearch + def initialize(guardian, model) + @guardian = guardian + @model = model + end + + def search_for_topics(query, page = 1) + limit = Search.per_filter + 1 + offset = (page - 1) * Search.per_filter + + candidate_ids = + DiscourseAi::Embeddings::Topic.new.asymmetric_semantic_search(model, query, limit, offset) + + ::Post + .where(post_type: ::Topic.visible_post_types(guardian.user)) + .public_posts + .where("topics.visible") + .where(topic_id: candidate_ids, post_number: 1) + .order("array_position(ARRAY#{candidate_ids}, topic_id)") + end + + private + + attr_reader :model, :guardian + end + end +end diff --git a/lib/modules/embeddings/topic.rb b/lib/modules/embeddings/topic.rb index 73a3cc2b..b9072dc8 100644 --- a/lib/modules/embeddings/topic.rb +++ b/lib/modules/embeddings/topic.rb @@ -3,54 +3,80 @@ module DiscourseAi module Embeddings class Topic - def initialize(topic) - @topic = topic - @embeddings = {} - end - - def perform! + def generate_and_store_embeddings_for(topic) return unless SiteSetting.ai_embeddings_enabled - return if DiscourseAi::Embeddings::Models.enabled_models.empty? + return if topic.blank? || topic.first_post.blank? - calculate_embeddings! - persist_embeddings! unless @embeddings.empty? - end + enabled_models = DiscourseAi::Embeddings::Model.enabled_models + return if enabled_models.empty? - def calculate_embeddings! - return if @topic.blank? || @topic.first_post.blank? - - DiscourseAi::Embeddings::Models.enabled_models.each do |model| - @embeddings[model.name] = send("#{model.provider}_embeddings", model.name) + enabled_models.each do |model| + embedding = model.generate_embedding(topic.first_post.raw) + persist_embedding(topic, model, embedding) if embedding end end - def persist_embeddings! - @embeddings.each do |model, model_embedding| - DiscourseAi::Database::Connection.db.exec( - <<~SQL, - INSERT INTO topic_embeddings_#{model.underscore} (topic_id, embedding) - VALUES (:topic_id, '[:embedding]') - ON CONFLICT (topic_id) - DO UPDATE SET embedding = '[:embedding]' - SQL - topic_id: @topic.id, - embedding: model_embedding, - ) + def symmetric_semantic_search(model, topic) + candidate_ids = + DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id) + SELECT + topic_id + FROM + topic_embeddings_#{model.name.underscore} + ORDER BY + embedding #{model.pg_function} ( + SELECT + embedding + FROM + topic_embeddings_#{model.name.underscore} + WHERE + topic_id = :topic_id + LIMIT 1 + ) + LIMIT 11 + SQL + + # Happens when the topic doesn't have any embeddings + # I'd rather not use Exceptions to control the flow, so this should be refactored soon + if candidate_ids.empty? || !candidate_ids.include?(topic.id) + raise StandardError, "No embeddings found for topic #{topic.id}" end + + candidate_ids end - def discourse_embeddings(model) - DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", - model.to_s, - @topic.first_post.raw, - SiteSetting.ai_embeddings_discourse_service_api_key, - ) + def asymmetric_semantic_search(model, query, limit, offset) + query_embedding = model.generate_embedding(query) + + candidate_ids = + DiscourseAi::Database::Connection + .db + .query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset) + SELECT + topic_id + FROM + topic_embeddings_#{model.name.underscore} + ORDER BY + embedding #{model.pg_function} '[:query_embedding]' + LIMIT :limit + OFFSET :offset + SQL + .map(&:topic_id) + + raise StandardError, "No embeddings found for topic #{topic.id}" if candidate_ids.empty? + + candidate_ids end - def openai_embeddings(model) - response = DiscourseAi::Inference::OpenAIEmbeddings.perform!(@topic.first_post.raw) - response[:data].first[:embedding] + private + + def persist_embedding(topic, model, embedding) + DiscourseAi::Database::Connection.db.exec(<<~SQL, topic_id: topic.id, embedding: embedding) + INSERT INTO topic_embeddings_#{model.name.underscore} (topic_id, embedding) + VALUES (:topic_id, '[:embedding]') + ON CONFLICT (topic_id) + DO UPDATE SET embedding = '[:embedding]' + SQL end end end diff --git a/lib/tasks/modules/embeddings/database.rake b/lib/tasks/modules/embeddings/database.rake index 61b3e167..57707d87 100644 --- a/lib/tasks/modules/embeddings/database.rake +++ b/lib/tasks/modules/embeddings/database.rake @@ -3,10 +3,10 @@ desc "Creates tables to store embeddings" task "ai:embeddings:create_table" => [:environment] do DiscourseAi::Database::Connection.db.exec(<<~SQL) - CREATE EXTENSION IF NOT EXISTS pg_vector; + CREATE EXTENSION IF NOT EXISTS vector; SQL - DiscourseAi::Embeddings::Models.enabled_models.each do |model| + DiscourseAi::Embeddings::Model.enabled_models.each do |model| DiscourseAi::Database::Connection.db.exec(<<~SQL) CREATE TABLE IF NOT EXISTS topic_embeddings_#{model.name.underscore} ( topic_id bigint PRIMARY KEY, @@ -19,12 +19,13 @@ end desc "Backfill embeddings for all topics" task "ai:embeddings:backfill" => [:environment] do public_categories = Category.where(read_restricted: false).pluck(:id) + topic_embeddings = DiscourseAi::Embeddings::Topic.new Topic - .where("category_id IN ?", public_categories) + .where("category_id IN (?)", public_categories) .where(deleted_at: nil) .find_each do |t| print "." - DiscourseAI::Embeddings::Topic.new(t).perform! + topic_embeddings.generate_and_store_embeddings_for(t) end end @@ -36,14 +37,14 @@ task "ai:embeddings:index", [:work_mem] => [:environment] do |_, args| lists = 4 * Math.sqrt(Topic.count).to_i DiscourseAi::Database::Connection.db.exec("SET work_mem TO '#{args[:work_mem] || "1GB"}';") - DiscourseAi::Embeddings::Models.enabled_models.each do |model| + DiscourseAi::Embeddings::Model.enabled_models.each do |model| DiscourseAi::Database::Connection.db.exec(<<~SQL) CREATE INDEX IF NOT EXISTS topic_embeddings_#{model.name.underscore}_search ON topic_embeddings_#{model.name.underscore} USING - ivfflat (embedding #{DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_INDEX[model.functions.first]}) + ivfflat (embedding #{model.pg_index}) WITH (lists = #{lists}); SQL diff --git a/spec/lib/modules/embeddings/model_spec.rb b/spec/lib/modules/embeddings/model_spec.rb new file mode 100644 index 00000000..81078373 --- /dev/null +++ b/spec/lib/modules/embeddings/model_spec.rb @@ -0,0 +1,36 @@ +# frozen_string_literal: true + +require_relative "../../../support/embeddings_generation_stubs" + +RSpec.describe DiscourseAi::Embeddings::Model do + describe "#generate_embedding" do + let(:input) { "test" } + let(:expected_embedding) { [0.0038493, 0.482001] } + + context "when the model uses the discourse service to create embeddings" do + before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" } + + let(:discourse_model) { "all-mpnet-base-v2" } + + it "returns an embedding for a given string" do + EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding) + + embedding = described_class.instantiate(discourse_model).generate_embedding(input) + + expect(embedding).to contain_exactly(*expected_embedding) + end + end + + context "when the model uses OpenAI to create embeddings" do + let(:openai_model) { "text-embedding-ada-002" } + + it "returns an embedding for a given string" do + EmbeddingsGenerationStubs.openai_service(openai_model, input, expected_embedding) + + embedding = described_class.instantiate(openai_model).generate_embedding(input) + + expect(embedding).to contain_exactly(*expected_embedding) + end + end + end +end diff --git a/spec/lib/modules/embeddings/semantic_related_spec.rb b/spec/lib/modules/embeddings/semantic_related_spec.rb index b72a477e..e9b8b566 100644 --- a/spec/lib/modules/embeddings/semantic_related_spec.rb +++ b/spec/lib/modules/embeddings/semantic_related_spec.rb @@ -17,9 +17,10 @@ describe DiscourseAi::Embeddings::SemanticRelated do describe "#candidates_for" do before do Discourse.cache.clear - described_class.stubs(:search_suggestions).returns( - Topic.unscoped.order(id: :desc).limit(10).pluck(:id), - ) + DiscourseAi::Embeddings::Topic + .any_instance + .expects(:symmetric_semantic_search) + .returns(Topic.unscoped.order(id: :desc).limit(10).pluck(:id)) end after { Discourse.cache.clear } diff --git a/spec/lib/modules/embeddings/semantic_search_spec.rb b/spec/lib/modules/embeddings/semantic_search_spec.rb new file mode 100644 index 00000000..8242835e --- /dev/null +++ b/spec/lib/modules/embeddings/semantic_search_spec.rb @@ -0,0 +1,86 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Embeddings::SemanticSearch do + fab!(:post) { Fabricate(:post) } + fab!(:user) { Fabricate(:user) } + let(:model_name) { "msmarco-distilbert-base-v4" } + let(:query) { "test_query" } + + let(:model) { DiscourseAi::Embeddings::Model.instantiate(model_name) } + let(:subject) { described_class.new(Guardian.new(user), model) } + + describe "#search_for_topics" do + def stub_candidate_ids(candidate_ids) + DiscourseAi::Embeddings::Topic + .any_instance + .expects(:asymmetric_semantic_search) + .returns(candidate_ids) + end + + it "returns the first post of a topic included in the asymmetric search results" do + stub_candidate_ids([post.topic_id]) + + posts = subject.search_for_topics(query) + + expect(posts).to contain_exactly(post) + end + + describe "applies different scopes to the candidates" do + context "when the topic is not visible" do + it "returns an empty list" do + post.topic.update!(visible: false) + stub_candidate_ids([post.topic_id]) + + posts = subject.search_for_topics(query) + + expect(posts).to be_empty + end + end + + context "when the post is not public" do + it "returns an empty list" do + pm_post = Fabricate(:private_message_post) + stub_candidate_ids([pm_post.topic_id]) + + posts = subject.search_for_topics(query) + + expect(posts).to be_empty + end + end + + context "when the post type is not visible" do + it "returns an empty list" do + post.update!(post_type: Post.types[:whisper]) + stub_candidate_ids([post.topic_id]) + + posts = subject.search_for_topics(query) + + expect(posts).to be_empty + end + end + + context "when the post is not the first post in the topic" do + it "returns an empty list" do + reply = Fabricate(:reply) + reply.topic.first_post.trash! + stub_candidate_ids([reply.topic_id]) + + posts = subject.search_for_topics(query) + + expect(posts).to be_empty + end + end + + context "when the post is not a candidate" do + it "doesn't include it in the results" do + post_2 = Fabricate(:post) + stub_candidate_ids([post.topic_id]) + + posts = subject.search_for_topics(query) + + expect(posts).not_to include(post_2) + end + end + end + end +end diff --git a/spec/requests/topic_spec.rb b/spec/requests/topic_spec.rb index 3000ac30..fd7b0e99 100644 --- a/spec/requests/topic_spec.rb +++ b/spec/requests/topic_spec.rb @@ -19,9 +19,10 @@ describe ::TopicsController do context "when a user is logged on" do it "includes related topics in payload when configured" do - DiscourseAi::Embeddings::SemanticRelated.stubs(:search_suggestions).returns( - [topic1.id, topic2.id, topic3.id], - ) + DiscourseAi::Embeddings::Topic + .any_instance + .expects(:symmetric_semantic_search) + .returns([topic1.id, topic2.id, topic3.id]) get("#{topic.relative_url}.json") json = response.parsed_body diff --git a/spec/support/embeddings_generation_stubs.rb b/spec/support/embeddings_generation_stubs.rb new file mode 100644 index 00000000..4724b57a --- /dev/null +++ b/spec/support/embeddings_generation_stubs.rb @@ -0,0 +1,22 @@ +# frozen_string_literal: true + +class EmbeddingsGenerationStubs + class << self + def discourse_service(model, string, embedding) + WebMock + .stub_request( + :post, + "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", + ) + .with(body: JSON.dump({ model: model, content: string })) + .to_return(status: 200, body: JSON.dump(embedding)) + end + + def openai_service(model, string, embedding) + WebMock + .stub_request(:post, "https://api.openai.com/v1/embeddings") + .with(body: JSON.dump({ model: model, input: string })) + .to_return(status: 200, body: JSON.dump({ data: [{ embedding: embedding }] })) + end + end +end