FEATURE: Embeddings to main db (#99)

* FEATURE: Embeddings to main db

This commit moves our embeddings store from an external configurable PostgreSQL
instance back into the main database. This is done to simplify the setup.

There is a migration that will try to import the external embeddings into
the main DB if it is configured and there are rows.

It removes support from embeddings models that aren't all_mpnet_base_v2 or OpenAI
text_embedding_ada_002. However it will now be easier to add new models.

It also now takes into account:
  - topic title
  - topic category
  - topic tags
  - replies (as much as the model allows)

We introduce an interface so we can eventually support multiple strategies
for handling long topics.

This PR severely damages the semantic search performance, but this is a
temporary until we can get adapt HyDE to make semantic search use the same
embeddings we have for semantic related with good performance.

Here we also have some ground work to add post level embeddings, but this
will be added in a future PR.

Please note that this PR will also block Discourse from booting / updating if 
this plugin is installed and the pgvector extension isn't available on the 
PostgreSQL instance Discourse uses.
This commit is contained in:
Rafael dos Santos Silva 2023-07-13 12:41:36 -03:00 committed by GitHub
parent 9d10a152b9
commit 5e3f4e1b78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 620 additions and 372 deletions

View File

@ -19,13 +19,8 @@ module DiscourseAi
use_pg_headlines_for_excerpt: false,
)
model =
DiscourseAi::Embeddings::Model.instantiate(
SiteSetting.ai_embeddings_semantic_search_model,
)
DiscourseAi::Embeddings::SemanticSearch
.new(guardian, model)
.new(guardian)
.search_for_topics(query, page)
.each { |topic_post| grouped_results.add(topic_post) }

View File

@ -143,39 +143,14 @@ plugins:
ai_embeddings_discourse_service_api_key:
default: ""
secret: true
ai_embeddings_models:
type: list
ai_embeddings_model:
type: enum
list_type: compact
default: ""
default: "all-mpnet-base-v2"
allow_any: false
choices:
- all-mpnet-base-v2
- all-distilroberta-v1
- multi-qa-mpnet-base-dot-v1
- paraphrase-multilingual-mpnet-base-v2
- msmarco-distilbert-base-v4
- msmarco-distilbert-base-tas-b
- text-embedding-ada-002
ai_embeddings_semantic_related_model:
type: enum
default: all-mpnet-base-v2
choices:
- all-mpnet-base-v2
- text-embedding-ada-002
- all-distilroberta-v1
- multi-qa-mpnet-base-dot-v1
- paraphrase-multilingual-mpnet-base-v2
ai_embeddings_semantic_search_model:
type: enum
default: msmarco-distilbert-base-v4
choices:
- msmarco-distilbert-base-v4
- msmarco-distilbert-base-tas-b
- text-embedding-ada-002
ai_embeddings_semantic_related_instruction:
default: "Represent the Discourse topic for retrieving relevant topics:"
hidden: true
client: false
ai_embeddings_generate_for_pms: false
ai_embeddings_semantic_related_topics_enabled: false
ai_embeddings_semantic_related_topics: 5

View File

@ -0,0 +1,7 @@
# frozen_string_literal: true
class EnablePgVectorExtension < ActiveRecord::Migration[7.0]
def change
enable_extension :vector
end
end

View File

@ -0,0 +1,28 @@
# frozen_string_literal: true
class CreateAiTopicEmbeddingsTable < ActiveRecord::Migration[7.0]
def change
models = [
DiscourseAi::Embeddings::Models::AllMpnetBaseV2,
DiscourseAi::Embeddings::Models::TextEmbeddingAda002,
]
strategies = [DiscourseAi::Embeddings::Strategies::Truncation]
models.each do |model|
strategies.each do |strategy|
table_name = "ai_topic_embeddings_#{model.id}_#{strategy.id}".to_sym
create_table table_name, id: false do |t|
t.integer :topic_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(#{model.dimensions})", null: false
t.timestamps
t.index :topic_id, unique: true
end
end
end
end
end

View File

@ -0,0 +1,63 @@
# frozen_string_literal: true
class MigrateEmbeddingsFromDedicatedDatabase < ActiveRecord::Migration[7.0]
def up
return unless SiteSetting.ai_embeddings_enabled
return unless SiteSetting.ai_embeddings_pg_connection_string.present?
models = [
DiscourseAi::Embeddings::Models::AllMpnetBaseV2,
DiscourseAi::Embeddings::Models::TextEmbeddingAda002,
]
strategies = [DiscourseAi::Embeddings::Strategies::Truncation]
models.each do |model|
strategies.each do |strategy|
new_table_name = "ai_topic_embeddings_#{model.id}_#{strategy.id}"
old_table_name = "topic_embeddings_#{model.name.underscore}"
begin
row_count =
DiscourseAi::Database::Connection
.db
.query_single("SELECT COUNT(*) FROM #{old_table_name}")
.first
if row_count > 0
puts "Migrating #{row_count} embeddings from #{old_table_name} to #{new_table_name}"
last_topic_id = 0
loop do
batch = DiscourseAi::Database::Connection.db.query(<<-SQL)
SELECT topic_id, embedding
FROM #{old_table_name}
WHERE topic_id > #{last_topic_id}
ORDER BY topic_id ASC
LIMIT 50
SQL
break if batch.empty?
DB.exec(<<-SQL)
INSERT INTO #{new_table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
VALUES #{batch.map { |r| "(#{r.topic_id}, 0, 0, '', '#{r.embedding}', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)" }.join(", ")}
ON CONFLICT (topic_id)
DO NOTHING
SQL
last_topic_id = batch.last.topic_id
end
end
rescue PG::Error => e
Rails.logger.error(
"Error #{e} migrating embeddings from #{old_table_name} to #{new_table_name}",
)
end
end
end
end
def down
# no-op
end
end

View File

@ -4,8 +4,11 @@ module DiscourseAi
module Embeddings
class EntryPoint
def load_files
require_relative "model"
require_relative "topic"
require_relative "models/base"
require_relative "models/all_mpnet_base_v2"
require_relative "models/text_embedding_ada_002"
require_relative "strategies/truncation"
require_relative "manager"
require_relative "jobs/regular/generate_embeddings"
require_relative "semantic_related"
require_relative "semantic_search"

View File

@ -11,7 +11,7 @@ module Jobs
post = topic.first_post
return if post.nil? || post.raw.blank?
DiscourseAi::Embeddings::Topic.new.generate_and_store_embeddings_for(topic)
DiscourseAi::Embeddings::Manager.new(topic).generate!
end
end
end

View File

@ -0,0 +1,64 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
class Manager
attr_reader :target, :model, :strategy
def initialize(target)
@target = target
@model =
DiscourseAi::Embeddings::Models::Base.subclasses.find do
_1.name == SiteSetting.ai_embeddings_model
end
@strategy = DiscourseAi::Embeddings::Strategies::Truncation.new(@target, @model)
end
def generate!
@strategy.process!
# TODO bail here if we already have an embedding with matching version and digest
@embeddings = @model.generate_embeddings(@strategy.processed_target)
persist!
end
def persist!
begin
DB.exec(
<<~SQL,
INSERT INTO ai_topic_embeddings_#{table_suffix} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (topic_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = CURRENT_TIMESTAMP
SQL
topic_id: @target.id,
model_version: @model.version,
strategy_version: @strategy.version,
digest: @strategy.digest,
embeddings: @embeddings,
)
rescue PG::Error => e
Rails.logger.error(
"Error #{e} persisting embedding for topic #{topic.id} and model #{model.name}",
)
end
end
def table_suffix
"#{@model.id}_#{@strategy.id}"
end
def topic_embeddings_table
"ai_topic_embeddings_#{table_suffix}"
end
end
end
end

View File

@ -1,94 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
class Model
AVAILABLE_MODELS_TEMPLATES = {
"all-mpnet-base-v2" => [768, 384, %i[dot cosine euclidean], %i[symmetric], "discourse"],
"all-distilroberta-v1" => [768, 512, %i[dot cosine euclidean], %i[symmetric], "discourse"],
"multi-qa-mpnet-base-dot-v1" => [768, 512, %i[dot], %i[symmetric], "discourse"],
"paraphrase-multilingual-mpnet-base-v2" => [
768,
128,
%i[cosine],
%i[symmetric],
"discourse",
],
"msmarco-distilbert-base-tas-b" => [768, 512, %i[dot], %i[asymmetric], "discourse"],
"msmarco-distilbert-base-v4" => [768, 512, %i[cosine], %i[asymmetric], "discourse"],
"instructor-xl" => [768, 512, %i[cosine], %i[symmetric asymmetric], "discourse"],
"text-embedding-ada-002" => [1536, 2048, %i[cosine], %i[symmetric asymmetric], "openai"],
}
SEARCH_FUNCTION_TO_PG_INDEX = {
dot: "vector_ip_ops",
cosine: "vector_cosine_ops",
euclidean: "vector_l2_ops",
}
SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" }
class << self
def instantiate(model_name)
new(model_name, *AVAILABLE_MODELS_TEMPLATES[model_name])
end
def enabled_models
SiteSetting
.ai_embeddings_models
.split("|")
.map { |model_name| instantiate(model_name.strip) }
end
end
def initialize(name, dimensions, max_sequence_lenght, functions, type, provider)
@name = name
@dimensions = dimensions
@max_sequence_lenght = max_sequence_lenght
@functions = functions
@type = type
@provider = provider
end
def generate_embedding(input)
send("#{provider}_embeddings", input)
end
def pg_function
SEARCH_FUNCTION_TO_PG_FUNCTION[functions.first]
end
def pg_index
SEARCH_FUNCTION_TO_PG_INDEX[functions.first]
end
attr_reader :name, :dimensions, :max_sequence_lenght, :functions, :type, :provider
private
def discourse_embeddings(input)
truncated_input = DiscourseAi::Tokenizer::BertTokenizer.truncate(input, max_sequence_lenght)
if name.start_with?("instructor")
truncated_input = [
[SiteSetting.ai_embeddings_semantic_related_instruction, truncated_input],
]
end
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
name.to_s,
truncated_input,
SiteSetting.ai_embeddings_discourse_service_api_key,
)
end
def openai_embeddings(input)
truncated_input =
DiscourseAi::Tokenizer::OpenAiTokenizer.truncate(input, max_sequence_lenght)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(truncated_input)
response[:data].first[:embedding]
end
end
end
end

View File

@ -0,0 +1,52 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module Models
class AllMpnetBaseV2 < Base
class << self
def id
1
end
def version
1
end
def name
"all-mpnet-base-v2"
end
def dimensions
768
end
def max_sequence_length
384
end
def pg_function
"<#>"
end
def pg_index_type
"vector_ip_ops"
end
def generate_embeddings(text)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
name,
text,
SiteSetting.ai_embeddings_discourse_service_api_key,
)
end
def tokenizer
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
end
end
end
end
end
end

View File

@ -0,0 +1,10 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module Models
class Base
end
end
end
end

View File

@ -0,0 +1,48 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module Models
class TextEmbeddingAda002 < Base
class << self
def id
2
end
def version
1
end
def name
"text-embedding-ada-002"
end
def dimensions
1536
end
def max_sequence_length
8191
end
def pg_function
"<=>"
end
def pg_index_type
"vector_cosine_ops"
end
def generate_embeddings(text)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text)
response[:data].first[:embedding]
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end
end
end
end
end

View File

@ -3,69 +3,112 @@
module DiscourseAi
module Embeddings
class SemanticRelated
def self.semantic_suggested_key(topic_id)
"semantic-suggested-topic-#{topic_id}"
end
MissingEmbeddingError = Class.new(StandardError)
def self.build_semantic_suggested_key(topic_id)
"build-semantic-suggested-topic-#{topic_id}"
end
def self.clear_cache_for(topic)
Discourse.cache.delete(semantic_suggested_key(topic.id))
Discourse.redis.del(build_semantic_suggested_key(topic.id))
end
def self.candidates_for(topic)
return ::Topic.none if SiteSetting.ai_embeddings_semantic_related_topics < 1
cache_for =
case topic.created_at
when 6.hour.ago..Time.now
15.minutes
when 1.day.ago..6.hour.ago
1.hour
else
1.day
end
model =
DiscourseAi::Embeddings::Model.instantiate(
SiteSetting.ai_embeddings_semantic_related_model,
)
begin
candidate_ids =
Discourse
.cache
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
DiscourseAi::Embeddings::Topic.new.symmetric_semantic_search(model, topic)
end
rescue MissingEmbeddingError
# avoid a flood of jobs when visiting topic
if Discourse.redis.set(
build_semantic_suggested_key(topic.id),
"queued",
ex: 15.minutes.to_i,
nx: true,
)
Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
end
return ::Topic.none
class << self
def semantic_suggested_key(topic_id)
"semantic-suggested-topic-#{topic_id}"
end
topic_list = ::Topic.visible.listable_topics.secured
unless SiteSetting.ai_embeddings_semantic_related_include_closed_topics
topic_list = topic_list.where(closed: false)
def build_semantic_suggested_key(topic_id)
"build-semantic-suggested-topic-#{topic_id}"
end
topic_list
.where("id <> ?", topic.id)
.where(id: candidate_ids)
# array_position forces the order of the topics to be preserved
.order("array_position(ARRAY#{candidate_ids}, id)")
.limit(SiteSetting.ai_embeddings_semantic_related_topics)
def clear_cache_for(topic)
Discourse.cache.delete(semantic_suggested_key(topic.id))
Discourse.redis.del(build_semantic_suggested_key(topic.id))
end
def candidates_for(topic)
return ::Topic.none if SiteSetting.ai_embeddings_semantic_related_topics < 1
manager = DiscourseAi::Embeddings::Manager.new(topic)
cache_for =
case topic.created_at
when 6.hour.ago..Time.now
15.minutes
when 1.day.ago..6.hour.ago
1.hour
else
1.day
end
begin
candidate_ids =
Discourse
.cache
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
self.symmetric_semantic_search(manager)
end
rescue MissingEmbeddingError
# avoid a flood of jobs when visiting topic
if Discourse.redis.set(
build_semantic_suggested_key(topic.id),
"queued",
ex: 15.minutes.to_i,
nx: true,
)
Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
end
return ::Topic.none
end
topic_list = ::Topic.visible.listable_topics.secured
unless SiteSetting.ai_embeddings_semantic_related_include_closed_topics
topic_list = topic_list.where(closed: false)
end
topic_list
.where("id <> ?", topic.id)
.where(id: candidate_ids)
# array_position forces the order of the topics to be preserved
.order("array_position(ARRAY#{candidate_ids}, id)")
.limit(SiteSetting.ai_embeddings_semantic_related_topics)
end
def symmetric_semantic_search(manager)
topic = manager.target
candidate_ids = self.query_symmetric_embeddings(manager)
# Happens when the topic doesn't have any embeddings
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}"
end
candidate_ids
end
def query_symmetric_embeddings(manager)
topic = manager.target
model = manager.model
table = manager.topic_embeddings_table
begin
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
SELECT
topic_id
FROM
#{table}
ORDER BY
embeddings #{model.pg_function} (
SELECT
embeddings
FROM
#{table}
WHERE
topic_id = :topic_id
LIMIT 1
)
LIMIT 100
SQL
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}",
)
raise MissingEmbeddingError
end
end
end
end
end

View File

@ -3,17 +3,17 @@
module DiscourseAi
module Embeddings
class SemanticSearch
def initialize(guardian, model)
def initialize(guardian)
@guardian = guardian
@model = model
@manager = DiscourseAi::Embeddings::Manager.new(nil)
@model = @manager.model
end
def search_for_topics(query, page = 1)
limit = Search.per_filter + 1
offset = (page - 1) * Search.per_filter
candidate_ids =
DiscourseAi::Embeddings::Topic.new.asymmetric_semantic_search(model, query, limit, offset)
candidate_ids = asymmetric_semantic_search(query, limit, offset)
::Post
.where(post_type: ::Topic.visible_post_types(guardian.user))
@ -23,6 +23,34 @@ module DiscourseAi
.order("array_position(ARRAY#{candidate_ids}, topic_id)")
end
def asymmetric_semantic_search(query, limit, offset)
embedding = model.generate_embeddings(query)
table = @manager.topic_embeddings_table
begin
candidate_ids =
DB.query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset).map(
SELECT
topic_id
FROM
#{table}
ORDER BY
embeddings #{@model.pg_function} '[:query_embedding]'
LIMIT :limit
OFFSET :offset
SQL
&:topic_id
)
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for model #{model.name} and search #{query}",
)
raise MissingEmbeddingError
end
candidate_ids
end
private
attr_reader :model, :guardian

View File

@ -0,0 +1,82 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module Strategies
class Truncation
attr_reader :processed_target, :digest
def self.id
1
end
def id
self.class.id
end
def version
1
end
def initialize(target, model)
@model = model
@target = target
@tokenizer = @model.tokenizer
@max_length = @model.max_sequence_length
@processed_target = +""
end
# Need a better name for this method
def process!
case @target
when Topic
topic_truncation(@target)
when Post
post_truncation(@target)
else
raise ArgumentError, "Invalid target type"
end
@digest = OpenSSL::Digest::SHA1.hexdigest(@processed_target)
end
def topic_truncation(topic)
t = @processed_target
t << topic.title
t << "\n\n"
t << topic.category.name
if SiteSetting.tagging_enabled
t << "\n\n"
t << topic.tags.pluck(:name).join(", ")
end
t << "\n\n"
topic.posts.each do |post|
t << post.raw
break if @tokenizer.size(t) >= @max_length
t << "\n\n"
end
@tokenizer.truncate(t, @max_length)
end
def post_truncation(post)
t = processed_target
t << post.topic.title
t << "\n\n"
t << post.topic.category.name
if SiteSetting.tagging_enabled
t << "\n\n"
t << post.topic.tags.pluck(:name).join(", ")
end
t << "\n\n"
t << post.raw
@tokenizer.truncate(t, @max_length)
end
end
end
end
end

View File

@ -1,110 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
MissingEmbeddingError = Class.new(StandardError)
class Topic
def generate_and_store_embeddings_for(topic)
return unless SiteSetting.ai_embeddings_enabled
return if topic.blank? || topic.first_post.blank?
enabled_models = DiscourseAi::Embeddings::Model.enabled_models
return if enabled_models.empty?
enabled_models.each do |model|
embedding = model.generate_embedding(topic.first_post.raw)
persist_embedding(topic, model, embedding) if embedding
end
end
def symmetric_semantic_search(model, topic)
candidate_ids = query_symmetric_embeddings(model, topic)
# Happens when the topic doesn't have any embeddings
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}"
end
candidate_ids
end
def asymmetric_semantic_search(model, query, limit, offset)
embedding = model.generate_embedding(query)
begin
candidate_ids =
DiscourseAi::Database::Connection
.db
.query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset)
SELECT
topic_id
FROM
topic_embeddings_#{model.name.underscore}
ORDER BY
embedding #{model.pg_function} '[:query_embedding]'
LIMIT :limit
OFFSET :offset
SQL
.map(&:topic_id)
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}",
)
raise MissingEmbeddingError
end
candidate_ids
end
private
def query_symmetric_embeddings(model, topic)
begin
DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
SELECT
topic_id
FROM
topic_embeddings_#{model.name.underscore}
ORDER BY
embedding #{model.pg_function} (
SELECT
embedding
FROM
topic_embeddings_#{model.name.underscore}
WHERE
topic_id = :topic_id
LIMIT 1
)
LIMIT 100
SQL
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}",
)
raise MissingEmbeddingError
end
end
def persist_embedding(topic, model, embedding)
begin
DiscourseAi::Database::Connection.db.exec(
<<~SQL,
INSERT INTO topic_embeddings_#{model.name.underscore} (topic_id, embedding)
VALUES (:topic_id, '[:embedding]')
ON CONFLICT (topic_id)
DO UPDATE SET embedding = '[:embedding]'
SQL
topic_id: topic.id,
embedding: embedding,
)
rescue PG::Error => e
Rails.logger.error(
"Error #{e} persisting embedding for topic #{topic.id} and model #{model.name}",
)
end
end
end
end
end

View File

@ -45,6 +45,13 @@ module DiscourseAi
end
end
class AllMpnetBaseV2Tokenizer < BasicTokenizer
def self.tokenizer
@@tokenizer ||=
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/all-mpnet-base-v2.json")
end
end
class OpenAiTokenizer < BasicTokenizer
class << self
def tokenizer

View File

@ -7,8 +7,8 @@
# url: https://meta.discourse.org/t/discourse-ai/259214
# required_version: 2.7.0
gem "tokenizers", "0.3.2", platform: RUBY_PLATFORM
gem "tiktoken_ruby", "0.0.5", platform: RUBY_PLATFORM
gem "tokenizers", "0.3.2"
gem "tiktoken_ruby", "0.0.5"
enabled_site_setting :discourse_ai_enabled
@ -65,4 +65,13 @@ after_initialize do
on(:reviewable_transitioned_to) do |new_status, reviewable|
ModelAccuracy.adjust_model_accuracy(new_status, reviewable)
end
if DB.query_single("SELECT 1 FROM pg_available_extensions WHERE name = 'vector';").empty?
STDERR.puts "------------------------------DISCOURSE AI ERROR----------------------------------"
STDERR.puts " Discourse AI requires the pgvector extension on the PostgreSQL database."
STDERR.puts " Run a `./launcher rebuild app` to fix it on a standard install."
STDERR.puts " Alternatively, you can remove Discourse AI to rebuild."
STDERR.puts "------------------------------DISCOURSE AI ERROR----------------------------------"
exit 1
end
end

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,44 @@
# frozen_string_literal: true
require_relative "../../support/embeddings_generation_stubs"
RSpec.describe DiscourseAi::Embeddings::Manager do
let(:user) { Fabricate(:user) }
let(:expected_embedding) do
JSON.parse(
File.read("#{Rails.root}/plugins/discourse-ai/spec/fixtures/embeddings/embedding.txt"),
)
end
let(:discourse_model) { "all-mpnet-base-v2" }
before do
SiteSetting.discourse_ai_enabled = true
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_model = "all-mpnet-base-v2"
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
Jobs.run_immediately!
end
it "generates embeddings for new topics automatically" do
pc =
PostCreator.new(
user,
raw: "this is the new content for my topic",
title: "this is my new topic title",
)
input =
"This is my new topic title\n\nUncategorized\n\n\n\nthis is the new content for my topic\n\n"
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
post = pc.create
manager = DiscourseAi::Embeddings::Manager.new(post.topic)
embeddings =
DB.query_single(
"SELECT embeddings FROM #{manager.topic_embeddings_table} WHERE topic_id = #{post.topic.id}",
).first
expect(embeddings.split(",")[1].to_f).to be_within(0.0001).of(expected_embedding[1])
expect(embeddings.split(",")[13].to_f).to be_within(0.0001).of(expected_embedding[13])
expect(embeddings.split(",")[135].to_f).to be_within(0.0001).of(expected_embedding[135])
end
end

View File

@ -1,36 +0,0 @@
# frozen_string_literal: true
require_relative "../../../support/embeddings_generation_stubs"
RSpec.describe DiscourseAi::Embeddings::Model do
describe "#generate_embedding" do
let(:input) { "test" }
let(:expected_embedding) { [0.0038493, 0.482001] }
context "when the model uses the discourse service to create embeddings" do
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
let(:discourse_model) { "all-mpnet-base-v2" }
it "returns an embedding for a given string" do
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
embedding = described_class.instantiate(discourse_model).generate_embedding(input)
expect(embedding).to contain_exactly(*expected_embedding)
end
end
context "when the model uses OpenAI to create embeddings" do
let(:openai_model) { "text-embedding-ada-002" }
it "returns an embedding for a given string" do
EmbeddingsGenerationStubs.openai_service(openai_model, input, expected_embedding)
embedding = described_class.instantiate(openai_model).generate_embedding(input)
expect(embedding).to contain_exactly(*expected_embedding)
end
end
end
end

View File

@ -0,0 +1,24 @@
# frozen_string_literal: true
require_relative "../../../../support/embeddings_generation_stubs"
RSpec.describe DiscourseAi::Embeddings::Models::AllMpnetBaseV2 do
describe "#generate_embeddings" do
let(:input) { "test" }
let(:expected_embedding) { [0.0038493, 0.482001] }
context "when the model uses the discourse service to create embeddings" do
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
let(:discourse_model) { "all-mpnet-base-v2" }
it "returns an embedding for a given string" do
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
embedding = described_class.generate_embeddings(input)
expect(embedding).to contain_exactly(*expected_embedding)
end
end
end
end

View File

@ -0,0 +1,22 @@
# frozen_string_literal: true
require_relative "../../../../support/embeddings_generation_stubs"
RSpec.describe DiscourseAi::Embeddings::Models::TextEmbeddingAda002 do
describe "#generate_embeddings" do
let(:input) { "test" }
let(:expected_embedding) { [0.0038493, 0.482001] }
context "when the model uses OpenAI to create embeddings" do
let(:openai_model) { "text-embedding-ada-002" }
it "returns an embedding for a given string" do
EmbeddingsGenerationStubs.openai_service(openai_model, input, expected_embedding)
embedding = described_class.generate_embeddings(input)
expect(embedding).to contain_exactly(*expected_embedding)
end
end
end
end

View File

@ -24,13 +24,6 @@ describe DiscourseAi::Embeddings::SemanticRelated do
end
it "queues job only once per 15 minutes" do
# sadly we need to mock a DB connection to return nothing
DiscourseAi::Embeddings::Topic
.any_instance
.expects(:query_symmetric_embeddings)
.returns([])
.twice
results = nil
expect_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
@ -49,10 +42,9 @@ describe DiscourseAi::Embeddings::SemanticRelated do
context "when embeddings exist" do
before do
Discourse.cache.clear
DiscourseAi::Embeddings::Topic
.any_instance
.expects(:symmetric_semantic_search)
.returns(Topic.unscoped.order(id: :desc).limit(100).pluck(:id))
DiscourseAi::Embeddings::SemanticRelated.expects(:symmetric_semantic_search).returns(
Topic.unscoped.order(id: :desc).limit(100).pluck(:id),
)
end
after { Discourse.cache.clear }

View File

@ -3,15 +3,13 @@
RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
fab!(:post) { Fabricate(:post) }
fab!(:user) { Fabricate(:user) }
let(:model_name) { "msmarco-distilbert-base-v4" }
let(:query) { "test_query" }
let(:model) { DiscourseAi::Embeddings::Model.instantiate(model_name) }
let(:subject) { described_class.new(Guardian.new(user), model) }
let(:query) { "test_query" }
let(:subject) { described_class.new(Guardian.new(user)) }
describe "#search_for_topics" do
def stub_candidate_ids(candidate_ids)
DiscourseAi::Embeddings::Topic
DiscourseAi::Embeddings::SemanticSearch
.any_instance
.expects(:asymmetric_semantic_search)
.returns(candidate_ids)

View File

@ -19,10 +19,9 @@ describe ::TopicsController do
context "when a user is logged on" do
it "includes related topics in payload when configured" do
DiscourseAi::Embeddings::Topic
.any_instance
.expects(:symmetric_semantic_search)
.returns([topic1.id, topic2.id, topic3.id])
DiscourseAi::Embeddings::SemanticRelated.expects(:symmetric_semantic_search).returns(
[topic1.id, topic2.id, topic3.id],
)
get("#{topic.relative_url}.json")
expect(response.status).to eq(200)
@ -39,16 +38,5 @@ describe ::TopicsController do
expect(json["suggested_topics"].length).to eq(0)
expect(json["related_topics"].length).to eq(2)
end
it "excludes embeddings when the database is offline" do
DiscourseAi::Database::Connection.stubs(:db).raises(PG::ConnectionBad)
get "#{topic.relative_url}.json"
expect(response.status).to eq(200)
json = response.parsed_body
expect(json["suggested_topics"].length).not_to eq(0)
expect(json["related_topics"].length).to eq(0)
end
end
end

View File

@ -5,3 +5,7 @@ Licensed under Apache License
## claude-v1-tokenization.json
Licensed under MIT License
## all-mpnet-base-v2.json
Licensed under Apache License

File diff suppressed because one or more lines are too long