FEATURE: Semantic assymetric full-page search (#34)
Depends on discourse/discourse#20915 Hooks to the full-page-search component using an experimental API and performs an assymetric similarity search using our embeddings database.
This commit is contained in:
parent
99886fb64d
commit
4e05763a99
|
@ -0,0 +1,35 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
class EmbeddingsController < ::ApplicationController
|
||||
requires_plugin ::DiscourseAi::PLUGIN_NAME
|
||||
|
||||
SEMANTIC_SEARCH_TYPE = "semantic_search"
|
||||
|
||||
def search
|
||||
query = params[:q]
|
||||
page = (params[:page] || 1).to_i
|
||||
|
||||
grouped_results =
|
||||
Search::GroupedSearchResults.new(
|
||||
type_filter: SEMANTIC_SEARCH_TYPE,
|
||||
term: query,
|
||||
search_context: guardian,
|
||||
)
|
||||
|
||||
model =
|
||||
DiscourseAi::Embeddings::Model.instantiate(
|
||||
SiteSetting.ai_embeddings_semantic_search_model,
|
||||
)
|
||||
|
||||
DiscourseAi::Embeddings::SemanticSearch
|
||||
.new(guardian, model)
|
||||
.search_for_topics(query, page)
|
||||
.each { |topic_post| grouped_results.add(topic_post) }
|
||||
|
||||
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,64 @@
|
|||
import { withPluginApi } from "discourse/lib/plugin-api";
|
||||
import { translateResults, updateRecentSearches } from "discourse/lib/search";
|
||||
import { setTransient } from "discourse/lib/page-tracker";
|
||||
import { ajax } from "discourse/lib/ajax";
|
||||
|
||||
const SEMANTIC_SEARCH = "semantic_search";
|
||||
|
||||
function initializeSemanticSearch(api) {
|
||||
api.addFullPageSearchType(
|
||||
"discourse_ai.embeddings.semantic_search",
|
||||
SEMANTIC_SEARCH,
|
||||
(searchController, args, searchKey) => {
|
||||
if (searchController.currentUser) {
|
||||
updateRecentSearches(searchController.currentUser, args.searchTerm);
|
||||
}
|
||||
|
||||
ajax("/discourse-ai/embeddings/semantic-search", { data: args })
|
||||
.then(async (results) => {
|
||||
const model = (await translateResults(results)) || {};
|
||||
|
||||
if (results.grouped_search_result) {
|
||||
searchController.set("q", results.grouped_search_result.term);
|
||||
}
|
||||
|
||||
if (args.page > 1) {
|
||||
if (model) {
|
||||
searchController.model.posts.pushObjects(model.posts);
|
||||
searchController.model.topics.pushObjects(model.topics);
|
||||
searchController.model.set(
|
||||
"grouped_search_result",
|
||||
results.grouped_search_result
|
||||
);
|
||||
}
|
||||
} else {
|
||||
setTransient("lastSearch", { searchKey, model }, 5);
|
||||
model.grouped_search_result = results.grouped_search_result;
|
||||
searchController.set("model", model);
|
||||
}
|
||||
searchController.set("error", null);
|
||||
})
|
||||
.catch((e) => {
|
||||
searchController.set("error", e.jqXHR.responseJSON?.message);
|
||||
})
|
||||
.finally(() => {
|
||||
searchController.setProperties({
|
||||
searching: false,
|
||||
loading: false,
|
||||
});
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
export default {
|
||||
name: "discourse_ai-full-page-semantic-search",
|
||||
|
||||
initialize(container) {
|
||||
const settings = container.lookup("site-settings:main");
|
||||
|
||||
if (settings.ai_embeddings_enabled) {
|
||||
withPluginApi("1.6.0", initializeSemanticSearch);
|
||||
}
|
||||
},
|
||||
};
|
|
@ -10,6 +10,9 @@ en:
|
|||
reviewables:
|
||||
model_used: "Model used:"
|
||||
accuracy: "Accuracy:"
|
||||
|
||||
embeddings:
|
||||
semantic_search: "Topics (Semantic)"
|
||||
review:
|
||||
types:
|
||||
reviewable_ai_post:
|
||||
|
|
|
@ -47,6 +47,7 @@ en:
|
|||
ai_embeddings_semantic_related_topics_enabled: "Use Semantic Search for related topics."
|
||||
ai_embeddings_semantic_related_topics: "Maximum number of topics to show in related topic section."
|
||||
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
|
||||
ai_embeddings_semantic_search_model: "Model to use for semantic search."
|
||||
|
||||
reviewables:
|
||||
reasons:
|
||||
|
|
|
@ -6,6 +6,11 @@ DiscourseAi::Engine.routes.draw do
|
|||
get "prompts" => "assistant#prompts"
|
||||
post "suggest" => "assistant#suggest"
|
||||
end
|
||||
|
||||
# Embedding routes
|
||||
scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do
|
||||
get "semantic-search" => "embeddings#search"
|
||||
end
|
||||
end
|
||||
|
||||
Discourse::Application.routes.append { mount ::DiscourseAi::Engine, at: "discourse-ai" }
|
||||
|
|
|
@ -108,7 +108,9 @@ plugins:
|
|||
- gpt-3.5-turbo
|
||||
- gpt-4
|
||||
|
||||
ai_embeddings_enabled: false
|
||||
ai_embeddings_enabled:
|
||||
default: false
|
||||
client: true
|
||||
ai_embeddings_discourse_service_api_endpoint: ""
|
||||
ai_embeddings_discourse_service_api_key: ""
|
||||
ai_embeddings_models:
|
||||
|
@ -133,6 +135,13 @@ plugins:
|
|||
- all-distilroberta-v1
|
||||
- multi-qa-mpnet-base-dot-v1
|
||||
- paraphrase-multilingual-mpnet-base-v2
|
||||
ai_embeddings_semantic_search_model:
|
||||
type: enum
|
||||
default: msmarco-distilbert-base-v4
|
||||
choices:
|
||||
- msmarco-distilbert-base-v4
|
||||
- msmarco-distilbert-base-tas-b
|
||||
- text-embedding-ada-002
|
||||
ai_embeddings_generate_for_pms: false
|
||||
ai_embeddings_semantic_related_topics_enabled: false
|
||||
ai_embeddings_semantic_related_topics: 5
|
||||
|
|
|
@ -4,10 +4,11 @@ module DiscourseAi
|
|||
module Embeddings
|
||||
class EntryPoint
|
||||
def load_files
|
||||
require_relative "models"
|
||||
require_relative "model"
|
||||
require_relative "topic"
|
||||
require_relative "jobs/regular/generate_embeddings"
|
||||
require_relative "semantic_related"
|
||||
require_relative "semantic_search"
|
||||
end
|
||||
|
||||
def inject_into(plugin)
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
class Model
|
||||
AVAILABLE_MODELS_TEMPLATES = {
|
||||
"all-mpnet-base-v2" => [768, 384, %i[dot cosine euclidean], %i[symmetric], "discourse"],
|
||||
"all-distilroberta-v1" => [768, 512, %i[dot cosine euclidean], %i[symmetric], "discourse"],
|
||||
"multi-qa-mpnet-base-dot-v1" => [768, 512, %i[dot], %i[symmetric], "discourse"],
|
||||
"paraphrase-multilingual-mpnet-base-v2" => [
|
||||
768,
|
||||
128,
|
||||
%i[cosine],
|
||||
%i[symmetric],
|
||||
"discourse",
|
||||
],
|
||||
"msmarco-distilbert-base-v4" => [768, 512, %i[cosine], %i[asymmetric], "discourse"],
|
||||
"msmarco-distilbert-base-tas-b" => [768, 512, %i[dot], %i[asymmetric], "discourse"],
|
||||
"text-embedding-ada-002" => [1536, 2048, %i[cosine], %i[symmetric asymmetric], "openai"],
|
||||
}
|
||||
|
||||
SEARCH_FUNCTION_TO_PG_INDEX = {
|
||||
dot: "vector_ip_ops",
|
||||
cosine: "vector_cosine_ops",
|
||||
euclidean: "vector_l2_ops",
|
||||
}
|
||||
|
||||
SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" }
|
||||
|
||||
class << self
|
||||
def instantiate(model_name)
|
||||
new(model_name, *AVAILABLE_MODELS_TEMPLATES[model_name])
|
||||
end
|
||||
|
||||
def enabled_models
|
||||
SiteSetting
|
||||
.ai_embeddings_models
|
||||
.split("|")
|
||||
.map { |model_name| instantiate(model_name.strip) }
|
||||
end
|
||||
end
|
||||
|
||||
def initialize(name, dimensions, max_sequence_lenght, functions, type, provider)
|
||||
@name = name
|
||||
@dimensions = dimensions
|
||||
@max_sequence_lenght = max_sequence_lenght
|
||||
@functions = functions
|
||||
@type = type
|
||||
@provider = provider
|
||||
end
|
||||
|
||||
def generate_embedding(input)
|
||||
send("#{provider}_embeddings", input)
|
||||
end
|
||||
|
||||
def pg_function
|
||||
SEARCH_FUNCTION_TO_PG_FUNCTION[functions.first]
|
||||
end
|
||||
|
||||
def pg_index
|
||||
SEARCH_FUNCTION_TO_PG_INDEX[functions.first]
|
||||
end
|
||||
|
||||
attr_reader :name, :dimensions, :max_sequence_lenght, :functions, :type, :provider
|
||||
|
||||
private
|
||||
|
||||
def discourse_embeddings(input)
|
||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
name.to_s,
|
||||
input,
|
||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||
)
|
||||
end
|
||||
|
||||
def openai_embeddings(input)
|
||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(input)
|
||||
response[:data].first[:embedding]
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,62 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
class Models
|
||||
MODEL = Data.define(:name, :dimensions, :max_sequence_lenght, :functions, :type, :provider)
|
||||
|
||||
SEARCH_FUNCTION_TO_PG_INDEX = {
|
||||
dot: "vector_ip_ops",
|
||||
cosine: "vector_cosine_ops",
|
||||
euclidean: "vector_l2_ops",
|
||||
}
|
||||
|
||||
SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" }
|
||||
|
||||
def self.enabled_models
|
||||
setting = SiteSetting.ai_embeddings_models.split("|").map(&:strip)
|
||||
list.filter { |model| setting.include?(model.name) }
|
||||
end
|
||||
|
||||
def self.list
|
||||
@@list ||= [
|
||||
MODEL.new(
|
||||
"all-mpnet-base-v2",
|
||||
768,
|
||||
384,
|
||||
%i[dot cosine euclidean],
|
||||
[:symmetric],
|
||||
"discourse",
|
||||
),
|
||||
MODEL.new(
|
||||
"all-distilroberta-v1",
|
||||
768,
|
||||
512,
|
||||
%i[dot cosine euclidean],
|
||||
[:symmetric],
|
||||
"discourse",
|
||||
),
|
||||
MODEL.new("multi-qa-mpnet-base-dot-v1", 768, 512, [:dot], [:symmetric], "discourse"),
|
||||
MODEL.new(
|
||||
"paraphrase-multilingual-mpnet-base-v2",
|
||||
768,
|
||||
128,
|
||||
[:cosine],
|
||||
[:symmetric],
|
||||
"discourse",
|
||||
),
|
||||
MODEL.new("msmarco-distilbert-base-v4", 768, 512, [:cosine], [:asymmetric], "discourse"),
|
||||
MODEL.new("msmarco-distilbert-base-tas-b", 768, 512, [:dot], [:asymmetric], "discourse"),
|
||||
MODEL.new(
|
||||
"text-embedding-ada-002",
|
||||
1536,
|
||||
2048,
|
||||
[:cosine],
|
||||
%i[:symmetric :asymmetric],
|
||||
"openai",
|
||||
),
|
||||
]
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -16,12 +16,17 @@ module DiscourseAi
|
|||
1.day
|
||||
end
|
||||
|
||||
model =
|
||||
DiscourseAi::Embeddings::Model.instantiate(
|
||||
SiteSetting.ai_embeddings_semantic_related_model,
|
||||
)
|
||||
|
||||
begin
|
||||
candidate_ids =
|
||||
Discourse
|
||||
.cache
|
||||
.fetch("semantic-suggested-topic-#{topic.id}", expires_in: cache_for) do
|
||||
search_suggestions(topic)
|
||||
DiscourseAi::Embeddings::Topic.new.symmetric_semantic_search(model, topic)
|
||||
end
|
||||
rescue StandardError => e
|
||||
Rails.logger.error("SemanticRelated: #{e}")
|
||||
|
@ -39,40 +44,6 @@ module DiscourseAi
|
|||
.order("array_position(ARRAY#{candidate_ids}, id)")
|
||||
.limit(SiteSetting.ai_embeddings_semantic_related_topics)
|
||||
end
|
||||
|
||||
def self.search_suggestions(topic)
|
||||
model_name = SiteSetting.ai_embeddings_semantic_related_model
|
||||
model = DiscourseAi::Embeddings::Models.list.find { |m| m.name == model_name }
|
||||
function =
|
||||
DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_FUNCTION[model.functions.first]
|
||||
|
||||
candidate_ids =
|
||||
DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
||||
SELECT
|
||||
topic_id
|
||||
FROM
|
||||
topic_embeddings_#{model_name.underscore}
|
||||
ORDER BY
|
||||
embedding #{function} (
|
||||
SELECT
|
||||
embedding
|
||||
FROM
|
||||
topic_embeddings_#{model_name.underscore}
|
||||
WHERE
|
||||
topic_id = :topic_id
|
||||
LIMIT 1
|
||||
)
|
||||
LIMIT 11
|
||||
SQL
|
||||
|
||||
# Happens when the topic doesn't have any embeddings
|
||||
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
|
||||
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
|
||||
raise StandardError, "No embeddings found for topic #{topic.id}"
|
||||
end
|
||||
|
||||
candidate_ids
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
class SemanticSearch
|
||||
def initialize(guardian, model)
|
||||
@guardian = guardian
|
||||
@model = model
|
||||
end
|
||||
|
||||
def search_for_topics(query, page = 1)
|
||||
limit = Search.per_filter + 1
|
||||
offset = (page - 1) * Search.per_filter
|
||||
|
||||
candidate_ids =
|
||||
DiscourseAi::Embeddings::Topic.new.asymmetric_semantic_search(model, query, limit, offset)
|
||||
|
||||
::Post
|
||||
.where(post_type: ::Topic.visible_post_types(guardian.user))
|
||||
.public_posts
|
||||
.where("topics.visible")
|
||||
.where(topic_id: candidate_ids, post_number: 1)
|
||||
.order("array_position(ARRAY#{candidate_ids}, topic_id)")
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
attr_reader :model, :guardian
|
||||
end
|
||||
end
|
||||
end
|
|
@ -3,54 +3,80 @@
|
|||
module DiscourseAi
|
||||
module Embeddings
|
||||
class Topic
|
||||
def initialize(topic)
|
||||
@topic = topic
|
||||
@embeddings = {}
|
||||
end
|
||||
|
||||
def perform!
|
||||
def generate_and_store_embeddings_for(topic)
|
||||
return unless SiteSetting.ai_embeddings_enabled
|
||||
return if DiscourseAi::Embeddings::Models.enabled_models.empty?
|
||||
return if topic.blank? || topic.first_post.blank?
|
||||
|
||||
calculate_embeddings!
|
||||
persist_embeddings! unless @embeddings.empty?
|
||||
end
|
||||
enabled_models = DiscourseAi::Embeddings::Model.enabled_models
|
||||
return if enabled_models.empty?
|
||||
|
||||
def calculate_embeddings!
|
||||
return if @topic.blank? || @topic.first_post.blank?
|
||||
|
||||
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
|
||||
@embeddings[model.name] = send("#{model.provider}_embeddings", model.name)
|
||||
enabled_models.each do |model|
|
||||
embedding = model.generate_embedding(topic.first_post.raw)
|
||||
persist_embedding(topic, model, embedding) if embedding
|
||||
end
|
||||
end
|
||||
|
||||
def persist_embeddings!
|
||||
@embeddings.each do |model, model_embedding|
|
||||
DiscourseAi::Database::Connection.db.exec(
|
||||
<<~SQL,
|
||||
INSERT INTO topic_embeddings_#{model.underscore} (topic_id, embedding)
|
||||
def symmetric_semantic_search(model, topic)
|
||||
candidate_ids =
|
||||
DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
||||
SELECT
|
||||
topic_id
|
||||
FROM
|
||||
topic_embeddings_#{model.name.underscore}
|
||||
ORDER BY
|
||||
embedding #{model.pg_function} (
|
||||
SELECT
|
||||
embedding
|
||||
FROM
|
||||
topic_embeddings_#{model.name.underscore}
|
||||
WHERE
|
||||
topic_id = :topic_id
|
||||
LIMIT 1
|
||||
)
|
||||
LIMIT 11
|
||||
SQL
|
||||
|
||||
# Happens when the topic doesn't have any embeddings
|
||||
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
|
||||
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
|
||||
raise StandardError, "No embeddings found for topic #{topic.id}"
|
||||
end
|
||||
|
||||
candidate_ids
|
||||
end
|
||||
|
||||
def asymmetric_semantic_search(model, query, limit, offset)
|
||||
query_embedding = model.generate_embedding(query)
|
||||
|
||||
candidate_ids =
|
||||
DiscourseAi::Database::Connection
|
||||
.db
|
||||
.query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset)
|
||||
SELECT
|
||||
topic_id
|
||||
FROM
|
||||
topic_embeddings_#{model.name.underscore}
|
||||
ORDER BY
|
||||
embedding #{model.pg_function} '[:query_embedding]'
|
||||
LIMIT :limit
|
||||
OFFSET :offset
|
||||
SQL
|
||||
.map(&:topic_id)
|
||||
|
||||
raise StandardError, "No embeddings found for topic #{topic.id}" if candidate_ids.empty?
|
||||
|
||||
candidate_ids
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def persist_embedding(topic, model, embedding)
|
||||
DiscourseAi::Database::Connection.db.exec(<<~SQL, topic_id: topic.id, embedding: embedding)
|
||||
INSERT INTO topic_embeddings_#{model.name.underscore} (topic_id, embedding)
|
||||
VALUES (:topic_id, '[:embedding]')
|
||||
ON CONFLICT (topic_id)
|
||||
DO UPDATE SET embedding = '[:embedding]'
|
||||
SQL
|
||||
topic_id: @topic.id,
|
||||
embedding: model_embedding,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def discourse_embeddings(model)
|
||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
model.to_s,
|
||||
@topic.first_post.raw,
|
||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||
)
|
||||
end
|
||||
|
||||
def openai_embeddings(model)
|
||||
response = DiscourseAi::Inference::OpenAIEmbeddings.perform!(@topic.first_post.raw)
|
||||
response[:data].first[:embedding]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -3,10 +3,10 @@
|
|||
desc "Creates tables to store embeddings"
|
||||
task "ai:embeddings:create_table" => [:environment] do
|
||||
DiscourseAi::Database::Connection.db.exec(<<~SQL)
|
||||
CREATE EXTENSION IF NOT EXISTS pg_vector;
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
SQL
|
||||
|
||||
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
|
||||
DiscourseAi::Embeddings::Model.enabled_models.each do |model|
|
||||
DiscourseAi::Database::Connection.db.exec(<<~SQL)
|
||||
CREATE TABLE IF NOT EXISTS topic_embeddings_#{model.name.underscore} (
|
||||
topic_id bigint PRIMARY KEY,
|
||||
|
@ -19,12 +19,13 @@ end
|
|||
desc "Backfill embeddings for all topics"
|
||||
task "ai:embeddings:backfill" => [:environment] do
|
||||
public_categories = Category.where(read_restricted: false).pluck(:id)
|
||||
topic_embeddings = DiscourseAi::Embeddings::Topic.new
|
||||
Topic
|
||||
.where("category_id IN ?", public_categories)
|
||||
.where("category_id IN (?)", public_categories)
|
||||
.where(deleted_at: nil)
|
||||
.find_each do |t|
|
||||
print "."
|
||||
DiscourseAI::Embeddings::Topic.new(t).perform!
|
||||
topic_embeddings.generate_and_store_embeddings_for(t)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -36,14 +37,14 @@ task "ai:embeddings:index", [:work_mem] => [:environment] do |_, args|
|
|||
lists = 4 * Math.sqrt(Topic.count).to_i
|
||||
|
||||
DiscourseAi::Database::Connection.db.exec("SET work_mem TO '#{args[:work_mem] || "1GB"}';")
|
||||
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
|
||||
DiscourseAi::Embeddings::Model.enabled_models.each do |model|
|
||||
DiscourseAi::Database::Connection.db.exec(<<~SQL)
|
||||
CREATE INDEX IF NOT EXISTS
|
||||
topic_embeddings_#{model.name.underscore}_search
|
||||
ON
|
||||
topic_embeddings_#{model.name.underscore}
|
||||
USING
|
||||
ivfflat (embedding #{DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_INDEX[model.functions.first]})
|
||||
ivfflat (embedding #{model.pg_index})
|
||||
WITH
|
||||
(lists = #{lists});
|
||||
SQL
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "../../../support/embeddings_generation_stubs"
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::Model do
|
||||
describe "#generate_embedding" do
|
||||
let(:input) { "test" }
|
||||
let(:expected_embedding) { [0.0038493, 0.482001] }
|
||||
|
||||
context "when the model uses the discourse service to create embeddings" do
|
||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||
|
||||
let(:discourse_model) { "all-mpnet-base-v2" }
|
||||
|
||||
it "returns an embedding for a given string" do
|
||||
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
|
||||
|
||||
embedding = described_class.instantiate(discourse_model).generate_embedding(input)
|
||||
|
||||
expect(embedding).to contain_exactly(*expected_embedding)
|
||||
end
|
||||
end
|
||||
|
||||
context "when the model uses OpenAI to create embeddings" do
|
||||
let(:openai_model) { "text-embedding-ada-002" }
|
||||
|
||||
it "returns an embedding for a given string" do
|
||||
EmbeddingsGenerationStubs.openai_service(openai_model, input, expected_embedding)
|
||||
|
||||
embedding = described_class.instantiate(openai_model).generate_embedding(input)
|
||||
|
||||
expect(embedding).to contain_exactly(*expected_embedding)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -17,9 +17,10 @@ describe DiscourseAi::Embeddings::SemanticRelated do
|
|||
describe "#candidates_for" do
|
||||
before do
|
||||
Discourse.cache.clear
|
||||
described_class.stubs(:search_suggestions).returns(
|
||||
Topic.unscoped.order(id: :desc).limit(10).pluck(:id),
|
||||
)
|
||||
DiscourseAi::Embeddings::Topic
|
||||
.any_instance
|
||||
.expects(:symmetric_semantic_search)
|
||||
.returns(Topic.unscoped.order(id: :desc).limit(10).pluck(:id))
|
||||
end
|
||||
|
||||
after { Discourse.cache.clear }
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
||||
fab!(:post) { Fabricate(:post) }
|
||||
fab!(:user) { Fabricate(:user) }
|
||||
let(:model_name) { "msmarco-distilbert-base-v4" }
|
||||
let(:query) { "test_query" }
|
||||
|
||||
let(:model) { DiscourseAi::Embeddings::Model.instantiate(model_name) }
|
||||
let(:subject) { described_class.new(Guardian.new(user), model) }
|
||||
|
||||
describe "#search_for_topics" do
|
||||
def stub_candidate_ids(candidate_ids)
|
||||
DiscourseAi::Embeddings::Topic
|
||||
.any_instance
|
||||
.expects(:asymmetric_semantic_search)
|
||||
.returns(candidate_ids)
|
||||
end
|
||||
|
||||
it "returns the first post of a topic included in the asymmetric search results" do
|
||||
stub_candidate_ids([post.topic_id])
|
||||
|
||||
posts = subject.search_for_topics(query)
|
||||
|
||||
expect(posts).to contain_exactly(post)
|
||||
end
|
||||
|
||||
describe "applies different scopes to the candidates" do
|
||||
context "when the topic is not visible" do
|
||||
it "returns an empty list" do
|
||||
post.topic.update!(visible: false)
|
||||
stub_candidate_ids([post.topic_id])
|
||||
|
||||
posts = subject.search_for_topics(query)
|
||||
|
||||
expect(posts).to be_empty
|
||||
end
|
||||
end
|
||||
|
||||
context "when the post is not public" do
|
||||
it "returns an empty list" do
|
||||
pm_post = Fabricate(:private_message_post)
|
||||
stub_candidate_ids([pm_post.topic_id])
|
||||
|
||||
posts = subject.search_for_topics(query)
|
||||
|
||||
expect(posts).to be_empty
|
||||
end
|
||||
end
|
||||
|
||||
context "when the post type is not visible" do
|
||||
it "returns an empty list" do
|
||||
post.update!(post_type: Post.types[:whisper])
|
||||
stub_candidate_ids([post.topic_id])
|
||||
|
||||
posts = subject.search_for_topics(query)
|
||||
|
||||
expect(posts).to be_empty
|
||||
end
|
||||
end
|
||||
|
||||
context "when the post is not the first post in the topic" do
|
||||
it "returns an empty list" do
|
||||
reply = Fabricate(:reply)
|
||||
reply.topic.first_post.trash!
|
||||
stub_candidate_ids([reply.topic_id])
|
||||
|
||||
posts = subject.search_for_topics(query)
|
||||
|
||||
expect(posts).to be_empty
|
||||
end
|
||||
end
|
||||
|
||||
context "when the post is not a candidate" do
|
||||
it "doesn't include it in the results" do
|
||||
post_2 = Fabricate(:post)
|
||||
stub_candidate_ids([post.topic_id])
|
||||
|
||||
posts = subject.search_for_topics(query)
|
||||
|
||||
expect(posts).not_to include(post_2)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -19,9 +19,10 @@ describe ::TopicsController do
|
|||
|
||||
context "when a user is logged on" do
|
||||
it "includes related topics in payload when configured" do
|
||||
DiscourseAi::Embeddings::SemanticRelated.stubs(:search_suggestions).returns(
|
||||
[topic1.id, topic2.id, topic3.id],
|
||||
)
|
||||
DiscourseAi::Embeddings::Topic
|
||||
.any_instance
|
||||
.expects(:symmetric_semantic_search)
|
||||
.returns([topic1.id, topic2.id, topic3.id])
|
||||
|
||||
get("#{topic.relative_url}.json")
|
||||
json = response.parsed_body
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class EmbeddingsGenerationStubs
|
||||
class << self
|
||||
def discourse_service(model, string, embedding)
|
||||
WebMock
|
||||
.stub_request(
|
||||
:post,
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
)
|
||||
.with(body: JSON.dump({ model: model, content: string }))
|
||||
.to_return(status: 200, body: JSON.dump(embedding))
|
||||
end
|
||||
|
||||
def openai_service(model, string, embedding)
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.openai.com/v1/embeddings")
|
||||
.with(body: JSON.dump({ model: model, input: string }))
|
||||
.to_return(status: 200, body: JSON.dump({ data: [{ embedding: embedding }] }))
|
||||
end
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue