FEATURE: Embeddings to main db (#99)
* FEATURE: Embeddings to main db This commit moves our embeddings store from an external configurable PostgreSQL instance back into the main database. This is done to simplify the setup. There is a migration that will try to import the external embeddings into the main DB if it is configured and there are rows. It removes support from embeddings models that aren't all_mpnet_base_v2 or OpenAI text_embedding_ada_002. However it will now be easier to add new models. It also now takes into account: - topic title - topic category - topic tags - replies (as much as the model allows) We introduce an interface so we can eventually support multiple strategies for handling long topics. This PR severely damages the semantic search performance, but this is a temporary until we can get adapt HyDE to make semantic search use the same embeddings we have for semantic related with good performance. Here we also have some ground work to add post level embeddings, but this will be added in a future PR. Please note that this PR will also block Discourse from booting / updating if this plugin is installed and the pgvector extension isn't available on the PostgreSQL instance Discourse uses.
This commit is contained in:
parent
9d10a152b9
commit
5e3f4e1b78
|
@ -19,13 +19,8 @@ module DiscourseAi
|
||||||
use_pg_headlines_for_excerpt: false,
|
use_pg_headlines_for_excerpt: false,
|
||||||
)
|
)
|
||||||
|
|
||||||
model =
|
|
||||||
DiscourseAi::Embeddings::Model.instantiate(
|
|
||||||
SiteSetting.ai_embeddings_semantic_search_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
DiscourseAi::Embeddings::SemanticSearch
|
DiscourseAi::Embeddings::SemanticSearch
|
||||||
.new(guardian, model)
|
.new(guardian)
|
||||||
.search_for_topics(query, page)
|
.search_for_topics(query, page)
|
||||||
.each { |topic_post| grouped_results.add(topic_post) }
|
.each { |topic_post| grouped_results.add(topic_post) }
|
||||||
|
|
||||||
|
|
|
@ -143,39 +143,14 @@ plugins:
|
||||||
ai_embeddings_discourse_service_api_key:
|
ai_embeddings_discourse_service_api_key:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
secret: true
|
||||||
ai_embeddings_models:
|
ai_embeddings_model:
|
||||||
type: list
|
type: enum
|
||||||
list_type: compact
|
list_type: compact
|
||||||
default: ""
|
default: "all-mpnet-base-v2"
|
||||||
allow_any: false
|
allow_any: false
|
||||||
choices:
|
choices:
|
||||||
- all-mpnet-base-v2
|
- all-mpnet-base-v2
|
||||||
- all-distilroberta-v1
|
|
||||||
- multi-qa-mpnet-base-dot-v1
|
|
||||||
- paraphrase-multilingual-mpnet-base-v2
|
|
||||||
- msmarco-distilbert-base-v4
|
|
||||||
- msmarco-distilbert-base-tas-b
|
|
||||||
- text-embedding-ada-002
|
- text-embedding-ada-002
|
||||||
ai_embeddings_semantic_related_model:
|
|
||||||
type: enum
|
|
||||||
default: all-mpnet-base-v2
|
|
||||||
choices:
|
|
||||||
- all-mpnet-base-v2
|
|
||||||
- text-embedding-ada-002
|
|
||||||
- all-distilroberta-v1
|
|
||||||
- multi-qa-mpnet-base-dot-v1
|
|
||||||
- paraphrase-multilingual-mpnet-base-v2
|
|
||||||
ai_embeddings_semantic_search_model:
|
|
||||||
type: enum
|
|
||||||
default: msmarco-distilbert-base-v4
|
|
||||||
choices:
|
|
||||||
- msmarco-distilbert-base-v4
|
|
||||||
- msmarco-distilbert-base-tas-b
|
|
||||||
- text-embedding-ada-002
|
|
||||||
ai_embeddings_semantic_related_instruction:
|
|
||||||
default: "Represent the Discourse topic for retrieving relevant topics:"
|
|
||||||
hidden: true
|
|
||||||
client: false
|
|
||||||
ai_embeddings_generate_for_pms: false
|
ai_embeddings_generate_for_pms: false
|
||||||
ai_embeddings_semantic_related_topics_enabled: false
|
ai_embeddings_semantic_related_topics_enabled: false
|
||||||
ai_embeddings_semantic_related_topics: 5
|
ai_embeddings_semantic_related_topics: 5
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class EnablePgVectorExtension < ActiveRecord::Migration[7.0]
|
||||||
|
def change
|
||||||
|
enable_extension :vector
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,28 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class CreateAiTopicEmbeddingsTable < ActiveRecord::Migration[7.0]
|
||||||
|
def change
|
||||||
|
models = [
|
||||||
|
DiscourseAi::Embeddings::Models::AllMpnetBaseV2,
|
||||||
|
DiscourseAi::Embeddings::Models::TextEmbeddingAda002,
|
||||||
|
]
|
||||||
|
strategies = [DiscourseAi::Embeddings::Strategies::Truncation]
|
||||||
|
|
||||||
|
models.each do |model|
|
||||||
|
strategies.each do |strategy|
|
||||||
|
table_name = "ai_topic_embeddings_#{model.id}_#{strategy.id}".to_sym
|
||||||
|
|
||||||
|
create_table table_name, id: false do |t|
|
||||||
|
t.integer :topic_id, null: false
|
||||||
|
t.integer :model_version, null: false
|
||||||
|
t.integer :strategy_version, null: false
|
||||||
|
t.text :digest, null: false
|
||||||
|
t.column :embeddings, "vector(#{model.dimensions})", null: false
|
||||||
|
t.timestamps
|
||||||
|
|
||||||
|
t.index :topic_id, unique: true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,63 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class MigrateEmbeddingsFromDedicatedDatabase < ActiveRecord::Migration[7.0]
|
||||||
|
def up
|
||||||
|
return unless SiteSetting.ai_embeddings_enabled
|
||||||
|
return unless SiteSetting.ai_embeddings_pg_connection_string.present?
|
||||||
|
|
||||||
|
models = [
|
||||||
|
DiscourseAi::Embeddings::Models::AllMpnetBaseV2,
|
||||||
|
DiscourseAi::Embeddings::Models::TextEmbeddingAda002,
|
||||||
|
]
|
||||||
|
strategies = [DiscourseAi::Embeddings::Strategies::Truncation]
|
||||||
|
|
||||||
|
models.each do |model|
|
||||||
|
strategies.each do |strategy|
|
||||||
|
new_table_name = "ai_topic_embeddings_#{model.id}_#{strategy.id}"
|
||||||
|
old_table_name = "topic_embeddings_#{model.name.underscore}"
|
||||||
|
|
||||||
|
begin
|
||||||
|
row_count =
|
||||||
|
DiscourseAi::Database::Connection
|
||||||
|
.db
|
||||||
|
.query_single("SELECT COUNT(*) FROM #{old_table_name}")
|
||||||
|
.first
|
||||||
|
|
||||||
|
if row_count > 0
|
||||||
|
puts "Migrating #{row_count} embeddings from #{old_table_name} to #{new_table_name}"
|
||||||
|
|
||||||
|
last_topic_id = 0
|
||||||
|
|
||||||
|
loop do
|
||||||
|
batch = DiscourseAi::Database::Connection.db.query(<<-SQL)
|
||||||
|
SELECT topic_id, embedding
|
||||||
|
FROM #{old_table_name}
|
||||||
|
WHERE topic_id > #{last_topic_id}
|
||||||
|
ORDER BY topic_id ASC
|
||||||
|
LIMIT 50
|
||||||
|
SQL
|
||||||
|
break if batch.empty?
|
||||||
|
|
||||||
|
DB.exec(<<-SQL)
|
||||||
|
INSERT INTO #{new_table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
|
||||||
|
VALUES #{batch.map { |r| "(#{r.topic_id}, 0, 0, '', '#{r.embedding}', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)" }.join(", ")}
|
||||||
|
ON CONFLICT (topic_id)
|
||||||
|
DO NOTHING
|
||||||
|
SQL
|
||||||
|
|
||||||
|
last_topic_id = batch.last.topic_id
|
||||||
|
end
|
||||||
|
end
|
||||||
|
rescue PG::Error => e
|
||||||
|
Rails.logger.error(
|
||||||
|
"Error #{e} migrating embeddings from #{old_table_name} to #{new_table_name}",
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def down
|
||||||
|
# no-op
|
||||||
|
end
|
||||||
|
end
|
|
@ -4,8 +4,11 @@ module DiscourseAi
|
||||||
module Embeddings
|
module Embeddings
|
||||||
class EntryPoint
|
class EntryPoint
|
||||||
def load_files
|
def load_files
|
||||||
require_relative "model"
|
require_relative "models/base"
|
||||||
require_relative "topic"
|
require_relative "models/all_mpnet_base_v2"
|
||||||
|
require_relative "models/text_embedding_ada_002"
|
||||||
|
require_relative "strategies/truncation"
|
||||||
|
require_relative "manager"
|
||||||
require_relative "jobs/regular/generate_embeddings"
|
require_relative "jobs/regular/generate_embeddings"
|
||||||
require_relative "semantic_related"
|
require_relative "semantic_related"
|
||||||
require_relative "semantic_search"
|
require_relative "semantic_search"
|
||||||
|
|
|
@ -11,7 +11,7 @@ module Jobs
|
||||||
post = topic.first_post
|
post = topic.first_post
|
||||||
return if post.nil? || post.raw.blank?
|
return if post.nil? || post.raw.blank?
|
||||||
|
|
||||||
DiscourseAi::Embeddings::Topic.new.generate_and_store_embeddings_for(topic)
|
DiscourseAi::Embeddings::Manager.new(topic).generate!
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
class Manager
|
||||||
|
attr_reader :target, :model, :strategy
|
||||||
|
|
||||||
|
def initialize(target)
|
||||||
|
@target = target
|
||||||
|
@model =
|
||||||
|
DiscourseAi::Embeddings::Models::Base.subclasses.find do
|
||||||
|
_1.name == SiteSetting.ai_embeddings_model
|
||||||
|
end
|
||||||
|
@strategy = DiscourseAi::Embeddings::Strategies::Truncation.new(@target, @model)
|
||||||
|
end
|
||||||
|
|
||||||
|
def generate!
|
||||||
|
@strategy.process!
|
||||||
|
|
||||||
|
# TODO bail here if we already have an embedding with matching version and digest
|
||||||
|
|
||||||
|
@embeddings = @model.generate_embeddings(@strategy.processed_target)
|
||||||
|
|
||||||
|
persist!
|
||||||
|
end
|
||||||
|
|
||||||
|
def persist!
|
||||||
|
begin
|
||||||
|
DB.exec(
|
||||||
|
<<~SQL,
|
||||||
|
INSERT INTO ai_topic_embeddings_#{table_suffix} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
|
||||||
|
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||||
|
ON CONFLICT (topic_id)
|
||||||
|
DO UPDATE SET
|
||||||
|
model_version = :model_version,
|
||||||
|
strategy_version = :strategy_version,
|
||||||
|
digest = :digest,
|
||||||
|
embeddings = '[:embeddings]',
|
||||||
|
updated_at = CURRENT_TIMESTAMP
|
||||||
|
|
||||||
|
SQL
|
||||||
|
topic_id: @target.id,
|
||||||
|
model_version: @model.version,
|
||||||
|
strategy_version: @strategy.version,
|
||||||
|
digest: @strategy.digest,
|
||||||
|
embeddings: @embeddings,
|
||||||
|
)
|
||||||
|
rescue PG::Error => e
|
||||||
|
Rails.logger.error(
|
||||||
|
"Error #{e} persisting embedding for topic #{topic.id} and model #{model.name}",
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def table_suffix
|
||||||
|
"#{@model.id}_#{@strategy.id}"
|
||||||
|
end
|
||||||
|
|
||||||
|
def topic_embeddings_table
|
||||||
|
"ai_topic_embeddings_#{table_suffix}"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -1,94 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
class Model
|
|
||||||
AVAILABLE_MODELS_TEMPLATES = {
|
|
||||||
"all-mpnet-base-v2" => [768, 384, %i[dot cosine euclidean], %i[symmetric], "discourse"],
|
|
||||||
"all-distilroberta-v1" => [768, 512, %i[dot cosine euclidean], %i[symmetric], "discourse"],
|
|
||||||
"multi-qa-mpnet-base-dot-v1" => [768, 512, %i[dot], %i[symmetric], "discourse"],
|
|
||||||
"paraphrase-multilingual-mpnet-base-v2" => [
|
|
||||||
768,
|
|
||||||
128,
|
|
||||||
%i[cosine],
|
|
||||||
%i[symmetric],
|
|
||||||
"discourse",
|
|
||||||
],
|
|
||||||
"msmarco-distilbert-base-tas-b" => [768, 512, %i[dot], %i[asymmetric], "discourse"],
|
|
||||||
"msmarco-distilbert-base-v4" => [768, 512, %i[cosine], %i[asymmetric], "discourse"],
|
|
||||||
"instructor-xl" => [768, 512, %i[cosine], %i[symmetric asymmetric], "discourse"],
|
|
||||||
"text-embedding-ada-002" => [1536, 2048, %i[cosine], %i[symmetric asymmetric], "openai"],
|
|
||||||
}
|
|
||||||
|
|
||||||
SEARCH_FUNCTION_TO_PG_INDEX = {
|
|
||||||
dot: "vector_ip_ops",
|
|
||||||
cosine: "vector_cosine_ops",
|
|
||||||
euclidean: "vector_l2_ops",
|
|
||||||
}
|
|
||||||
|
|
||||||
SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" }
|
|
||||||
|
|
||||||
class << self
|
|
||||||
def instantiate(model_name)
|
|
||||||
new(model_name, *AVAILABLE_MODELS_TEMPLATES[model_name])
|
|
||||||
end
|
|
||||||
|
|
||||||
def enabled_models
|
|
||||||
SiteSetting
|
|
||||||
.ai_embeddings_models
|
|
||||||
.split("|")
|
|
||||||
.map { |model_name| instantiate(model_name.strip) }
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def initialize(name, dimensions, max_sequence_lenght, functions, type, provider)
|
|
||||||
@name = name
|
|
||||||
@dimensions = dimensions
|
|
||||||
@max_sequence_lenght = max_sequence_lenght
|
|
||||||
@functions = functions
|
|
||||||
@type = type
|
|
||||||
@provider = provider
|
|
||||||
end
|
|
||||||
|
|
||||||
def generate_embedding(input)
|
|
||||||
send("#{provider}_embeddings", input)
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_function
|
|
||||||
SEARCH_FUNCTION_TO_PG_FUNCTION[functions.first]
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_index
|
|
||||||
SEARCH_FUNCTION_TO_PG_INDEX[functions.first]
|
|
||||||
end
|
|
||||||
|
|
||||||
attr_reader :name, :dimensions, :max_sequence_lenght, :functions, :type, :provider
|
|
||||||
|
|
||||||
private
|
|
||||||
|
|
||||||
def discourse_embeddings(input)
|
|
||||||
truncated_input = DiscourseAi::Tokenizer::BertTokenizer.truncate(input, max_sequence_lenght)
|
|
||||||
|
|
||||||
if name.start_with?("instructor")
|
|
||||||
truncated_input = [
|
|
||||||
[SiteSetting.ai_embeddings_semantic_related_instruction, truncated_input],
|
|
||||||
]
|
|
||||||
end
|
|
||||||
|
|
||||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
|
||||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
|
||||||
name.to_s,
|
|
||||||
truncated_input,
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
def openai_embeddings(input)
|
|
||||||
truncated_input =
|
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer.truncate(input, max_sequence_lenght)
|
|
||||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(truncated_input)
|
|
||||||
response[:data].first[:embedding]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module Models
|
||||||
|
class AllMpnetBaseV2 < Base
|
||||||
|
class << self
|
||||||
|
def id
|
||||||
|
1
|
||||||
|
end
|
||||||
|
|
||||||
|
def version
|
||||||
|
1
|
||||||
|
end
|
||||||
|
|
||||||
|
def name
|
||||||
|
"all-mpnet-base-v2"
|
||||||
|
end
|
||||||
|
|
||||||
|
def dimensions
|
||||||
|
768
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_sequence_length
|
||||||
|
384
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_function
|
||||||
|
"<#>"
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_index_type
|
||||||
|
"vector_ip_ops"
|
||||||
|
end
|
||||||
|
|
||||||
|
def generate_embeddings(text)
|
||||||
|
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||||
|
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||||
|
name,
|
||||||
|
text,
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,10 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module Models
|
||||||
|
class Base
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,48 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module Models
|
||||||
|
class TextEmbeddingAda002 < Base
|
||||||
|
class << self
|
||||||
|
def id
|
||||||
|
2
|
||||||
|
end
|
||||||
|
|
||||||
|
def version
|
||||||
|
1
|
||||||
|
end
|
||||||
|
|
||||||
|
def name
|
||||||
|
"text-embedding-ada-002"
|
||||||
|
end
|
||||||
|
|
||||||
|
def dimensions
|
||||||
|
1536
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_sequence_length
|
||||||
|
8191
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_function
|
||||||
|
"<=>"
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_index_type
|
||||||
|
"vector_cosine_ops"
|
||||||
|
end
|
||||||
|
|
||||||
|
def generate_embeddings(text)
|
||||||
|
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text)
|
||||||
|
response[:data].first[:embedding]
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -3,69 +3,112 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Embeddings
|
module Embeddings
|
||||||
class SemanticRelated
|
class SemanticRelated
|
||||||
def self.semantic_suggested_key(topic_id)
|
MissingEmbeddingError = Class.new(StandardError)
|
||||||
"semantic-suggested-topic-#{topic_id}"
|
|
||||||
end
|
|
||||||
|
|
||||||
def self.build_semantic_suggested_key(topic_id)
|
class << self
|
||||||
"build-semantic-suggested-topic-#{topic_id}"
|
def semantic_suggested_key(topic_id)
|
||||||
end
|
"semantic-suggested-topic-#{topic_id}"
|
||||||
|
|
||||||
def self.clear_cache_for(topic)
|
|
||||||
Discourse.cache.delete(semantic_suggested_key(topic.id))
|
|
||||||
Discourse.redis.del(build_semantic_suggested_key(topic.id))
|
|
||||||
end
|
|
||||||
|
|
||||||
def self.candidates_for(topic)
|
|
||||||
return ::Topic.none if SiteSetting.ai_embeddings_semantic_related_topics < 1
|
|
||||||
|
|
||||||
cache_for =
|
|
||||||
case topic.created_at
|
|
||||||
when 6.hour.ago..Time.now
|
|
||||||
15.minutes
|
|
||||||
when 1.day.ago..6.hour.ago
|
|
||||||
1.hour
|
|
||||||
else
|
|
||||||
1.day
|
|
||||||
end
|
|
||||||
|
|
||||||
model =
|
|
||||||
DiscourseAi::Embeddings::Model.instantiate(
|
|
||||||
SiteSetting.ai_embeddings_semantic_related_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
begin
|
|
||||||
candidate_ids =
|
|
||||||
Discourse
|
|
||||||
.cache
|
|
||||||
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
|
|
||||||
DiscourseAi::Embeddings::Topic.new.symmetric_semantic_search(model, topic)
|
|
||||||
end
|
|
||||||
rescue MissingEmbeddingError
|
|
||||||
# avoid a flood of jobs when visiting topic
|
|
||||||
if Discourse.redis.set(
|
|
||||||
build_semantic_suggested_key(topic.id),
|
|
||||||
"queued",
|
|
||||||
ex: 15.minutes.to_i,
|
|
||||||
nx: true,
|
|
||||||
)
|
|
||||||
Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
|
|
||||||
end
|
|
||||||
return ::Topic.none
|
|
||||||
end
|
end
|
||||||
|
|
||||||
topic_list = ::Topic.visible.listable_topics.secured
|
def build_semantic_suggested_key(topic_id)
|
||||||
|
"build-semantic-suggested-topic-#{topic_id}"
|
||||||
unless SiteSetting.ai_embeddings_semantic_related_include_closed_topics
|
|
||||||
topic_list = topic_list.where(closed: false)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
topic_list
|
def clear_cache_for(topic)
|
||||||
.where("id <> ?", topic.id)
|
Discourse.cache.delete(semantic_suggested_key(topic.id))
|
||||||
.where(id: candidate_ids)
|
Discourse.redis.del(build_semantic_suggested_key(topic.id))
|
||||||
# array_position forces the order of the topics to be preserved
|
end
|
||||||
.order("array_position(ARRAY#{candidate_ids}, id)")
|
|
||||||
.limit(SiteSetting.ai_embeddings_semantic_related_topics)
|
def candidates_for(topic)
|
||||||
|
return ::Topic.none if SiteSetting.ai_embeddings_semantic_related_topics < 1
|
||||||
|
|
||||||
|
manager = DiscourseAi::Embeddings::Manager.new(topic)
|
||||||
|
cache_for =
|
||||||
|
case topic.created_at
|
||||||
|
when 6.hour.ago..Time.now
|
||||||
|
15.minutes
|
||||||
|
when 1.day.ago..6.hour.ago
|
||||||
|
1.hour
|
||||||
|
else
|
||||||
|
1.day
|
||||||
|
end
|
||||||
|
|
||||||
|
begin
|
||||||
|
candidate_ids =
|
||||||
|
Discourse
|
||||||
|
.cache
|
||||||
|
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
|
||||||
|
self.symmetric_semantic_search(manager)
|
||||||
|
end
|
||||||
|
rescue MissingEmbeddingError
|
||||||
|
# avoid a flood of jobs when visiting topic
|
||||||
|
if Discourse.redis.set(
|
||||||
|
build_semantic_suggested_key(topic.id),
|
||||||
|
"queued",
|
||||||
|
ex: 15.minutes.to_i,
|
||||||
|
nx: true,
|
||||||
|
)
|
||||||
|
Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
|
||||||
|
end
|
||||||
|
return ::Topic.none
|
||||||
|
end
|
||||||
|
|
||||||
|
topic_list = ::Topic.visible.listable_topics.secured
|
||||||
|
|
||||||
|
unless SiteSetting.ai_embeddings_semantic_related_include_closed_topics
|
||||||
|
topic_list = topic_list.where(closed: false)
|
||||||
|
end
|
||||||
|
|
||||||
|
topic_list
|
||||||
|
.where("id <> ?", topic.id)
|
||||||
|
.where(id: candidate_ids)
|
||||||
|
# array_position forces the order of the topics to be preserved
|
||||||
|
.order("array_position(ARRAY#{candidate_ids}, id)")
|
||||||
|
.limit(SiteSetting.ai_embeddings_semantic_related_topics)
|
||||||
|
end
|
||||||
|
|
||||||
|
def symmetric_semantic_search(manager)
|
||||||
|
topic = manager.target
|
||||||
|
candidate_ids = self.query_symmetric_embeddings(manager)
|
||||||
|
|
||||||
|
# Happens when the topic doesn't have any embeddings
|
||||||
|
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
|
||||||
|
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
|
||||||
|
raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}"
|
||||||
|
end
|
||||||
|
|
||||||
|
candidate_ids
|
||||||
|
end
|
||||||
|
|
||||||
|
def query_symmetric_embeddings(manager)
|
||||||
|
topic = manager.target
|
||||||
|
model = manager.model
|
||||||
|
table = manager.topic_embeddings_table
|
||||||
|
begin
|
||||||
|
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
||||||
|
SELECT
|
||||||
|
topic_id
|
||||||
|
FROM
|
||||||
|
#{table}
|
||||||
|
ORDER BY
|
||||||
|
embeddings #{model.pg_function} (
|
||||||
|
SELECT
|
||||||
|
embeddings
|
||||||
|
FROM
|
||||||
|
#{table}
|
||||||
|
WHERE
|
||||||
|
topic_id = :topic_id
|
||||||
|
LIMIT 1
|
||||||
|
)
|
||||||
|
LIMIT 100
|
||||||
|
SQL
|
||||||
|
rescue PG::Error => e
|
||||||
|
Rails.logger.error(
|
||||||
|
"Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}",
|
||||||
|
)
|
||||||
|
raise MissingEmbeddingError
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,17 +3,17 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Embeddings
|
module Embeddings
|
||||||
class SemanticSearch
|
class SemanticSearch
|
||||||
def initialize(guardian, model)
|
def initialize(guardian)
|
||||||
@guardian = guardian
|
@guardian = guardian
|
||||||
@model = model
|
@manager = DiscourseAi::Embeddings::Manager.new(nil)
|
||||||
|
@model = @manager.model
|
||||||
end
|
end
|
||||||
|
|
||||||
def search_for_topics(query, page = 1)
|
def search_for_topics(query, page = 1)
|
||||||
limit = Search.per_filter + 1
|
limit = Search.per_filter + 1
|
||||||
offset = (page - 1) * Search.per_filter
|
offset = (page - 1) * Search.per_filter
|
||||||
|
|
||||||
candidate_ids =
|
candidate_ids = asymmetric_semantic_search(query, limit, offset)
|
||||||
DiscourseAi::Embeddings::Topic.new.asymmetric_semantic_search(model, query, limit, offset)
|
|
||||||
|
|
||||||
::Post
|
::Post
|
||||||
.where(post_type: ::Topic.visible_post_types(guardian.user))
|
.where(post_type: ::Topic.visible_post_types(guardian.user))
|
||||||
|
@ -23,6 +23,34 @@ module DiscourseAi
|
||||||
.order("array_position(ARRAY#{candidate_ids}, topic_id)")
|
.order("array_position(ARRAY#{candidate_ids}, topic_id)")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def asymmetric_semantic_search(query, limit, offset)
|
||||||
|
embedding = model.generate_embeddings(query)
|
||||||
|
table = @manager.topic_embeddings_table
|
||||||
|
|
||||||
|
begin
|
||||||
|
candidate_ids =
|
||||||
|
DB.query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset).map(
|
||||||
|
SELECT
|
||||||
|
topic_id
|
||||||
|
FROM
|
||||||
|
#{table}
|
||||||
|
ORDER BY
|
||||||
|
embeddings #{@model.pg_function} '[:query_embedding]'
|
||||||
|
LIMIT :limit
|
||||||
|
OFFSET :offset
|
||||||
|
SQL
|
||||||
|
&:topic_id
|
||||||
|
)
|
||||||
|
rescue PG::Error => e
|
||||||
|
Rails.logger.error(
|
||||||
|
"Error #{e} querying embeddings for model #{model.name} and search #{query}",
|
||||||
|
)
|
||||||
|
raise MissingEmbeddingError
|
||||||
|
end
|
||||||
|
|
||||||
|
candidate_ids
|
||||||
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
attr_reader :model, :guardian
|
attr_reader :model, :guardian
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module Strategies
|
||||||
|
class Truncation
|
||||||
|
attr_reader :processed_target, :digest
|
||||||
|
|
||||||
|
def self.id
|
||||||
|
1
|
||||||
|
end
|
||||||
|
|
||||||
|
def id
|
||||||
|
self.class.id
|
||||||
|
end
|
||||||
|
|
||||||
|
def version
|
||||||
|
1
|
||||||
|
end
|
||||||
|
|
||||||
|
def initialize(target, model)
|
||||||
|
@model = model
|
||||||
|
@target = target
|
||||||
|
@tokenizer = @model.tokenizer
|
||||||
|
@max_length = @model.max_sequence_length
|
||||||
|
@processed_target = +""
|
||||||
|
end
|
||||||
|
|
||||||
|
# Need a better name for this method
|
||||||
|
def process!
|
||||||
|
case @target
|
||||||
|
when Topic
|
||||||
|
topic_truncation(@target)
|
||||||
|
when Post
|
||||||
|
post_truncation(@target)
|
||||||
|
else
|
||||||
|
raise ArgumentError, "Invalid target type"
|
||||||
|
end
|
||||||
|
|
||||||
|
@digest = OpenSSL::Digest::SHA1.hexdigest(@processed_target)
|
||||||
|
end
|
||||||
|
|
||||||
|
def topic_truncation(topic)
|
||||||
|
t = @processed_target
|
||||||
|
|
||||||
|
t << topic.title
|
||||||
|
t << "\n\n"
|
||||||
|
t << topic.category.name
|
||||||
|
if SiteSetting.tagging_enabled
|
||||||
|
t << "\n\n"
|
||||||
|
t << topic.tags.pluck(:name).join(", ")
|
||||||
|
end
|
||||||
|
t << "\n\n"
|
||||||
|
|
||||||
|
topic.posts.each do |post|
|
||||||
|
t << post.raw
|
||||||
|
break if @tokenizer.size(t) >= @max_length
|
||||||
|
t << "\n\n"
|
||||||
|
end
|
||||||
|
|
||||||
|
@tokenizer.truncate(t, @max_length)
|
||||||
|
end
|
||||||
|
|
||||||
|
def post_truncation(post)
|
||||||
|
t = processed_target
|
||||||
|
|
||||||
|
t << post.topic.title
|
||||||
|
t << "\n\n"
|
||||||
|
t << post.topic.category.name
|
||||||
|
if SiteSetting.tagging_enabled
|
||||||
|
t << "\n\n"
|
||||||
|
t << post.topic.tags.pluck(:name).join(", ")
|
||||||
|
end
|
||||||
|
t << "\n\n"
|
||||||
|
t << post.raw
|
||||||
|
|
||||||
|
@tokenizer.truncate(t, @max_length)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -1,110 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
MissingEmbeddingError = Class.new(StandardError)
|
|
||||||
|
|
||||||
class Topic
|
|
||||||
def generate_and_store_embeddings_for(topic)
|
|
||||||
return unless SiteSetting.ai_embeddings_enabled
|
|
||||||
return if topic.blank? || topic.first_post.blank?
|
|
||||||
|
|
||||||
enabled_models = DiscourseAi::Embeddings::Model.enabled_models
|
|
||||||
return if enabled_models.empty?
|
|
||||||
|
|
||||||
enabled_models.each do |model|
|
|
||||||
embedding = model.generate_embedding(topic.first_post.raw)
|
|
||||||
persist_embedding(topic, model, embedding) if embedding
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def symmetric_semantic_search(model, topic)
|
|
||||||
candidate_ids = query_symmetric_embeddings(model, topic)
|
|
||||||
|
|
||||||
# Happens when the topic doesn't have any embeddings
|
|
||||||
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
|
|
||||||
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
|
|
||||||
raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}"
|
|
||||||
end
|
|
||||||
|
|
||||||
candidate_ids
|
|
||||||
end
|
|
||||||
|
|
||||||
def asymmetric_semantic_search(model, query, limit, offset)
|
|
||||||
embedding = model.generate_embedding(query)
|
|
||||||
|
|
||||||
begin
|
|
||||||
candidate_ids =
|
|
||||||
DiscourseAi::Database::Connection
|
|
||||||
.db
|
|
||||||
.query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset)
|
|
||||||
SELECT
|
|
||||||
topic_id
|
|
||||||
FROM
|
|
||||||
topic_embeddings_#{model.name.underscore}
|
|
||||||
ORDER BY
|
|
||||||
embedding #{model.pg_function} '[:query_embedding]'
|
|
||||||
LIMIT :limit
|
|
||||||
OFFSET :offset
|
|
||||||
SQL
|
|
||||||
.map(&:topic_id)
|
|
||||||
rescue PG::Error => e
|
|
||||||
Rails.logger.error(
|
|
||||||
"Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}",
|
|
||||||
)
|
|
||||||
raise MissingEmbeddingError
|
|
||||||
end
|
|
||||||
|
|
||||||
candidate_ids
|
|
||||||
end
|
|
||||||
|
|
||||||
private
|
|
||||||
|
|
||||||
def query_symmetric_embeddings(model, topic)
|
|
||||||
begin
|
|
||||||
DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
|
||||||
SELECT
|
|
||||||
topic_id
|
|
||||||
FROM
|
|
||||||
topic_embeddings_#{model.name.underscore}
|
|
||||||
ORDER BY
|
|
||||||
embedding #{model.pg_function} (
|
|
||||||
SELECT
|
|
||||||
embedding
|
|
||||||
FROM
|
|
||||||
topic_embeddings_#{model.name.underscore}
|
|
||||||
WHERE
|
|
||||||
topic_id = :topic_id
|
|
||||||
LIMIT 1
|
|
||||||
)
|
|
||||||
LIMIT 100
|
|
||||||
SQL
|
|
||||||
rescue PG::Error => e
|
|
||||||
Rails.logger.error(
|
|
||||||
"Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}",
|
|
||||||
)
|
|
||||||
raise MissingEmbeddingError
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def persist_embedding(topic, model, embedding)
|
|
||||||
begin
|
|
||||||
DiscourseAi::Database::Connection.db.exec(
|
|
||||||
<<~SQL,
|
|
||||||
INSERT INTO topic_embeddings_#{model.name.underscore} (topic_id, embedding)
|
|
||||||
VALUES (:topic_id, '[:embedding]')
|
|
||||||
ON CONFLICT (topic_id)
|
|
||||||
DO UPDATE SET embedding = '[:embedding]'
|
|
||||||
SQL
|
|
||||||
topic_id: topic.id,
|
|
||||||
embedding: embedding,
|
|
||||||
)
|
|
||||||
rescue PG::Error => e
|
|
||||||
Rails.logger.error(
|
|
||||||
"Error #{e} persisting embedding for topic #{topic.id} and model #{model.name}",
|
|
||||||
)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -45,6 +45,13 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
class AllMpnetBaseV2Tokenizer < BasicTokenizer
|
||||||
|
def self.tokenizer
|
||||||
|
@@tokenizer ||=
|
||||||
|
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/all-mpnet-base-v2.json")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
class OpenAiTokenizer < BasicTokenizer
|
class OpenAiTokenizer < BasicTokenizer
|
||||||
class << self
|
class << self
|
||||||
def tokenizer
|
def tokenizer
|
||||||
|
|
13
plugin.rb
13
plugin.rb
|
@ -7,8 +7,8 @@
|
||||||
# url: https://meta.discourse.org/t/discourse-ai/259214
|
# url: https://meta.discourse.org/t/discourse-ai/259214
|
||||||
# required_version: 2.7.0
|
# required_version: 2.7.0
|
||||||
|
|
||||||
gem "tokenizers", "0.3.2", platform: RUBY_PLATFORM
|
gem "tokenizers", "0.3.2"
|
||||||
gem "tiktoken_ruby", "0.0.5", platform: RUBY_PLATFORM
|
gem "tiktoken_ruby", "0.0.5"
|
||||||
|
|
||||||
enabled_site_setting :discourse_ai_enabled
|
enabled_site_setting :discourse_ai_enabled
|
||||||
|
|
||||||
|
@ -65,4 +65,13 @@ after_initialize do
|
||||||
on(:reviewable_transitioned_to) do |new_status, reviewable|
|
on(:reviewable_transitioned_to) do |new_status, reviewable|
|
||||||
ModelAccuracy.adjust_model_accuracy(new_status, reviewable)
|
ModelAccuracy.adjust_model_accuracy(new_status, reviewable)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
if DB.query_single("SELECT 1 FROM pg_available_extensions WHERE name = 'vector';").empty?
|
||||||
|
STDERR.puts "------------------------------DISCOURSE AI ERROR----------------------------------"
|
||||||
|
STDERR.puts " Discourse AI requires the pgvector extension on the PostgreSQL database."
|
||||||
|
STDERR.puts " Run a `./launcher rebuild app` to fix it on a standard install."
|
||||||
|
STDERR.puts " Alternatively, you can remove Discourse AI to rebuild."
|
||||||
|
STDERR.puts "------------------------------DISCOURSE AI ERROR----------------------------------"
|
||||||
|
exit 1
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,44 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "../../support/embeddings_generation_stubs"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Embeddings::Manager do
|
||||||
|
let(:user) { Fabricate(:user) }
|
||||||
|
let(:expected_embedding) do
|
||||||
|
JSON.parse(
|
||||||
|
File.read("#{Rails.root}/plugins/discourse-ai/spec/fixtures/embeddings/embedding.txt"),
|
||||||
|
)
|
||||||
|
end
|
||||||
|
let(:discourse_model) { "all-mpnet-base-v2" }
|
||||||
|
|
||||||
|
before do
|
||||||
|
SiteSetting.discourse_ai_enabled = true
|
||||||
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
|
SiteSetting.ai_embeddings_model = "all-mpnet-base-v2"
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||||
|
Jobs.run_immediately!
|
||||||
|
end
|
||||||
|
|
||||||
|
it "generates embeddings for new topics automatically" do
|
||||||
|
pc =
|
||||||
|
PostCreator.new(
|
||||||
|
user,
|
||||||
|
raw: "this is the new content for my topic",
|
||||||
|
title: "this is my new topic title",
|
||||||
|
)
|
||||||
|
input =
|
||||||
|
"This is my new topic title\n\nUncategorized\n\n\n\nthis is the new content for my topic\n\n"
|
||||||
|
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
|
||||||
|
post = pc.create
|
||||||
|
manager = DiscourseAi::Embeddings::Manager.new(post.topic)
|
||||||
|
|
||||||
|
embeddings =
|
||||||
|
DB.query_single(
|
||||||
|
"SELECT embeddings FROM #{manager.topic_embeddings_table} WHERE topic_id = #{post.topic.id}",
|
||||||
|
).first
|
||||||
|
|
||||||
|
expect(embeddings.split(",")[1].to_f).to be_within(0.0001).of(expected_embedding[1])
|
||||||
|
expect(embeddings.split(",")[13].to_f).to be_within(0.0001).of(expected_embedding[13])
|
||||||
|
expect(embeddings.split(",")[135].to_f).to be_within(0.0001).of(expected_embedding[135])
|
||||||
|
end
|
||||||
|
end
|
|
@ -1,36 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
require_relative "../../../support/embeddings_generation_stubs"
|
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Embeddings::Model do
|
|
||||||
describe "#generate_embedding" do
|
|
||||||
let(:input) { "test" }
|
|
||||||
let(:expected_embedding) { [0.0038493, 0.482001] }
|
|
||||||
|
|
||||||
context "when the model uses the discourse service to create embeddings" do
|
|
||||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
|
||||||
|
|
||||||
let(:discourse_model) { "all-mpnet-base-v2" }
|
|
||||||
|
|
||||||
it "returns an embedding for a given string" do
|
|
||||||
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
|
|
||||||
|
|
||||||
embedding = described_class.instantiate(discourse_model).generate_embedding(input)
|
|
||||||
|
|
||||||
expect(embedding).to contain_exactly(*expected_embedding)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "when the model uses OpenAI to create embeddings" do
|
|
||||||
let(:openai_model) { "text-embedding-ada-002" }
|
|
||||||
|
|
||||||
it "returns an embedding for a given string" do
|
|
||||||
EmbeddingsGenerationStubs.openai_service(openai_model, input, expected_embedding)
|
|
||||||
|
|
||||||
embedding = described_class.instantiate(openai_model).generate_embedding(input)
|
|
||||||
|
|
||||||
expect(embedding).to contain_exactly(*expected_embedding)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "../../../../support/embeddings_generation_stubs"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Embeddings::Models::AllMpnetBaseV2 do
|
||||||
|
describe "#generate_embeddings" do
|
||||||
|
let(:input) { "test" }
|
||||||
|
let(:expected_embedding) { [0.0038493, 0.482001] }
|
||||||
|
|
||||||
|
context "when the model uses the discourse service to create embeddings" do
|
||||||
|
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||||
|
|
||||||
|
let(:discourse_model) { "all-mpnet-base-v2" }
|
||||||
|
|
||||||
|
it "returns an embedding for a given string" do
|
||||||
|
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
|
||||||
|
|
||||||
|
embedding = described_class.generate_embeddings(input)
|
||||||
|
|
||||||
|
expect(embedding).to contain_exactly(*expected_embedding)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,22 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "../../../../support/embeddings_generation_stubs"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Embeddings::Models::TextEmbeddingAda002 do
|
||||||
|
describe "#generate_embeddings" do
|
||||||
|
let(:input) { "test" }
|
||||||
|
let(:expected_embedding) { [0.0038493, 0.482001] }
|
||||||
|
|
||||||
|
context "when the model uses OpenAI to create embeddings" do
|
||||||
|
let(:openai_model) { "text-embedding-ada-002" }
|
||||||
|
|
||||||
|
it "returns an embedding for a given string" do
|
||||||
|
EmbeddingsGenerationStubs.openai_service(openai_model, input, expected_embedding)
|
||||||
|
|
||||||
|
embedding = described_class.generate_embeddings(input)
|
||||||
|
|
||||||
|
expect(embedding).to contain_exactly(*expected_embedding)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -24,13 +24,6 @@ describe DiscourseAi::Embeddings::SemanticRelated do
|
||||||
end
|
end
|
||||||
|
|
||||||
it "queues job only once per 15 minutes" do
|
it "queues job only once per 15 minutes" do
|
||||||
# sadly we need to mock a DB connection to return nothing
|
|
||||||
DiscourseAi::Embeddings::Topic
|
|
||||||
.any_instance
|
|
||||||
.expects(:query_symmetric_embeddings)
|
|
||||||
.returns([])
|
|
||||||
.twice
|
|
||||||
|
|
||||||
results = nil
|
results = nil
|
||||||
|
|
||||||
expect_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
|
expect_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
|
||||||
|
@ -49,10 +42,9 @@ describe DiscourseAi::Embeddings::SemanticRelated do
|
||||||
context "when embeddings exist" do
|
context "when embeddings exist" do
|
||||||
before do
|
before do
|
||||||
Discourse.cache.clear
|
Discourse.cache.clear
|
||||||
DiscourseAi::Embeddings::Topic
|
DiscourseAi::Embeddings::SemanticRelated.expects(:symmetric_semantic_search).returns(
|
||||||
.any_instance
|
Topic.unscoped.order(id: :desc).limit(100).pluck(:id),
|
||||||
.expects(:symmetric_semantic_search)
|
)
|
||||||
.returns(Topic.unscoped.order(id: :desc).limit(100).pluck(:id))
|
|
||||||
end
|
end
|
||||||
|
|
||||||
after { Discourse.cache.clear }
|
after { Discourse.cache.clear }
|
||||||
|
|
|
@ -3,15 +3,13 @@
|
||||||
RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
||||||
fab!(:post) { Fabricate(:post) }
|
fab!(:post) { Fabricate(:post) }
|
||||||
fab!(:user) { Fabricate(:user) }
|
fab!(:user) { Fabricate(:user) }
|
||||||
let(:model_name) { "msmarco-distilbert-base-v4" }
|
|
||||||
let(:query) { "test_query" }
|
|
||||||
|
|
||||||
let(:model) { DiscourseAi::Embeddings::Model.instantiate(model_name) }
|
let(:query) { "test_query" }
|
||||||
let(:subject) { described_class.new(Guardian.new(user), model) }
|
let(:subject) { described_class.new(Guardian.new(user)) }
|
||||||
|
|
||||||
describe "#search_for_topics" do
|
describe "#search_for_topics" do
|
||||||
def stub_candidate_ids(candidate_ids)
|
def stub_candidate_ids(candidate_ids)
|
||||||
DiscourseAi::Embeddings::Topic
|
DiscourseAi::Embeddings::SemanticSearch
|
||||||
.any_instance
|
.any_instance
|
||||||
.expects(:asymmetric_semantic_search)
|
.expects(:asymmetric_semantic_search)
|
||||||
.returns(candidate_ids)
|
.returns(candidate_ids)
|
||||||
|
|
|
@ -19,10 +19,9 @@ describe ::TopicsController do
|
||||||
|
|
||||||
context "when a user is logged on" do
|
context "when a user is logged on" do
|
||||||
it "includes related topics in payload when configured" do
|
it "includes related topics in payload when configured" do
|
||||||
DiscourseAi::Embeddings::Topic
|
DiscourseAi::Embeddings::SemanticRelated.expects(:symmetric_semantic_search).returns(
|
||||||
.any_instance
|
[topic1.id, topic2.id, topic3.id],
|
||||||
.expects(:symmetric_semantic_search)
|
)
|
||||||
.returns([topic1.id, topic2.id, topic3.id])
|
|
||||||
|
|
||||||
get("#{topic.relative_url}.json")
|
get("#{topic.relative_url}.json")
|
||||||
expect(response.status).to eq(200)
|
expect(response.status).to eq(200)
|
||||||
|
@ -39,16 +38,5 @@ describe ::TopicsController do
|
||||||
expect(json["suggested_topics"].length).to eq(0)
|
expect(json["suggested_topics"].length).to eq(0)
|
||||||
expect(json["related_topics"].length).to eq(2)
|
expect(json["related_topics"].length).to eq(2)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "excludes embeddings when the database is offline" do
|
|
||||||
DiscourseAi::Database::Connection.stubs(:db).raises(PG::ConnectionBad)
|
|
||||||
|
|
||||||
get "#{topic.relative_url}.json"
|
|
||||||
expect(response.status).to eq(200)
|
|
||||||
json = response.parsed_body
|
|
||||||
|
|
||||||
expect(json["suggested_topics"].length).not_to eq(0)
|
|
||||||
expect(json["related_topics"].length).to eq(0)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -5,3 +5,7 @@ Licensed under Apache License
|
||||||
## claude-v1-tokenization.json
|
## claude-v1-tokenization.json
|
||||||
|
|
||||||
Licensed under MIT License
|
Licensed under MIT License
|
||||||
|
|
||||||
|
## all-mpnet-base-v2.json
|
||||||
|
|
||||||
|
Licensed under Apache License
|
||||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue