FEATURE: HyDE-powered semantic search. (#136)

* FEATURE: HyDE-powered semantic search.

It relies on the new outlet added on discourse/discourse#23390 to display semantic search results in an unobtrusive way.

We'll use a HyDE-backed approach for semantic search, which consists on generating an hypothetical document from a given keywords, which gets transformed into a vector and used in a asymmetric similarity topic search.

This PR also reorganizes the internals to have less moving parts, maintaining one hierarchy of DAOish classes for vector-related operations like transformations and querying.

Completions and vectors created by HyDE will remain cached on Redis for now, but we could later use Postgres instead.

* Missing translation and rate limiting

---------

Co-authored-by: Roman Rizzi <rizziromanalejandro@gmail.com>
This commit is contained in:
Rafael dos Santos Silva 2023-09-05 11:08:23 -03:00 committed by GitHub
parent 3d83d062a1
commit 2c0f535bab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 970 additions and 700 deletions

View File

@ -9,7 +9,6 @@ module DiscourseAi
def search
query = params[:q]
page = (params[:page] || 1).to_i
grouped_results =
Search::GroupedSearchResults.new(
@ -19,12 +18,19 @@ module DiscourseAi
use_pg_headlines_for_excerpt: false,
)
DiscourseAi::Embeddings::SemanticSearch
.new(guardian)
.search_for_topics(query, page)
.each { |topic_post| grouped_results.add(topic_post) }
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian)
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
if !semantic_search.cached_query?(query)
RateLimiter.new(current_user, "semantic-search", 4, 1.minutes).performed!
end
hijack do
semantic_search
.search_for_topics(query)
.each { |topic_post| grouped_results.add(topic_post) }
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
end
end
end
end

View File

@ -0,0 +1,51 @@
{{#if this.searchEnabled}}
<div class="semantic-search__container search-results" role="region">
<div
class="semantic-search__results"
{{did-insert this.setup}}
{{did-insert this.debouncedSearch}}
{{will-destroy this.teardown}}
>
{{#if this.searching}}
<div class="semantic-search__searching">
<div class="semantic-search__searching-text">
{{i18n "discourse_ai.embeddings.semantic_search_loading"}}
</div>
<span class="semantic-search__indicator-wave">
<span class="semantic-search__indicator-dot">.</span>
<span class="semantic-search__indicator-dot">.</span>
<span class="semantic-search__indicator-dot">.</span>
</span>
</div>
{{else}}
{{#if this.results.length}}
<div class="semantic-search__toggle-button-container">
<DButton
@translatedTitle={{this.collapsedResultsTitle}}
@translatedLabel={{this.collapsedResultsTitle}}
@action={{fn
(mut this.collapsedResults)
(not this.collapsedResults)
}}
@class="btn-flat"
@icon={{if this.collapsedResults "chevron-right" "chevron-down"}}
/>
</div>
{{#unless this.collapsedResults}}
<div class="semantic-search__entries">
<SearchResultEntries
@posts={{this.results}}
@highlightQuery={{this.highlightQuery}}
/>
</div>
{{/unless}}
{{else}}
<div class="semantic-search__searching">
{{i18n "discourse_ai.embeddings.semantic_search_results.none"}}
</div>
{{/if}}
{{/if}}
</div>
</div>
{{/if}}

View File

@ -0,0 +1,82 @@
import Component from "@glimmer/component";
import { action, computed } from "@ember/object";
import I18n from "I18n";
import { tracked } from "@glimmer/tracking";
import { ajax } from "discourse/lib/ajax";
import { translateResults } from "discourse/lib/search";
import discourseDebounce from "discourse-common/lib/debounce";
import { inject as service } from "@ember/service";
import { bind } from "discourse-common/utils/decorators";
import { SEARCH_TYPE_DEFAULT } from "discourse/controllers/full-page-search";
export default class extends Component {
static shouldRender(_args, { siteSettings }) {
return siteSettings.ai_embeddings_semantic_search_enabled;
}
@service appEvents;
@tracked searching = false;
@tracked collapsedResults = true;
@tracked results = [];
@computed("args.outletArgs.search")
get searchTerm() {
return this.args.outletArgs.search;
}
@computed("args.outletArgs.type")
get searchEnabled() {
return this.args.outletArgs.type === SEARCH_TYPE_DEFAULT;
}
@computed("results")
get collapsedResultsTitle() {
return I18n.t("discourse_ai.embeddings.semantic_search_results.toggle", {
count: this.results.length,
});
}
@action
setup() {
this.appEvents.on(
"full-page-search:trigger-search",
this,
"debouncedSearch"
);
}
@action
teardown() {
this.appEvents.off(
"full-page-search:trigger-search",
this,
"debouncedSearch"
);
}
@bind
performHyDESearch() {
if (!this.searchTerm || !this.searchEnabled || this.searching) {
return;
}
this.searching = true;
this.collapsedResults = true;
this.results = [];
ajax("/discourse-ai/embeddings/semantic-search", {
data: { q: this.searchTerm },
})
.then(async (results) => {
const model = (await translateResults(results)) || {};
this.results = model.posts;
})
.finally(() => (this.searching = false));
}
@action
debouncedSearch() {
discourseDebounce(this, this.performHyDESearch, 500);
}
}

View File

@ -1,63 +0,0 @@
import { withPluginApi } from "discourse/lib/plugin-api";
import { translateResults, updateRecentSearches } from "discourse/lib/search";
import { ajax } from "discourse/lib/ajax";
const SEMANTIC_SEARCH = "semantic_search";
function initializeSemanticSearch(api) {
api.addFullPageSearchType(
"discourse_ai.embeddings.semantic_search",
SEMANTIC_SEARCH,
(searchController, args) => {
if (searchController.currentUser) {
updateRecentSearches(searchController.currentUser, args.searchTerm);
}
ajax("/discourse-ai/embeddings/semantic-search", { data: args })
.then(async (results) => {
const model = (await translateResults(results)) || {};
if (results.grouped_search_result) {
searchController.set("q", results.grouped_search_result.term);
}
if (args.page > 1) {
if (model) {
searchController.model.posts.pushObjects(model.posts);
searchController.model.topics.pushObjects(model.topics);
searchController.model.set(
"grouped_search_result",
results.grouped_search_result
);
}
} else {
model.grouped_search_result = results.grouped_search_result;
searchController.set("model", model);
}
searchController.set("error", null);
})
.catch((e) => {
searchController.set("error", e.jqXHR.responseJSON?.message);
})
.finally(() => {
searchController.setProperties({
searching: false,
loading: false,
});
});
}
);
}
export default {
name: "discourse-ai-full-page-semantic-search",
initialize(container) {
const settings = container.lookup("service:site-settings");
const semanticSearch = settings.ai_embeddings_semantic_search_enabled;
if (settings.ai_embeddings_enabled && semanticSearch) {
withPluginApi("1.6.0", initializeSemanticSearch);
}
},
};

View File

@ -0,0 +1,39 @@
.semantic-search__container {
background: var(--primary-very-low);
margin: 1rem 0 1rem 0;
.semantic-search__results {
display: flex;
flex-direction: column;
align-items: baseline;
.semantic-search {
&__searching-text {
display: inline-block;
margin-left: 3px;
}
&__indicator-wave {
flex: 0 0 auto;
display: inline-flex;
}
&__indicator-dot {
display: inline-block;
animation: ai-summary__indicator-wave 1.8s linear infinite;
&:nth-child(2) {
animation-delay: -1.6s;
}
&:nth-child(3) {
animation-delay: -1.4s;
}
}
}
.semantic-search__entries {
margin-top: 10px;
}
.semantic-search__searching {
margin-left: 5px;
}
}
}

View File

@ -34,6 +34,10 @@ en:
embeddings:
semantic_search: "Topics (Semantic)"
semantic_search_loading: "Searching for more results using AI"
semantic_search_results:
toggle: "Found %{count} results using AI"
none: "Sorry, our AI search found no matching topics."
ai_bot:
pm_warning: "AI chatbot messages are monitored regularly by moderators."

View File

@ -55,6 +55,7 @@ en:
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
ai_embeddings_semantic_search_enabled: "Enable full-page semantic search."
ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results"
ai_embeddings_semantic_search_hyde_model: "Model used to expand keywords to get better results during a semantic search"
ai_summarization_discourse_service_api_endpoint: "URL where the Discourse summarization API is running."
ai_summarization_discourse_service_api_key: "API key for the Discourse summarization API."

View File

@ -177,6 +177,18 @@ discourse_ai:
ai_embeddings_semantic_search_enabled:
default: false
client: true
ai_embeddings_semantic_search_hyde_model:
default: "gpt-3.5-turbo"
type: enum
allow_any: false
choices:
- Llama2-*-chat-hf
- claude-instant-1
- claude-2
- gpt-3.5-turbo
- gpt-4
- StableBeluga2
- Upstage-Llama-2-*-instruct-v2
ai_summarization_discourse_service_api_endpoint: ""
ai_summarization_discourse_service_api_key:

View File

@ -11,13 +11,12 @@ module DiscourseAi
return [] if @text.blank?
return [] unless SiteSetting.ai_embeddings_enabled
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
candidates =
::DiscourseAi::Embeddings::SemanticSearch.new(nil).asymmetric_semantic_search(
@text,
100,
0,
return_distance: true,
)
vector_rep.asymmetric_semantic_search(@text, limit: 100, offset: 0, return_distance: true)
candidate_ids = candidates.map(&:first)
::Topic

View File

@ -4,16 +4,21 @@ module DiscourseAi
module Embeddings
class EntryPoint
def load_files
require_relative "models/base"
require_relative "models/all_mpnet_base_v2"
require_relative "models/text_embedding_ada_002"
require_relative "models/multilingual_e5_large"
require_relative "vector_representations/base"
require_relative "vector_representations/all_mpnet_base_v2"
require_relative "vector_representations/text_embedding_ada_002"
require_relative "vector_representations/multilingual_e5_large"
require_relative "strategies/truncation"
require_relative "manager"
require_relative "jobs/regular/generate_embeddings"
require_relative "semantic_related"
require_relative "semantic_search"
require_relative "semantic_topic_query"
require_relative "hyde_generators/base"
require_relative "hyde_generators/openai"
require_relative "hyde_generators/anthropic"
require_relative "hyde_generators/llama2"
require_relative "hyde_generators/llama2_ftos"
require_relative "semantic_search"
end
def inject_into(plugin)

View File

@ -0,0 +1,32 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module HydeGenerators
class Anthropic < DiscourseAi::Embeddings::HydeGenerators::Base
def prompt(search_term)
<<~TEXT
Given a search term given between <input> tags, generate a forum post about the search term.
Respond with the generated post between <ai> tags.
<input>#{search_term}</input>
TEXT
end
def models
%w[claude-instant-1 claude-2]
end
def hypothetical_post_from(query)
response =
::DiscourseAi::Inference::AnthropicCompletions.perform!(
prompt(query),
SiteSetting.ai_embeddings_semantic_search_hyde_model,
).dig(:completion)
Nokogiri::HTML5.fragment(response).at("ai").text
end
end
end
end
end

View File

@ -0,0 +1,17 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module HydeGenerators
class Base
def self.current_hyde_model
DiscourseAi::Embeddings::HydeGenerators::Base.descendants.find do |generator_klass|
generator_klass.new.models.include?(
SiteSetting.ai_embeddings_semantic_search_hyde_model,
)
end
end
end
end
end
end

View File

@ -0,0 +1,34 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module HydeGenerators
class Llama2 < DiscourseAi::Embeddings::HydeGenerators::Base
def prompt(search_term)
<<~TEXT
[INST] <<SYS>>
You are a helpful bot
You create forum posts about a given topic
<</SYS>>
Topic: #{search_term}
[/INST]
Here is a forum post about the above topic:
TEXT
end
def models
["Llama2-*-chat-hf"]
end
def hypothetical_post_from(query)
::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(
prompt(query),
SiteSetting.ai_embeddings_semantic_search_hyde_model,
token_limit: 400,
).dig(:generated_text)
end
end
end
end
end

View File

@ -0,0 +1,27 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module HydeGenerators
class Llama2Ftos < DiscourseAi::Embeddings::HydeGenerators::Llama2
def prompt(search_term)
<<~TEXT
### System:
You are a helpful bot
You create forum posts about a given topic
### User:
Topic: #{search_term}
### Assistant:
Here is a forum post about the above topic:
TEXT
end
def models
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2]
end
end
end
end
end

View File

@ -0,0 +1,30 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module HydeGenerators
class OpenAi < DiscourseAi::Embeddings::HydeGenerators::Base
def prompt(search_term)
[
{
role: "system",
content: "You are a helpful bot. You create forum posts about a given topic.",
},
{ role: "user", content: "Create a forum post about the topic: #{search_term}" },
]
end
def models
%w[gpt-3.5-turbo gpt-4]
end
def hypothetical_post_from(query)
::DiscourseAi::Inference::OpenAiCompletions.perform!(
prompt(query),
SiteSetting.ai_embeddings_semantic_search_hyde_model,
).dig(:choices, 0, :message, :content)
end
end
end
end
end

View File

@ -11,7 +11,11 @@ module Jobs
post = topic.first_post
return if post.nil? || post.raw.blank?
DiscourseAi::Embeddings::Manager.new(topic).generate!
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new
vector_rep.generate_topic_representation_from(topic, strategy)
end
end
end

View File

@ -1,64 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
class Manager
attr_reader :target, :model, :strategy
def initialize(target)
@target = target
@model =
DiscourseAi::Embeddings::Models::Base.subclasses.find do
_1.name == SiteSetting.ai_embeddings_model
end
@strategy = DiscourseAi::Embeddings::Strategies::Truncation.new(@target, @model)
end
def generate!
@strategy.process!
# TODO bail here if we already have an embedding with matching version and digest
@embeddings = @model.generate_embeddings(@strategy.processed_target)
persist!
end
def persist!
begin
DB.exec(
<<~SQL,
INSERT INTO ai_topic_embeddings_#{table_suffix} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (topic_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = CURRENT_TIMESTAMP
SQL
topic_id: @target.id,
model_version: @model.version,
strategy_version: @strategy.version,
digest: @strategy.digest,
embeddings: @embeddings,
)
rescue PG::Error => e
Rails.logger.error(
"Error #{e} persisting embedding for topic #{topic.id} and model #{model.name}",
)
end
end
def table_suffix
"#{@model.id}_#{@strategy.id}"
end
def topic_embeddings_table
"ai_topic_embeddings_#{table_suffix}"
end
end
end
end

View File

@ -1,52 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module Models
class AllMpnetBaseV2 < Base
class << self
def id
1
end
def version
1
end
def name
"all-mpnet-base-v2"
end
def dimensions
768
end
def max_sequence_length
384
end
def pg_function
"<#>"
end
def pg_index_type
"vector_ip_ops"
end
def generate_embeddings(text)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
name,
text,
SiteSetting.ai_embeddings_discourse_service_api_key,
)
end
def tokenizer
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
end
end
end
end
end
end

View File

@ -1,10 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module Models
class Base
end
end
end
end

View File

@ -1,52 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module Models
class MultilingualE5Large < Base
class << self
def id
3
end
def version
1
end
def name
"multilingual-e5-large"
end
def dimensions
1024
end
def max_sequence_length
512
end
def pg_function
"<=>"
end
def pg_index_type
"vector_cosine_ops"
end
def generate_embeddings(text)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
name,
"query: #{text}",
SiteSetting.ai_embeddings_discourse_service_api_key,
)
end
def tokenizer
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
end
end
end
end
end
end

View File

@ -1,48 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module Models
class TextEmbeddingAda002 < Base
class << self
def id
2
end
def version
1
end
def name
"text-embedding-ada-002"
end
def dimensions
1536
end
def max_sequence_length
8191
end
def pg_function
"<=>"
end
def pg_index_type
"vector_cosine_ops"
end
def generate_embeddings(text)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text)
response[:data].first[:embedding]
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end
end
end
end
end

View File

@ -5,101 +5,67 @@ module DiscourseAi
class SemanticRelated
MissingEmbeddingError = Class.new(StandardError)
class << self
def semantic_suggested_key(topic_id)
"semantic-suggested-topic-#{topic_id}"
end
def self.clear_cache_for(topic)
Discourse.cache.delete("semantic-suggested-topic-#{topic.id}")
Discourse.redis.del("build-semantic-suggested-topic-#{topic.id}")
end
def build_semantic_suggested_key(topic_id)
"build-semantic-suggested-topic-#{topic_id}"
end
def related_topic_ids_for(topic)
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
def clear_cache_for(topic)
Discourse.cache.delete(semantic_suggested_key(topic.id))
Discourse.redis.del(build_semantic_suggested_key(topic.id))
end
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
cache_for = results_ttl(topic)
def related_topic_ids_for(topic)
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
manager = DiscourseAi::Embeddings::Manager.new(topic)
cache_for = results_ttl(topic)
begin
Discourse
.cache
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
symmetric_semantic_search(manager)
end
rescue MissingEmbeddingError
# avoid a flood of jobs when visiting topic
if Discourse.redis.set(
build_semantic_suggested_key(topic.id),
"queued",
ex: 15.minutes.to_i,
nx: true,
)
Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
asd =
Discourse
.cache
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
vector_rep
.symmetric_topics_similarity_search(topic)
.tap do |candidate_ids|
# Happens when the topic doesn't have any embeddings
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}"
end
end
end
[]
end
rescue MissingEmbeddingError
# avoid a flood of jobs when visiting topic
if Discourse.redis.set(
build_semantic_suggested_key(topic.id),
"queued",
ex: 15.minutes.to_i,
nx: true,
)
Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
end
[]
end
def symmetric_semantic_search(manager)
topic = manager.target
candidate_ids = self.query_symmetric_embeddings(manager)
# Happens when the topic doesn't have any embeddings
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}"
end
candidate_ids
def results_ttl(topic)
case topic.created_at
when 6.hour.ago..Time.now
15.minutes
when 3.day.ago..6.hour.ago
1.hour
when 15.days.ago..3.day.ago
12.hours
else
1.week
end
end
def query_symmetric_embeddings(manager)
topic = manager.target
model = manager.model
table = manager.topic_embeddings_table
begin
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
SELECT
topic_id
FROM
#{table}
ORDER BY
embeddings #{model.pg_function} (
SELECT
embeddings
FROM
#{table}
WHERE
topic_id = :topic_id
LIMIT 1
)
LIMIT 100
SQL
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}",
)
raise MissingEmbeddingError
end
end
private
def results_ttl(topic)
case topic.created_at
when 6.hour.ago..Time.now
15.minutes
when 3.day.ago..6.hour.ago
1.hour
when 15.days.ago..3.day.ago
12.hours
else
1.week
end
end
def semantic_suggested_key(topic_id)
"semantic-suggested-topic-#{topic_id}"
end
def build_semantic_suggested_key(topic_id)
"build-semantic-suggested-topic-#{topic_id}"
end
end
end

View File

@ -3,59 +3,66 @@
module DiscourseAi
module Embeddings
class SemanticSearch
def self.clear_cache_for(query)
digest = OpenSSL::Digest::SHA1.hexdigest(query)
Discourse.cache.delete("hyde-doc-#{digest}")
Discourse.cache.delete("hyde-doc-embedding-#{digest}")
end
def initialize(guardian)
@guardian = guardian
@manager = DiscourseAi::Embeddings::Manager.new(nil)
@model = @manager.model
end
def cached_query?(query)
digest = OpenSSL::Digest::SHA1.hexdigest(query)
Discourse.cache.read("hyde-doc-embedding-#{digest}").present?
end
def search_for_topics(query, page = 1)
limit = Search.per_filter + 1
offset = (page - 1) * Search.per_filter
max_results_per_page = 50
limit = [Search.per_filter, max_results_per_page].min + 1
offset = (page - 1) * limit
candidate_ids = asymmetric_semantic_search(query, limit, offset)
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
digest = OpenSSL::Digest::SHA1.hexdigest(query)
hypothetical_post =
Discourse
.cache
.fetch("hyde-doc-#{digest}", expires_in: 1.week) do
hyde_generator = DiscourseAi::Embeddings::HydeGenerators::Base.current_hyde_model.new
hyde_generator.hypothetical_post_from(query)
end
hypothetical_post_embedding =
Discourse
.cache
.fetch("hyde-doc-embedding-#{digest}", expires_in: 1.week) do
vector_rep.vector_from(hypothetical_post)
end
candidate_topic_ids =
vector_rep.asymmetric_topics_similarity_search(
hypothetical_post_embedding,
limit: limit,
offset: offset,
)
::Post
.where(post_type: ::Topic.visible_post_types(guardian.user))
.public_posts
.where("topics.visible")
.where(topic_id: candidate_ids, post_number: 1)
.order("array_position(ARRAY#{candidate_ids}, topic_id)")
end
def asymmetric_semantic_search(query, limit, offset, return_distance: false)
embedding = model.generate_embeddings(query)
table = @manager.topic_embeddings_table
begin
candidate_ids = DB.query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset)
SELECT
topic_id,
embeddings #{@model.pg_function} '[:query_embedding]' AS distance
FROM
#{table}
ORDER BY
embeddings #{@model.pg_function} '[:query_embedding]'
LIMIT :limit
OFFSET :offset
SQL
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for model #{model.name} and search #{query}",
)
raise MissingEmbeddingError
end
if return_distance
candidate_ids.map { |c| [c.topic_id, c.distance] }
else
candidate_ids.map(&:topic_id)
end
.where(topic_id: candidate_topic_ids, post_number: 1)
.order("array_position(ARRAY#{candidate_topic_ids}, topic_id)")
end
private
attr_reader :model, :guardian
attr_reader :guardian
end
end
end

View File

@ -14,7 +14,7 @@ class DiscourseAi::Embeddings::SemanticTopicQuery < TopicQuery
list =
create_list(:semantic_related, query_opts) do |topics|
candidate_ids = DiscourseAi::Embeddings::SemanticRelated.related_topic_ids_for(topic)
candidate_ids = DiscourseAi::Embeddings::SemanticRelated.new.related_topic_ids_for(topic)
list =
topics

View File

@ -4,77 +4,57 @@ module DiscourseAi
module Embeddings
module Strategies
class Truncation
attr_reader :processed_target, :digest
def self.id
1
end
def id
self.class.id
1
end
def version
1
end
def initialize(target, model)
@model = model
@target = target
@tokenizer = @model.tokenizer
@max_length = @model.max_sequence_length - 2
@processed_target = nil
end
# Need a better name for this method
def process!
case @target
def prepare_text_from(target, tokenizer, max_length)
case target
when Topic
@processed_target = topic_truncation(@target)
topic_truncation(target, tokenizer, max_length)
when Post
@processed_target = post_truncation(@target)
post_truncation(target, tokenizer, max_length)
else
raise ArgumentError, "Invalid target type"
end
@digest = OpenSSL::Digest::SHA1.hexdigest(@processed_target)
end
def topic_truncation(topic)
t = +""
private
t << topic.title
t << "\n\n"
t << topic.category.name
def topic_information(topic)
info = +""
info << topic.title
info << "\n\n"
info << topic.category.name
if SiteSetting.tagging_enabled
t << "\n\n"
t << topic.tags.pluck(:name).join(", ")
info << "\n\n"
info << topic.tags.pluck(:name).join(", ")
end
t << "\n\n"
info << "\n\n"
end
def topic_truncation(topic, tokenizer, max_length)
text = +topic_information(topic)
topic.posts.find_each do |post|
t << post.raw
break if @tokenizer.size(t) >= @max_length #maybe keep a partial counter to speed this up?
t << "\n\n"
text << post.raw
break if tokenizer.size(text) >= max_length #maybe keep a partial counter to speed this up?
text << "\n\n"
end
@tokenizer.truncate(t, @max_length)
tokenizer.truncate(text, max_length)
end
def post_truncation(post)
t = +""
def post_truncation(topic, tokenizer, max_length)
text = +topic_information(post.topic)
text << post.raw
t << post.topic.title
t << "\n\n"
t << post.topic.category.name
if SiteSetting.tagging_enabled
t << "\n\n"
t << post.topic.tags.pluck(:name).join(", ")
end
t << "\n\n"
t << post.raw
@tokenizer.truncate(t, @max_length)
tokenizer.truncate(text, max_length)
end
end
end

View File

@ -0,0 +1,50 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class AllMpnetBaseV2 < Base
def vector_from(text)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
name,
text,
SiteSetting.ai_embeddings_discourse_service_api_key,
)
end
def name
"all-mpnet-base-v2"
end
def dimensions
768
end
def max_sequence_length
384
end
def id
1
end
def version
1
end
def pg_function
"<#>"
end
def pg_index_type
"vector_ip_ops"
end
def tokenizer
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
end
end
end
end
end

View File

@ -0,0 +1,166 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class Base
def self.current_representation(strategy)
subclasses.map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model }
end
def initialize(strategy)
@strategy = strategy
end
def create_index(lists, probes)
index_name = "#{table_name}_search"
DB.exec(<<~SQL)
DROP INDEX IF EXISTS #{index_name};
CREATE INDEX IF NOT EXISTS
#{index}
ON
#{table_name}
USING
ivfflat (embeddings #{pg_index_type})
WITH
(lists = #{lists})
WHERE
model_version = #{version} AND
strategy_version = #{@strategy.version};
SQL
end
def vector_from(text)
raise NotImplementedError
end
def generate_topic_representation_from(target, persist: true)
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
vector_from(text).tap do |vector|
if persist
digest = OpenSSL::Digest::SHA1.hexdigest(text)
save_to_db(target, vector, digest)
end
end
end
def topic_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
topic_id
FROM
#{table_name}
ORDER BY
embeddings #{pg_function} '[:query_embedding]'
LIMIT 1
SQL
end
def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
SELECT
topic_id,
embeddings #{pg_function} '[:query_embedding]' AS distance
FROM
#{table_name}
ORDER BY
embeddings #{pg_function} '[:query_embedding]'
LIMIT :limit
OFFSET :offset
SQL
if return_distance
results.map { |r| [r.topic_id, r.distance] }
else
results.map(&:topic_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def symmetric_topics_similarity_search(topic)
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
SELECT
topic_id
FROM
#{table_name}
ORDER BY
embeddings #{pg_function} (
SELECT
embeddings
FROM
#{table_name}
WHERE
topic_id = :topic_id
LIMIT 1
)
LIMIT 100
SQL
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
)
raise MissingEmbeddingError
end
def table_name
"ai_topic_embeddings_#{id}_#{@strategy.id}"
end
def name
raise NotImplementedError
end
def dimensions
raise NotImplementedError
end
def max_sequence_length
raise NotImplementedError
end
def id
raise NotImplementedError
end
def pg_function
raise NotImplementedError
end
def version
raise NotImplementedError
end
def tokenizer
raise NotImplementedError
end
protected
def save_to_db(target, vector, digest)
DB.exec(
<<~SQL,
INSERT INTO #{table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (topic_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = CURRENT_TIMESTAMP
SQL
topic_id: target.id,
model_version: version,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
)
end
end
end
end
end

View File

@ -0,0 +1,50 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class MultilingualE5Large < Base
def vector_from(text)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
name,
"query: #{text}",
SiteSetting.ai_embeddings_discourse_service_api_key,
)
end
def id
3
end
def version
1
end
def name
"multilingual-e5-large"
end
def dimensions
1024
end
def max_sequence_length
512
end
def pg_function
"<=>"
end
def pg_index_type
"vector_cosine_ops"
end
def tokenizer
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
end
end
end
end
end

View File

@ -0,0 +1,46 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class TextEmbeddingAda002 < Base
def id
2
end
def version
1
end
def name
"text-embedding-ada-002"
end
def dimensions
1536
end
def max_sequence_length
8191
end
def pg_function
"<=>"
end
def pg_index_type
"vector_cosine_ops"
end
def vector_from(text)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text)
response[:data].first[:embedding]
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end
end
end
end

View File

@ -4,7 +4,7 @@ module ::DiscourseAi
module Inference
class HuggingFaceTextGeneration
CompletionFailed = Class.new(StandardError)
TIMEOUT = 60
TIMEOUT = 120
def self.perform!(
prompt,

View File

@ -4,18 +4,22 @@ desc "Backfill embeddings for all topics"
task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args|
public_categories = Category.where(read_restricted: false).pluck(:id)
manager = DiscourseAi::Embeddings::Manager.new(Topic.first)
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new(strategy)
table_name = vector_rep.table_name
Topic
.joins(
"LEFT JOIN #{manager.topic_embeddings_table} ON #{manager.topic_embeddings_table}.topic_id = topics.id",
)
.where("#{manager.topic_embeddings_table}.topic_id IS NULL")
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
.where("#{table_name}.topic_id IS NULL")
.where("topics.id >= ?", args[:start_topic].to_i || 0)
.where("category_id IN (?)", public_categories)
.where(deleted_at: nil)
.order("topics.id ASC")
.find_each do |t|
print "."
DiscourseAi::Embeddings::Manager.new(t).generate!
vector_rep.generate_topic_representation_from(t)
end
end
@ -28,25 +32,11 @@ task "ai:embeddings:index", [:work_mem] => [:environment] do |_, args|
lists = count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i
probes = count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i
manager = DiscourseAi::Embeddings::Manager.new(Topic.first)
table = manager.topic_embeddings_table
index = "#{table}_search"
vector_representation_klass = DiscourseAi::Embeddings::Vectors::Base.find_vector_representation
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
DB.exec("SET work_mem TO '#{args[:work_mem] || "1GB"}';")
DB.exec(<<~SQL)
DROP INDEX IF EXISTS #{index};
CREATE INDEX IF NOT EXISTS
#{index}
ON
#{table}
USING
ivfflat (embeddings #{manager.model.pg_index_type})
WITH
(lists = #{lists})
WHERE
model_version = #{manager.model.version} AND
strategy_version = #{manager.strategy.version};
SQL
vector_representation_klass.new(strategy).create_index(lists, probes)
DB.exec("RESET work_mem;")
DB.exec("SET ivfflat.probes = #{probes};")
end

View File

@ -17,6 +17,7 @@ register_asset "stylesheets/modules/ai-helper/common/ai-helper.scss"
register_asset "stylesheets/modules/ai-bot/common/bot-replies.scss"
register_asset "stylesheets/modules/embeddings/common/semantic-related-topics.scss"
register_asset "stylesheets/modules/embeddings/common/semantic-search.scss"
module ::DiscourseAi
PLUGIN_NAME = "discourse-ai"

View File

@ -1,44 +0,0 @@
# frozen_string_literal: true
require_relative "../../support/embeddings_generation_stubs"
RSpec.describe DiscourseAi::Embeddings::Manager do
let(:user) { Fabricate(:user) }
let(:expected_embedding) do
JSON.parse(
File.read("#{Rails.root}/plugins/discourse-ai/spec/fixtures/embeddings/embedding.txt"),
)
end
let(:discourse_model) { "all-mpnet-base-v2" }
before do
SiteSetting.discourse_ai_enabled = true
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_model = "all-mpnet-base-v2"
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
Jobs.run_immediately!
end
it "generates embeddings for new topics automatically" do
pc =
PostCreator.new(
user,
raw: "this is the new content for my topic",
title: "this is my new topic title",
)
input =
"This is my new topic title\n\nUncategorized\n\n\n\nthis is the new content for my topic\n\n"
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
post = pc.create
manager = DiscourseAi::Embeddings::Manager.new(post.topic)
embeddings =
DB.query_single(
"SELECT embeddings FROM #{manager.topic_embeddings_table} WHERE topic_id = #{post.topic.id}",
).first
expect(embeddings.split(",")[1].to_f).to be_within(0.0001).of(expected_embedding[1])
expect(embeddings.split(",")[13].to_f).to be_within(0.0001).of(expected_embedding[13])
expect(embeddings.split(",")[135].to_f).to be_within(0.0001).of(expected_embedding[135])
end
end

View File

@ -28,97 +28,4 @@ describe DiscourseAi::Embeddings::EntryPoint do
end
end
end
describe "SemanticTopicQuery extension" do
describe "#list_semantic_related_topics" do
subject(:topic_query) { DiscourseAi::Embeddings::SemanticTopicQuery.new(user) }
fab!(:target) { Fabricate(:topic) }
def stub_semantic_search_with(results)
DiscourseAi::Embeddings::SemanticRelated.expects(:related_topic_ids_for).returns(results)
end
context "when the semantic search returns an unlisted topic" do
fab!(:unlisted_topic) { Fabricate(:topic, visible: false) }
before { stub_semantic_search_with([unlisted_topic.id]) }
it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
end
end
context "when the semantic search returns a private topic" do
fab!(:private_topic) { Fabricate(:private_message_topic) }
before { stub_semantic_search_with([private_topic.id]) }
it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
end
end
context "when the semantic search returns a topic from a restricted category" do
fab!(:group) { Fabricate(:group) }
fab!(:category) { Fabricate(:private_category, group: group) }
fab!(:secured_category_topic) { Fabricate(:topic, category: category) }
before { stub_semantic_search_with([secured_category_topic.id]) }
it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
end
it "doesn't filter it out if the user has access to the category" do
group.add(user)
expect(topic_query.list_semantic_related_topics(target).topics).to contain_exactly(
secured_category_topic,
)
end
end
context "when the semantic search returns a closed topic and we explicitly exclude them" do
fab!(:closed_topic) { Fabricate(:topic, closed: true) }
before do
SiteSetting.ai_embeddings_semantic_related_include_closed_topics = false
stub_semantic_search_with([closed_topic.id])
end
it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
end
end
context "when the semantic search returns public topics" do
fab!(:normal_topic_1) { Fabricate(:topic) }
fab!(:normal_topic_2) { Fabricate(:topic) }
fab!(:normal_topic_3) { Fabricate(:topic) }
fab!(:closed_topic) { Fabricate(:topic, closed: true) }
before do
stub_semantic_search_with(
[closed_topic.id, normal_topic_1.id, normal_topic_2.id, normal_topic_3.id],
)
end
it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to eq(
[closed_topic, normal_topic_1, normal_topic_2, normal_topic_3],
)
end
it "returns the plugin limit for the number of results" do
SiteSetting.ai_embeddings_semantic_related_topics = 2
expect(topic_query.list_semantic_related_topics(target).topics).to contain_exactly(
closed_topic,
normal_topic_1,
)
end
end
end
end
end

View File

@ -1,24 +0,0 @@
# frozen_string_literal: true
require_relative "../../../../support/embeddings_generation_stubs"
RSpec.describe DiscourseAi::Embeddings::Models::AllMpnetBaseV2 do
describe "#generate_embeddings" do
let(:input) { "test" }
let(:expected_embedding) { [0.0038493, 0.482001] }
context "when the model uses the discourse service to create embeddings" do
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
let(:discourse_model) { "all-mpnet-base-v2" }
it "returns an embedding for a given string" do
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
embedding = described_class.generate_embeddings(input)
expect(embedding).to contain_exactly(*expected_embedding)
end
end
end
end

View File

@ -1,22 +0,0 @@
# frozen_string_literal: true
require_relative "../../../../support/embeddings_generation_stubs"
RSpec.describe DiscourseAi::Embeddings::Models::TextEmbeddingAda002 do
describe "#generate_embeddings" do
let(:input) { "test" }
let(:expected_embedding) { [0.0038493, 0.482001] }
context "when the model uses OpenAI to create embeddings" do
let(:openai_model) { "text-embedding-ada-002" }
it "returns an embedding for a given string" do
EmbeddingsGenerationStubs.openai_service(openai_model, input, expected_embedding)
embedding = described_class.generate_embeddings(input)
expect(embedding).to contain_exactly(*expected_embedding)
end
end
end
end

View File

@ -3,6 +3,8 @@
require "rails_helper"
describe DiscourseAi::Embeddings::SemanticRelated do
subject(:semantic_related) { described_class.new }
fab!(:target) { Fabricate(:topic) }
fab!(:normal_topic_1) { Fabricate(:topic) }
fab!(:normal_topic_2) { Fabricate(:topic) }
@ -25,13 +27,13 @@ describe DiscourseAi::Embeddings::SemanticRelated do
results = nil
expect_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
results = described_class.related_topic_ids_for(topic)
results = semantic_related.related_topic_ids_for(topic)
end
expect(results).to eq([])
expect_not_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
results = described_class.related_topic_ids_for(topic)
results = semantic_related.related_topic_ids_for(topic)
end
expect(results).to eq([])

View File

@ -1,5 +1,8 @@
# frozen_string_literal: true
require_relative "../../../support/embeddings_generation_stubs"
require_relative "../../../support/openai_completions_inference_stubs"
RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
fab!(:post) { Fabricate(:post) }
fab!(:user) { Fabricate(:user) }
@ -8,10 +11,28 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
let(:subject) { described_class.new(Guardian.new(user)) }
describe "#search_for_topics" do
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
before do
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
prompt = DiscourseAi::Embeddings::HydeGenerators::OpenAi.new.prompt(query)
OpenAiCompletionsInferenceStubs.stub_response(prompt, hypothetical_post)
hyde_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
hypothetical_post,
hyde_embedding,
)
end
after { described_class.clear_cache_for(query) }
def stub_candidate_ids(candidate_ids)
DiscourseAi::Embeddings::SemanticSearch
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
.any_instance
.expects(:asymmetric_semantic_search)
.expects(:asymmetric_topics_similarity_search)
.returns(candidate_ids)
end

View File

@ -12,9 +12,14 @@ describe DiscourseAi::Embeddings::EntryPoint do
fab!(:target) { Fabricate(:topic) }
def stub_semantic_search_with(results)
DiscourseAi::Embeddings::SemanticRelated.expects(:related_topic_ids_for).returns(results)
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
.any_instance
.expects(:symmetric_topics_similarity_search)
.returns(results.concat([target.id]))
end
after { DiscourseAi::Embeddings::SemanticRelated.clear_cache_for(target) }
context "when the semantic search returns an unlisted topic" do
fab!(:unlisted_topic) { Fabricate(:topic, visible: false) }

View File

@ -1,8 +1,10 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
describe "#process!" do
context "when the model uses OpenAI to create embeddings" do
subject(:truncation) { described_class.new }
describe "#prepare_text_from" do
context "when using vector from OpenAI" do
before { SiteSetting.max_post_length = 100_000 }
fab!(:topic) { Fabricate(:topic) }
@ -18,13 +20,15 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
end
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
let(:model) { DiscourseAi::Embeddings::Models::Base.descendants.sample(1).first }
let(:truncation) { described_class.new(topic, model) }
let(:model) do
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new(truncation)
end
it "truncates a topic" do
truncation.process!
prepared_text =
truncation.prepare_text_from(topic, model.tokenizer, model.max_sequence_length)
expect(model.tokenizer.size(truncation.processed_target)).to be <= model.max_sequence_length
expect(model.tokenizer.size(prepared_text)).to be <= model.max_sequence_length
end
end
end

View File

@ -0,0 +1,18 @@
# frozen_string_literal: true
require_relative "../../../../support/embeddings_generation_stubs"
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -0,0 +1,22 @@
# frozen_string_literal: true
require_relative "../../../../support/embeddings_generation_stubs"
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(
vector_rep.name,
"query: #{text}",
expected_embedding,
)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -0,0 +1,16 @@
# frozen_string_literal: true
require_relative "../../../../support/embeddings_generation_stubs"
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002 do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(vector_rep.name, text, expected_embedding)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -0,0 +1,54 @@
# frozen_string_literal: true
RSpec.shared_examples "generates and store embedding using with vector representation" do
before { @expected_embedding = [0.0038493] * vector_rep.dimensions }
describe "#vector_from" do
it "creates a vector from a given string" do
text = "This is a piece of text"
stub_vector_mapping(text, @expected_embedding)
expect(vector_rep.vector_from(text)).to eq(@expected_embedding)
end
end
describe "#generate_topic_representation_from" do
fab!(:topic) { Fabricate(:topic) }
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
it "creates a vector from a topic and stores it in the database" do
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, @expected_embedding)
vector_rep.generate_topic_representation_from(topic)
expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id)
end
end
describe "#asymmetric_topics_similarity_search" do
fab!(:topic) { Fabricate(:topic) }
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
it "finds IDs of similar topics with a given embedding" do
similar_vector = [0.0038494] * vector_rep.dimensions
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, @expected_embedding)
vector_rep.generate_topic_representation_from(topic)
expect(
vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),
).to contain_exactly(topic.id)
end
end
end

View File

@ -16,9 +16,10 @@ describe ::TopicsController do
context "when a user is logged on" do
it "includes related topics in payload when configured" do
DiscourseAi::Embeddings::SemanticRelated.stubs(:related_topic_ids_for).returns(
[topic1.id, topic2.id, topic3.id],
)
DiscourseAi::Embeddings::SemanticRelated
.any_instance
.stubs(:related_topic_ids_for)
.returns([topic1.id, topic2.id, topic3.id])
get("#{topic.relative_url}.json")
expect(response.status).to eq(200)