FEATURE: HyDE-powered semantic search. (#136)
* FEATURE: HyDE-powered semantic search. It relies on the new outlet added on discourse/discourse#23390 to display semantic search results in an unobtrusive way. We'll use a HyDE-backed approach for semantic search, which consists on generating an hypothetical document from a given keywords, which gets transformed into a vector and used in a asymmetric similarity topic search. This PR also reorganizes the internals to have less moving parts, maintaining one hierarchy of DAOish classes for vector-related operations like transformations and querying. Completions and vectors created by HyDE will remain cached on Redis for now, but we could later use Postgres instead. * Missing translation and rate limiting --------- Co-authored-by: Roman Rizzi <rizziromanalejandro@gmail.com>
This commit is contained in:
parent
3d83d062a1
commit
2c0f535bab
|
@ -9,7 +9,6 @@ module DiscourseAi
|
|||
|
||||
def search
|
||||
query = params[:q]
|
||||
page = (params[:page] || 1).to_i
|
||||
|
||||
grouped_results =
|
||||
Search::GroupedSearchResults.new(
|
||||
|
@ -19,9 +18,15 @@ module DiscourseAi
|
|||
use_pg_headlines_for_excerpt: false,
|
||||
)
|
||||
|
||||
DiscourseAi::Embeddings::SemanticSearch
|
||||
.new(guardian)
|
||||
.search_for_topics(query, page)
|
||||
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian)
|
||||
|
||||
if !semantic_search.cached_query?(query)
|
||||
RateLimiter.new(current_user, "semantic-search", 4, 1.minutes).performed!
|
||||
end
|
||||
|
||||
hijack do
|
||||
semantic_search
|
||||
.search_for_topics(query)
|
||||
.each { |topic_post| grouped_results.add(topic_post) }
|
||||
|
||||
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
|
||||
|
@ -29,3 +34,4 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
{{#if this.searchEnabled}}
|
||||
<div class="semantic-search__container search-results" role="region">
|
||||
<div
|
||||
class="semantic-search__results"
|
||||
{{did-insert this.setup}}
|
||||
{{did-insert this.debouncedSearch}}
|
||||
{{will-destroy this.teardown}}
|
||||
>
|
||||
{{#if this.searching}}
|
||||
<div class="semantic-search__searching">
|
||||
<div class="semantic-search__searching-text">
|
||||
{{i18n "discourse_ai.embeddings.semantic_search_loading"}}
|
||||
</div>
|
||||
<span class="semantic-search__indicator-wave">
|
||||
<span class="semantic-search__indicator-dot">.</span>
|
||||
<span class="semantic-search__indicator-dot">.</span>
|
||||
<span class="semantic-search__indicator-dot">.</span>
|
||||
</span>
|
||||
</div>
|
||||
{{else}}
|
||||
{{#if this.results.length}}
|
||||
<div class="semantic-search__toggle-button-container">
|
||||
<DButton
|
||||
@translatedTitle={{this.collapsedResultsTitle}}
|
||||
@translatedLabel={{this.collapsedResultsTitle}}
|
||||
@action={{fn
|
||||
(mut this.collapsedResults)
|
||||
(not this.collapsedResults)
|
||||
}}
|
||||
@class="btn-flat"
|
||||
@icon={{if this.collapsedResults "chevron-right" "chevron-down"}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{{#unless this.collapsedResults}}
|
||||
<div class="semantic-search__entries">
|
||||
<SearchResultEntries
|
||||
@posts={{this.results}}
|
||||
@highlightQuery={{this.highlightQuery}}
|
||||
/>
|
||||
</div>
|
||||
{{/unless}}
|
||||
{{else}}
|
||||
<div class="semantic-search__searching">
|
||||
{{i18n "discourse_ai.embeddings.semantic_search_results.none"}}
|
||||
</div>
|
||||
{{/if}}
|
||||
{{/if}}
|
||||
</div>
|
||||
</div>
|
||||
{{/if}}
|
|
@ -0,0 +1,82 @@
|
|||
import Component from "@glimmer/component";
|
||||
import { action, computed } from "@ember/object";
|
||||
import I18n from "I18n";
|
||||
import { tracked } from "@glimmer/tracking";
|
||||
import { ajax } from "discourse/lib/ajax";
|
||||
import { translateResults } from "discourse/lib/search";
|
||||
import discourseDebounce from "discourse-common/lib/debounce";
|
||||
import { inject as service } from "@ember/service";
|
||||
import { bind } from "discourse-common/utils/decorators";
|
||||
import { SEARCH_TYPE_DEFAULT } from "discourse/controllers/full-page-search";
|
||||
|
||||
export default class extends Component {
|
||||
static shouldRender(_args, { siteSettings }) {
|
||||
return siteSettings.ai_embeddings_semantic_search_enabled;
|
||||
}
|
||||
|
||||
@service appEvents;
|
||||
|
||||
@tracked searching = false;
|
||||
@tracked collapsedResults = true;
|
||||
@tracked results = [];
|
||||
|
||||
@computed("args.outletArgs.search")
|
||||
get searchTerm() {
|
||||
return this.args.outletArgs.search;
|
||||
}
|
||||
|
||||
@computed("args.outletArgs.type")
|
||||
get searchEnabled() {
|
||||
return this.args.outletArgs.type === SEARCH_TYPE_DEFAULT;
|
||||
}
|
||||
|
||||
@computed("results")
|
||||
get collapsedResultsTitle() {
|
||||
return I18n.t("discourse_ai.embeddings.semantic_search_results.toggle", {
|
||||
count: this.results.length,
|
||||
});
|
||||
}
|
||||
|
||||
@action
|
||||
setup() {
|
||||
this.appEvents.on(
|
||||
"full-page-search:trigger-search",
|
||||
this,
|
||||
"debouncedSearch"
|
||||
);
|
||||
}
|
||||
|
||||
@action
|
||||
teardown() {
|
||||
this.appEvents.off(
|
||||
"full-page-search:trigger-search",
|
||||
this,
|
||||
"debouncedSearch"
|
||||
);
|
||||
}
|
||||
|
||||
@bind
|
||||
performHyDESearch() {
|
||||
if (!this.searchTerm || !this.searchEnabled || this.searching) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.searching = true;
|
||||
this.collapsedResults = true;
|
||||
this.results = [];
|
||||
|
||||
ajax("/discourse-ai/embeddings/semantic-search", {
|
||||
data: { q: this.searchTerm },
|
||||
})
|
||||
.then(async (results) => {
|
||||
const model = (await translateResults(results)) || {};
|
||||
this.results = model.posts;
|
||||
})
|
||||
.finally(() => (this.searching = false));
|
||||
}
|
||||
|
||||
@action
|
||||
debouncedSearch() {
|
||||
discourseDebounce(this, this.performHyDESearch, 500);
|
||||
}
|
||||
}
|
|
@ -1,63 +0,0 @@
|
|||
import { withPluginApi } from "discourse/lib/plugin-api";
|
||||
import { translateResults, updateRecentSearches } from "discourse/lib/search";
|
||||
import { ajax } from "discourse/lib/ajax";
|
||||
|
||||
const SEMANTIC_SEARCH = "semantic_search";
|
||||
|
||||
function initializeSemanticSearch(api) {
|
||||
api.addFullPageSearchType(
|
||||
"discourse_ai.embeddings.semantic_search",
|
||||
SEMANTIC_SEARCH,
|
||||
(searchController, args) => {
|
||||
if (searchController.currentUser) {
|
||||
updateRecentSearches(searchController.currentUser, args.searchTerm);
|
||||
}
|
||||
|
||||
ajax("/discourse-ai/embeddings/semantic-search", { data: args })
|
||||
.then(async (results) => {
|
||||
const model = (await translateResults(results)) || {};
|
||||
|
||||
if (results.grouped_search_result) {
|
||||
searchController.set("q", results.grouped_search_result.term);
|
||||
}
|
||||
|
||||
if (args.page > 1) {
|
||||
if (model) {
|
||||
searchController.model.posts.pushObjects(model.posts);
|
||||
searchController.model.topics.pushObjects(model.topics);
|
||||
searchController.model.set(
|
||||
"grouped_search_result",
|
||||
results.grouped_search_result
|
||||
);
|
||||
}
|
||||
} else {
|
||||
model.grouped_search_result = results.grouped_search_result;
|
||||
searchController.set("model", model);
|
||||
}
|
||||
searchController.set("error", null);
|
||||
})
|
||||
.catch((e) => {
|
||||
searchController.set("error", e.jqXHR.responseJSON?.message);
|
||||
})
|
||||
.finally(() => {
|
||||
searchController.setProperties({
|
||||
searching: false,
|
||||
loading: false,
|
||||
});
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
export default {
|
||||
name: "discourse-ai-full-page-semantic-search",
|
||||
|
||||
initialize(container) {
|
||||
const settings = container.lookup("service:site-settings");
|
||||
const semanticSearch = settings.ai_embeddings_semantic_search_enabled;
|
||||
|
||||
if (settings.ai_embeddings_enabled && semanticSearch) {
|
||||
withPluginApi("1.6.0", initializeSemanticSearch);
|
||||
}
|
||||
},
|
||||
};
|
|
@ -0,0 +1,39 @@
|
|||
.semantic-search__container {
|
||||
background: var(--primary-very-low);
|
||||
margin: 1rem 0 1rem 0;
|
||||
|
||||
.semantic-search__results {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: baseline;
|
||||
|
||||
.semantic-search {
|
||||
&__searching-text {
|
||||
display: inline-block;
|
||||
margin-left: 3px;
|
||||
}
|
||||
&__indicator-wave {
|
||||
flex: 0 0 auto;
|
||||
display: inline-flex;
|
||||
}
|
||||
&__indicator-dot {
|
||||
display: inline-block;
|
||||
animation: ai-summary__indicator-wave 1.8s linear infinite;
|
||||
&:nth-child(2) {
|
||||
animation-delay: -1.6s;
|
||||
}
|
||||
&:nth-child(3) {
|
||||
animation-delay: -1.4s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.semantic-search__entries {
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.semantic-search__searching {
|
||||
margin-left: 5px;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -34,6 +34,10 @@ en:
|
|||
|
||||
embeddings:
|
||||
semantic_search: "Topics (Semantic)"
|
||||
semantic_search_loading: "Searching for more results using AI"
|
||||
semantic_search_results:
|
||||
toggle: "Found %{count} results using AI"
|
||||
none: "Sorry, our AI search found no matching topics."
|
||||
|
||||
ai_bot:
|
||||
pm_warning: "AI chatbot messages are monitored regularly by moderators."
|
||||
|
|
|
@ -55,6 +55,7 @@ en:
|
|||
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
|
||||
ai_embeddings_semantic_search_enabled: "Enable full-page semantic search."
|
||||
ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results"
|
||||
ai_embeddings_semantic_search_hyde_model: "Model used to expand keywords to get better results during a semantic search"
|
||||
|
||||
ai_summarization_discourse_service_api_endpoint: "URL where the Discourse summarization API is running."
|
||||
ai_summarization_discourse_service_api_key: "API key for the Discourse summarization API."
|
||||
|
|
|
@ -177,6 +177,18 @@ discourse_ai:
|
|||
ai_embeddings_semantic_search_enabled:
|
||||
default: false
|
||||
client: true
|
||||
ai_embeddings_semantic_search_hyde_model:
|
||||
default: "gpt-3.5-turbo"
|
||||
type: enum
|
||||
allow_any: false
|
||||
choices:
|
||||
- Llama2-*-chat-hf
|
||||
- claude-instant-1
|
||||
- claude-2
|
||||
- gpt-3.5-turbo
|
||||
- gpt-4
|
||||
- StableBeluga2
|
||||
- Upstage-Llama-2-*-instruct-v2
|
||||
|
||||
ai_summarization_discourse_service_api_endpoint: ""
|
||||
ai_summarization_discourse_service_api_key:
|
||||
|
|
|
@ -11,13 +11,12 @@ module DiscourseAi
|
|||
return [] if @text.blank?
|
||||
return [] unless SiteSetting.ai_embeddings_enabled
|
||||
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
|
||||
candidates =
|
||||
::DiscourseAi::Embeddings::SemanticSearch.new(nil).asymmetric_semantic_search(
|
||||
@text,
|
||||
100,
|
||||
0,
|
||||
return_distance: true,
|
||||
)
|
||||
vector_rep.asymmetric_semantic_search(@text, limit: 100, offset: 0, return_distance: true)
|
||||
candidate_ids = candidates.map(&:first)
|
||||
|
||||
::Topic
|
||||
|
|
|
@ -4,16 +4,21 @@ module DiscourseAi
|
|||
module Embeddings
|
||||
class EntryPoint
|
||||
def load_files
|
||||
require_relative "models/base"
|
||||
require_relative "models/all_mpnet_base_v2"
|
||||
require_relative "models/text_embedding_ada_002"
|
||||
require_relative "models/multilingual_e5_large"
|
||||
require_relative "vector_representations/base"
|
||||
require_relative "vector_representations/all_mpnet_base_v2"
|
||||
require_relative "vector_representations/text_embedding_ada_002"
|
||||
require_relative "vector_representations/multilingual_e5_large"
|
||||
require_relative "strategies/truncation"
|
||||
require_relative "manager"
|
||||
require_relative "jobs/regular/generate_embeddings"
|
||||
require_relative "semantic_related"
|
||||
require_relative "semantic_search"
|
||||
require_relative "semantic_topic_query"
|
||||
|
||||
require_relative "hyde_generators/base"
|
||||
require_relative "hyde_generators/openai"
|
||||
require_relative "hyde_generators/anthropic"
|
||||
require_relative "hyde_generators/llama2"
|
||||
require_relative "hyde_generators/llama2_ftos"
|
||||
require_relative "semantic_search"
|
||||
end
|
||||
|
||||
def inject_into(plugin)
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module HydeGenerators
|
||||
class Anthropic < DiscourseAi::Embeddings::HydeGenerators::Base
|
||||
def prompt(search_term)
|
||||
<<~TEXT
|
||||
Given a search term given between <input> tags, generate a forum post about the search term.
|
||||
Respond with the generated post between <ai> tags.
|
||||
|
||||
<input>#{search_term}</input>
|
||||
TEXT
|
||||
end
|
||||
|
||||
def models
|
||||
%w[claude-instant-1 claude-2]
|
||||
end
|
||||
|
||||
def hypothetical_post_from(query)
|
||||
response =
|
||||
::DiscourseAi::Inference::AnthropicCompletions.perform!(
|
||||
prompt(query),
|
||||
SiteSetting.ai_embeddings_semantic_search_hyde_model,
|
||||
).dig(:completion)
|
||||
|
||||
Nokogiri::HTML5.fragment(response).at("ai").text
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,17 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module HydeGenerators
|
||||
class Base
|
||||
def self.current_hyde_model
|
||||
DiscourseAi::Embeddings::HydeGenerators::Base.descendants.find do |generator_klass|
|
||||
generator_klass.new.models.include?(
|
||||
SiteSetting.ai_embeddings_semantic_search_hyde_model,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,34 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module HydeGenerators
|
||||
class Llama2 < DiscourseAi::Embeddings::HydeGenerators::Base
|
||||
def prompt(search_term)
|
||||
<<~TEXT
|
||||
[INST] <<SYS>>
|
||||
You are a helpful bot
|
||||
You create forum posts about a given topic
|
||||
<</SYS>>
|
||||
|
||||
Topic: #{search_term}
|
||||
[/INST]
|
||||
Here is a forum post about the above topic:
|
||||
TEXT
|
||||
end
|
||||
|
||||
def models
|
||||
["Llama2-*-chat-hf"]
|
||||
end
|
||||
|
||||
def hypothetical_post_from(query)
|
||||
::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(
|
||||
prompt(query),
|
||||
SiteSetting.ai_embeddings_semantic_search_hyde_model,
|
||||
token_limit: 400,
|
||||
).dig(:generated_text)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,27 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module HydeGenerators
|
||||
class Llama2Ftos < DiscourseAi::Embeddings::HydeGenerators::Llama2
|
||||
def prompt(search_term)
|
||||
<<~TEXT
|
||||
### System:
|
||||
You are a helpful bot
|
||||
You create forum posts about a given topic
|
||||
|
||||
### User:
|
||||
Topic: #{search_term}
|
||||
|
||||
### Assistant:
|
||||
Here is a forum post about the above topic:
|
||||
TEXT
|
||||
end
|
||||
|
||||
def models
|
||||
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2]
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,30 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module HydeGenerators
|
||||
class OpenAi < DiscourseAi::Embeddings::HydeGenerators::Base
|
||||
def prompt(search_term)
|
||||
[
|
||||
{
|
||||
role: "system",
|
||||
content: "You are a helpful bot. You create forum posts about a given topic.",
|
||||
},
|
||||
{ role: "user", content: "Create a forum post about the topic: #{search_term}" },
|
||||
]
|
||||
end
|
||||
|
||||
def models
|
||||
%w[gpt-3.5-turbo gpt-4]
|
||||
end
|
||||
|
||||
def hypothetical_post_from(query)
|
||||
::DiscourseAi::Inference::OpenAiCompletions.perform!(
|
||||
prompt(query),
|
||||
SiteSetting.ai_embeddings_semantic_search_hyde_model,
|
||||
).dig(:choices, 0, :message, :content)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -11,7 +11,11 @@ module Jobs
|
|||
post = topic.first_post
|
||||
return if post.nil? || post.raw.blank?
|
||||
|
||||
DiscourseAi::Embeddings::Manager.new(topic).generate!
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new
|
||||
|
||||
vector_rep.generate_topic_representation_from(topic, strategy)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
class Manager
|
||||
attr_reader :target, :model, :strategy
|
||||
|
||||
def initialize(target)
|
||||
@target = target
|
||||
@model =
|
||||
DiscourseAi::Embeddings::Models::Base.subclasses.find do
|
||||
_1.name == SiteSetting.ai_embeddings_model
|
||||
end
|
||||
@strategy = DiscourseAi::Embeddings::Strategies::Truncation.new(@target, @model)
|
||||
end
|
||||
|
||||
def generate!
|
||||
@strategy.process!
|
||||
|
||||
# TODO bail here if we already have an embedding with matching version and digest
|
||||
|
||||
@embeddings = @model.generate_embeddings(@strategy.processed_target)
|
||||
|
||||
persist!
|
||||
end
|
||||
|
||||
def persist!
|
||||
begin
|
||||
DB.exec(
|
||||
<<~SQL,
|
||||
INSERT INTO ai_topic_embeddings_#{table_suffix} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
|
||||
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT (topic_id)
|
||||
DO UPDATE SET
|
||||
model_version = :model_version,
|
||||
strategy_version = :strategy_version,
|
||||
digest = :digest,
|
||||
embeddings = '[:embeddings]',
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
|
||||
SQL
|
||||
topic_id: @target.id,
|
||||
model_version: @model.version,
|
||||
strategy_version: @strategy.version,
|
||||
digest: @strategy.digest,
|
||||
embeddings: @embeddings,
|
||||
)
|
||||
rescue PG::Error => e
|
||||
Rails.logger.error(
|
||||
"Error #{e} persisting embedding for topic #{topic.id} and model #{model.name}",
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def table_suffix
|
||||
"#{@model.id}_#{@strategy.id}"
|
||||
end
|
||||
|
||||
def topic_embeddings_table
|
||||
"ai_topic_embeddings_#{table_suffix}"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,52 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module Models
|
||||
class AllMpnetBaseV2 < Base
|
||||
class << self
|
||||
def id
|
||||
1
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def name
|
||||
"all-mpnet-base-v2"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
768
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
384
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<#>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"vector_ip_ops"
|
||||
end
|
||||
|
||||
def generate_embeddings(text)
|
||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
name,
|
||||
text,
|
||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||
)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,10 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module Models
|
||||
class Base
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,52 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module Models
|
||||
class MultilingualE5Large < Base
|
||||
class << self
|
||||
def id
|
||||
3
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def name
|
||||
"multilingual-e5-large"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1024
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
512
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<=>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"vector_cosine_ops"
|
||||
end
|
||||
|
||||
def generate_embeddings(text)
|
||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
name,
|
||||
"query: #{text}",
|
||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||
)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,48 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module Models
|
||||
class TextEmbeddingAda002 < Base
|
||||
class << self
|
||||
def id
|
||||
2
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def name
|
||||
"text-embedding-ada-002"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1536
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
8191
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<=>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"vector_cosine_ops"
|
||||
end
|
||||
|
||||
def generate_embeddings(text)
|
||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text)
|
||||
response[:data].first[:embedding]
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -5,31 +5,32 @@ module DiscourseAi
|
|||
class SemanticRelated
|
||||
MissingEmbeddingError = Class.new(StandardError)
|
||||
|
||||
class << self
|
||||
def semantic_suggested_key(topic_id)
|
||||
"semantic-suggested-topic-#{topic_id}"
|
||||
end
|
||||
|
||||
def build_semantic_suggested_key(topic_id)
|
||||
"build-semantic-suggested-topic-#{topic_id}"
|
||||
end
|
||||
|
||||
def clear_cache_for(topic)
|
||||
Discourse.cache.delete(semantic_suggested_key(topic.id))
|
||||
Discourse.redis.del(build_semantic_suggested_key(topic.id))
|
||||
def self.clear_cache_for(topic)
|
||||
Discourse.cache.delete("semantic-suggested-topic-#{topic.id}")
|
||||
Discourse.redis.del("build-semantic-suggested-topic-#{topic.id}")
|
||||
end
|
||||
|
||||
def related_topic_ids_for(topic)
|
||||
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
|
||||
|
||||
manager = DiscourseAi::Embeddings::Manager.new(topic)
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
cache_for = results_ttl(topic)
|
||||
|
||||
begin
|
||||
asd =
|
||||
Discourse
|
||||
.cache
|
||||
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
|
||||
symmetric_semantic_search(manager)
|
||||
vector_rep
|
||||
.symmetric_topics_similarity_search(topic)
|
||||
.tap do |candidate_ids|
|
||||
# Happens when the topic doesn't have any embeddings
|
||||
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
|
||||
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
|
||||
raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}"
|
||||
end
|
||||
end
|
||||
end
|
||||
rescue MissingEmbeddingError
|
||||
# avoid a flood of jobs when visiting topic
|
||||
|
@ -43,50 +44,6 @@ module DiscourseAi
|
|||
end
|
||||
[]
|
||||
end
|
||||
end
|
||||
|
||||
def symmetric_semantic_search(manager)
|
||||
topic = manager.target
|
||||
candidate_ids = self.query_symmetric_embeddings(manager)
|
||||
|
||||
# Happens when the topic doesn't have any embeddings
|
||||
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
|
||||
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
|
||||
raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}"
|
||||
end
|
||||
|
||||
candidate_ids
|
||||
end
|
||||
|
||||
def query_symmetric_embeddings(manager)
|
||||
topic = manager.target
|
||||
model = manager.model
|
||||
table = manager.topic_embeddings_table
|
||||
begin
|
||||
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
||||
SELECT
|
||||
topic_id
|
||||
FROM
|
||||
#{table}
|
||||
ORDER BY
|
||||
embeddings #{model.pg_function} (
|
||||
SELECT
|
||||
embeddings
|
||||
FROM
|
||||
#{table}
|
||||
WHERE
|
||||
topic_id = :topic_id
|
||||
LIMIT 1
|
||||
)
|
||||
LIMIT 100
|
||||
SQL
|
||||
rescue PG::Error => e
|
||||
Rails.logger.error(
|
||||
"Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}",
|
||||
)
|
||||
raise MissingEmbeddingError
|
||||
end
|
||||
end
|
||||
|
||||
def results_ttl(topic)
|
||||
case topic.created_at
|
||||
|
@ -100,6 +57,15 @@ module DiscourseAi
|
|||
1.week
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def semantic_suggested_key(topic_id)
|
||||
"semantic-suggested-topic-#{topic_id}"
|
||||
end
|
||||
|
||||
def build_semantic_suggested_key(topic_id)
|
||||
"build-semantic-suggested-topic-#{topic_id}"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -3,59 +3,66 @@
|
|||
module DiscourseAi
|
||||
module Embeddings
|
||||
class SemanticSearch
|
||||
def self.clear_cache_for(query)
|
||||
digest = OpenSSL::Digest::SHA1.hexdigest(query)
|
||||
|
||||
Discourse.cache.delete("hyde-doc-#{digest}")
|
||||
Discourse.cache.delete("hyde-doc-embedding-#{digest}")
|
||||
end
|
||||
|
||||
def initialize(guardian)
|
||||
@guardian = guardian
|
||||
@manager = DiscourseAi::Embeddings::Manager.new(nil)
|
||||
@model = @manager.model
|
||||
end
|
||||
|
||||
def cached_query?(query)
|
||||
digest = OpenSSL::Digest::SHA1.hexdigest(query)
|
||||
Discourse.cache.read("hyde-doc-embedding-#{digest}").present?
|
||||
end
|
||||
|
||||
def search_for_topics(query, page = 1)
|
||||
limit = Search.per_filter + 1
|
||||
offset = (page - 1) * Search.per_filter
|
||||
max_results_per_page = 50
|
||||
limit = [Search.per_filter, max_results_per_page].min + 1
|
||||
offset = (page - 1) * limit
|
||||
|
||||
candidate_ids = asymmetric_semantic_search(query, limit, offset)
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
|
||||
digest = OpenSSL::Digest::SHA1.hexdigest(query)
|
||||
|
||||
hypothetical_post =
|
||||
Discourse
|
||||
.cache
|
||||
.fetch("hyde-doc-#{digest}", expires_in: 1.week) do
|
||||
hyde_generator = DiscourseAi::Embeddings::HydeGenerators::Base.current_hyde_model.new
|
||||
hyde_generator.hypothetical_post_from(query)
|
||||
end
|
||||
|
||||
hypothetical_post_embedding =
|
||||
Discourse
|
||||
.cache
|
||||
.fetch("hyde-doc-embedding-#{digest}", expires_in: 1.week) do
|
||||
vector_rep.vector_from(hypothetical_post)
|
||||
end
|
||||
|
||||
candidate_topic_ids =
|
||||
vector_rep.asymmetric_topics_similarity_search(
|
||||
hypothetical_post_embedding,
|
||||
limit: limit,
|
||||
offset: offset,
|
||||
)
|
||||
|
||||
::Post
|
||||
.where(post_type: ::Topic.visible_post_types(guardian.user))
|
||||
.public_posts
|
||||
.where("topics.visible")
|
||||
.where(topic_id: candidate_ids, post_number: 1)
|
||||
.order("array_position(ARRAY#{candidate_ids}, topic_id)")
|
||||
end
|
||||
|
||||
def asymmetric_semantic_search(query, limit, offset, return_distance: false)
|
||||
embedding = model.generate_embeddings(query)
|
||||
table = @manager.topic_embeddings_table
|
||||
|
||||
begin
|
||||
candidate_ids = DB.query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset)
|
||||
SELECT
|
||||
topic_id,
|
||||
embeddings #{@model.pg_function} '[:query_embedding]' AS distance
|
||||
FROM
|
||||
#{table}
|
||||
ORDER BY
|
||||
embeddings #{@model.pg_function} '[:query_embedding]'
|
||||
LIMIT :limit
|
||||
OFFSET :offset
|
||||
SQL
|
||||
rescue PG::Error => e
|
||||
Rails.logger.error(
|
||||
"Error #{e} querying embeddings for model #{model.name} and search #{query}",
|
||||
)
|
||||
raise MissingEmbeddingError
|
||||
end
|
||||
|
||||
if return_distance
|
||||
candidate_ids.map { |c| [c.topic_id, c.distance] }
|
||||
else
|
||||
candidate_ids.map(&:topic_id)
|
||||
end
|
||||
.where(topic_id: candidate_topic_ids, post_number: 1)
|
||||
.order("array_position(ARRAY#{candidate_topic_ids}, topic_id)")
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
attr_reader :model, :guardian
|
||||
attr_reader :guardian
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -14,7 +14,7 @@ class DiscourseAi::Embeddings::SemanticTopicQuery < TopicQuery
|
|||
|
||||
list =
|
||||
create_list(:semantic_related, query_opts) do |topics|
|
||||
candidate_ids = DiscourseAi::Embeddings::SemanticRelated.related_topic_ids_for(topic)
|
||||
candidate_ids = DiscourseAi::Embeddings::SemanticRelated.new.related_topic_ids_for(topic)
|
||||
|
||||
list =
|
||||
topics
|
||||
|
|
|
@ -4,77 +4,57 @@ module DiscourseAi
|
|||
module Embeddings
|
||||
module Strategies
|
||||
class Truncation
|
||||
attr_reader :processed_target, :digest
|
||||
|
||||
def self.id
|
||||
1
|
||||
end
|
||||
|
||||
def id
|
||||
self.class.id
|
||||
1
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def initialize(target, model)
|
||||
@model = model
|
||||
@target = target
|
||||
@tokenizer = @model.tokenizer
|
||||
@max_length = @model.max_sequence_length - 2
|
||||
@processed_target = nil
|
||||
end
|
||||
|
||||
# Need a better name for this method
|
||||
def process!
|
||||
case @target
|
||||
def prepare_text_from(target, tokenizer, max_length)
|
||||
case target
|
||||
when Topic
|
||||
@processed_target = topic_truncation(@target)
|
||||
topic_truncation(target, tokenizer, max_length)
|
||||
when Post
|
||||
@processed_target = post_truncation(@target)
|
||||
post_truncation(target, tokenizer, max_length)
|
||||
else
|
||||
raise ArgumentError, "Invalid target type"
|
||||
end
|
||||
|
||||
@digest = OpenSSL::Digest::SHA1.hexdigest(@processed_target)
|
||||
end
|
||||
|
||||
def topic_truncation(topic)
|
||||
t = +""
|
||||
private
|
||||
|
||||
t << topic.title
|
||||
t << "\n\n"
|
||||
t << topic.category.name
|
||||
def topic_information(topic)
|
||||
info = +""
|
||||
|
||||
info << topic.title
|
||||
info << "\n\n"
|
||||
info << topic.category.name
|
||||
if SiteSetting.tagging_enabled
|
||||
t << "\n\n"
|
||||
t << topic.tags.pluck(:name).join(", ")
|
||||
info << "\n\n"
|
||||
info << topic.tags.pluck(:name).join(", ")
|
||||
end
|
||||
t << "\n\n"
|
||||
info << "\n\n"
|
||||
end
|
||||
|
||||
def topic_truncation(topic, tokenizer, max_length)
|
||||
text = +topic_information(topic)
|
||||
|
||||
topic.posts.find_each do |post|
|
||||
t << post.raw
|
||||
break if @tokenizer.size(t) >= @max_length #maybe keep a partial counter to speed this up?
|
||||
t << "\n\n"
|
||||
text << post.raw
|
||||
break if tokenizer.size(text) >= max_length #maybe keep a partial counter to speed this up?
|
||||
text << "\n\n"
|
||||
end
|
||||
|
||||
@tokenizer.truncate(t, @max_length)
|
||||
tokenizer.truncate(text, max_length)
|
||||
end
|
||||
|
||||
def post_truncation(post)
|
||||
t = +""
|
||||
def post_truncation(topic, tokenizer, max_length)
|
||||
text = +topic_information(post.topic)
|
||||
text << post.raw
|
||||
|
||||
t << post.topic.title
|
||||
t << "\n\n"
|
||||
t << post.topic.category.name
|
||||
if SiteSetting.tagging_enabled
|
||||
t << "\n\n"
|
||||
t << post.topic.tags.pluck(:name).join(", ")
|
||||
end
|
||||
t << "\n\n"
|
||||
t << post.raw
|
||||
|
||||
@tokenizer.truncate(t, @max_length)
|
||||
tokenizer.truncate(text, max_length)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class AllMpnetBaseV2 < Base
|
||||
def vector_from(text)
|
||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
name,
|
||||
text,
|
||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||
)
|
||||
end
|
||||
|
||||
def name
|
||||
"all-mpnet-base-v2"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
768
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
384
|
||||
end
|
||||
|
||||
def id
|
||||
1
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<#>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"vector_ip_ops"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,166 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class Base
|
||||
def self.current_representation(strategy)
|
||||
subclasses.map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model }
|
||||
end
|
||||
|
||||
def initialize(strategy)
|
||||
@strategy = strategy
|
||||
end
|
||||
|
||||
def create_index(lists, probes)
|
||||
index_name = "#{table_name}_search"
|
||||
|
||||
DB.exec(<<~SQL)
|
||||
DROP INDEX IF EXISTS #{index_name};
|
||||
CREATE INDEX IF NOT EXISTS
|
||||
#{index}
|
||||
ON
|
||||
#{table_name}
|
||||
USING
|
||||
ivfflat (embeddings #{pg_index_type})
|
||||
WITH
|
||||
(lists = #{lists})
|
||||
WHERE
|
||||
model_version = #{version} AND
|
||||
strategy_version = #{@strategy.version};
|
||||
SQL
|
||||
end
|
||||
|
||||
def vector_from(text)
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def generate_topic_representation_from(target, persist: true)
|
||||
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
|
||||
|
||||
vector_from(text).tap do |vector|
|
||||
if persist
|
||||
digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
||||
save_to_db(target, vector, digest)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def topic_id_from_representation(raw_vector)
|
||||
DB.query_single(<<~SQL, query_embedding: raw_vector).first
|
||||
SELECT
|
||||
topic_id
|
||||
FROM
|
||||
#{table_name}
|
||||
ORDER BY
|
||||
embeddings #{pg_function} '[:query_embedding]'
|
||||
LIMIT 1
|
||||
SQL
|
||||
end
|
||||
|
||||
def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
|
||||
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
|
||||
SELECT
|
||||
topic_id,
|
||||
embeddings #{pg_function} '[:query_embedding]' AS distance
|
||||
FROM
|
||||
#{table_name}
|
||||
ORDER BY
|
||||
embeddings #{pg_function} '[:query_embedding]'
|
||||
LIMIT :limit
|
||||
OFFSET :offset
|
||||
SQL
|
||||
|
||||
if return_distance
|
||||
results.map { |r| [r.topic_id, r.distance] }
|
||||
else
|
||||
results.map(&:topic_id)
|
||||
end
|
||||
rescue PG::Error => e
|
||||
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
||||
raise MissingEmbeddingError
|
||||
end
|
||||
|
||||
def symmetric_topics_similarity_search(topic)
|
||||
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
||||
SELECT
|
||||
topic_id
|
||||
FROM
|
||||
#{table_name}
|
||||
ORDER BY
|
||||
embeddings #{pg_function} (
|
||||
SELECT
|
||||
embeddings
|
||||
FROM
|
||||
#{table_name}
|
||||
WHERE
|
||||
topic_id = :topic_id
|
||||
LIMIT 1
|
||||
)
|
||||
LIMIT 100
|
||||
SQL
|
||||
rescue PG::Error => e
|
||||
Rails.logger.error(
|
||||
"Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
|
||||
)
|
||||
raise MissingEmbeddingError
|
||||
end
|
||||
|
||||
def table_name
|
||||
"ai_topic_embeddings_#{id}_#{@strategy.id}"
|
||||
end
|
||||
|
||||
def name
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def dimensions
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def id
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def pg_function
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def version
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def save_to_db(target, vector, digest)
|
||||
DB.exec(
|
||||
<<~SQL,
|
||||
INSERT INTO #{table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
|
||||
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT (topic_id)
|
||||
DO UPDATE SET
|
||||
model_version = :model_version,
|
||||
strategy_version = :strategy_version,
|
||||
digest = :digest,
|
||||
embeddings = '[:embeddings]',
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
SQL
|
||||
topic_id: target.id,
|
||||
model_version: version,
|
||||
strategy_version: @strategy.version,
|
||||
digest: digest,
|
||||
embeddings: vector,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,50 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class MultilingualE5Large < Base
|
||||
def vector_from(text)
|
||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
name,
|
||||
"query: #{text}",
|
||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||
)
|
||||
end
|
||||
|
||||
def id
|
||||
3
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def name
|
||||
"multilingual-e5-large"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1024
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
512
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<=>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"vector_cosine_ops"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,46 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class TextEmbeddingAda002 < Base
|
||||
def id
|
||||
2
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def name
|
||||
"text-embedding-ada-002"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1536
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
8191
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<=>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"vector_cosine_ops"
|
||||
end
|
||||
|
||||
def vector_from(text)
|
||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text)
|
||||
response[:data].first[:embedding]
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -4,7 +4,7 @@ module ::DiscourseAi
|
|||
module Inference
|
||||
class HuggingFaceTextGeneration
|
||||
CompletionFailed = Class.new(StandardError)
|
||||
TIMEOUT = 60
|
||||
TIMEOUT = 120
|
||||
|
||||
def self.perform!(
|
||||
prompt,
|
||||
|
|
|
@ -4,18 +4,22 @@ desc "Backfill embeddings for all topics"
|
|||
task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args|
|
||||
public_categories = Category.where(read_restricted: false).pluck(:id)
|
||||
manager = DiscourseAi::Embeddings::Manager.new(Topic.first)
|
||||
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new(strategy)
|
||||
table_name = vector_rep.table_name
|
||||
|
||||
Topic
|
||||
.joins(
|
||||
"LEFT JOIN #{manager.topic_embeddings_table} ON #{manager.topic_embeddings_table}.topic_id = topics.id",
|
||||
)
|
||||
.where("#{manager.topic_embeddings_table}.topic_id IS NULL")
|
||||
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
|
||||
.where("#{table_name}.topic_id IS NULL")
|
||||
.where("topics.id >= ?", args[:start_topic].to_i || 0)
|
||||
.where("category_id IN (?)", public_categories)
|
||||
.where(deleted_at: nil)
|
||||
.order("topics.id ASC")
|
||||
.find_each do |t|
|
||||
print "."
|
||||
DiscourseAi::Embeddings::Manager.new(t).generate!
|
||||
vector_rep.generate_topic_representation_from(t)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -28,25 +32,11 @@ task "ai:embeddings:index", [:work_mem] => [:environment] do |_, args|
|
|||
lists = count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i
|
||||
probes = count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i
|
||||
|
||||
manager = DiscourseAi::Embeddings::Manager.new(Topic.first)
|
||||
table = manager.topic_embeddings_table
|
||||
index = "#{table}_search"
|
||||
vector_representation_klass = DiscourseAi::Embeddings::Vectors::Base.find_vector_representation
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
|
||||
DB.exec("SET work_mem TO '#{args[:work_mem] || "1GB"}';")
|
||||
DB.exec(<<~SQL)
|
||||
DROP INDEX IF EXISTS #{index};
|
||||
CREATE INDEX IF NOT EXISTS
|
||||
#{index}
|
||||
ON
|
||||
#{table}
|
||||
USING
|
||||
ivfflat (embeddings #{manager.model.pg_index_type})
|
||||
WITH
|
||||
(lists = #{lists})
|
||||
WHERE
|
||||
model_version = #{manager.model.version} AND
|
||||
strategy_version = #{manager.strategy.version};
|
||||
SQL
|
||||
vector_representation_klass.new(strategy).create_index(lists, probes)
|
||||
DB.exec("RESET work_mem;")
|
||||
DB.exec("SET ivfflat.probes = #{probes};")
|
||||
end
|
||||
|
|
|
@ -17,6 +17,7 @@ register_asset "stylesheets/modules/ai-helper/common/ai-helper.scss"
|
|||
register_asset "stylesheets/modules/ai-bot/common/bot-replies.scss"
|
||||
|
||||
register_asset "stylesheets/modules/embeddings/common/semantic-related-topics.scss"
|
||||
register_asset "stylesheets/modules/embeddings/common/semantic-search.scss"
|
||||
|
||||
module ::DiscourseAi
|
||||
PLUGIN_NAME = "discourse-ai"
|
||||
|
|
|
@ -1,44 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "../../support/embeddings_generation_stubs"
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::Manager do
|
||||
let(:user) { Fabricate(:user) }
|
||||
let(:expected_embedding) do
|
||||
JSON.parse(
|
||||
File.read("#{Rails.root}/plugins/discourse-ai/spec/fixtures/embeddings/embedding.txt"),
|
||||
)
|
||||
end
|
||||
let(:discourse_model) { "all-mpnet-base-v2" }
|
||||
|
||||
before do
|
||||
SiteSetting.discourse_ai_enabled = true
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
SiteSetting.ai_embeddings_model = "all-mpnet-base-v2"
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
Jobs.run_immediately!
|
||||
end
|
||||
|
||||
it "generates embeddings for new topics automatically" do
|
||||
pc =
|
||||
PostCreator.new(
|
||||
user,
|
||||
raw: "this is the new content for my topic",
|
||||
title: "this is my new topic title",
|
||||
)
|
||||
input =
|
||||
"This is my new topic title\n\nUncategorized\n\n\n\nthis is the new content for my topic\n\n"
|
||||
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
|
||||
post = pc.create
|
||||
manager = DiscourseAi::Embeddings::Manager.new(post.topic)
|
||||
|
||||
embeddings =
|
||||
DB.query_single(
|
||||
"SELECT embeddings FROM #{manager.topic_embeddings_table} WHERE topic_id = #{post.topic.id}",
|
||||
).first
|
||||
|
||||
expect(embeddings.split(",")[1].to_f).to be_within(0.0001).of(expected_embedding[1])
|
||||
expect(embeddings.split(",")[13].to_f).to be_within(0.0001).of(expected_embedding[13])
|
||||
expect(embeddings.split(",")[135].to_f).to be_within(0.0001).of(expected_embedding[135])
|
||||
end
|
||||
end
|
|
@ -28,97 +28,4 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
|||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "SemanticTopicQuery extension" do
|
||||
describe "#list_semantic_related_topics" do
|
||||
subject(:topic_query) { DiscourseAi::Embeddings::SemanticTopicQuery.new(user) }
|
||||
|
||||
fab!(:target) { Fabricate(:topic) }
|
||||
|
||||
def stub_semantic_search_with(results)
|
||||
DiscourseAi::Embeddings::SemanticRelated.expects(:related_topic_ids_for).returns(results)
|
||||
end
|
||||
|
||||
context "when the semantic search returns an unlisted topic" do
|
||||
fab!(:unlisted_topic) { Fabricate(:topic, visible: false) }
|
||||
|
||||
before { stub_semantic_search_with([unlisted_topic.id]) }
|
||||
|
||||
it "filters it out" do
|
||||
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
|
||||
end
|
||||
end
|
||||
|
||||
context "when the semantic search returns a private topic" do
|
||||
fab!(:private_topic) { Fabricate(:private_message_topic) }
|
||||
|
||||
before { stub_semantic_search_with([private_topic.id]) }
|
||||
|
||||
it "filters it out" do
|
||||
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
|
||||
end
|
||||
end
|
||||
|
||||
context "when the semantic search returns a topic from a restricted category" do
|
||||
fab!(:group) { Fabricate(:group) }
|
||||
fab!(:category) { Fabricate(:private_category, group: group) }
|
||||
fab!(:secured_category_topic) { Fabricate(:topic, category: category) }
|
||||
|
||||
before { stub_semantic_search_with([secured_category_topic.id]) }
|
||||
|
||||
it "filters it out" do
|
||||
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
|
||||
end
|
||||
|
||||
it "doesn't filter it out if the user has access to the category" do
|
||||
group.add(user)
|
||||
|
||||
expect(topic_query.list_semantic_related_topics(target).topics).to contain_exactly(
|
||||
secured_category_topic,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
context "when the semantic search returns a closed topic and we explicitly exclude them" do
|
||||
fab!(:closed_topic) { Fabricate(:topic, closed: true) }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_semantic_related_include_closed_topics = false
|
||||
stub_semantic_search_with([closed_topic.id])
|
||||
end
|
||||
|
||||
it "filters it out" do
|
||||
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
|
||||
end
|
||||
end
|
||||
|
||||
context "when the semantic search returns public topics" do
|
||||
fab!(:normal_topic_1) { Fabricate(:topic) }
|
||||
fab!(:normal_topic_2) { Fabricate(:topic) }
|
||||
fab!(:normal_topic_3) { Fabricate(:topic) }
|
||||
fab!(:closed_topic) { Fabricate(:topic, closed: true) }
|
||||
|
||||
before do
|
||||
stub_semantic_search_with(
|
||||
[closed_topic.id, normal_topic_1.id, normal_topic_2.id, normal_topic_3.id],
|
||||
)
|
||||
end
|
||||
|
||||
it "filters it out" do
|
||||
expect(topic_query.list_semantic_related_topics(target).topics).to eq(
|
||||
[closed_topic, normal_topic_1, normal_topic_2, normal_topic_3],
|
||||
)
|
||||
end
|
||||
|
||||
it "returns the plugin limit for the number of results" do
|
||||
SiteSetting.ai_embeddings_semantic_related_topics = 2
|
||||
|
||||
expect(topic_query.list_semantic_related_topics(target).topics).to contain_exactly(
|
||||
closed_topic,
|
||||
normal_topic_1,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "../../../../support/embeddings_generation_stubs"
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::Models::AllMpnetBaseV2 do
|
||||
describe "#generate_embeddings" do
|
||||
let(:input) { "test" }
|
||||
let(:expected_embedding) { [0.0038493, 0.482001] }
|
||||
|
||||
context "when the model uses the discourse service to create embeddings" do
|
||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||
|
||||
let(:discourse_model) { "all-mpnet-base-v2" }
|
||||
|
||||
it "returns an embedding for a given string" do
|
||||
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
|
||||
|
||||
embedding = described_class.generate_embeddings(input)
|
||||
|
||||
expect(embedding).to contain_exactly(*expected_embedding)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,22 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "../../../../support/embeddings_generation_stubs"
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::Models::TextEmbeddingAda002 do
|
||||
describe "#generate_embeddings" do
|
||||
let(:input) { "test" }
|
||||
let(:expected_embedding) { [0.0038493, 0.482001] }
|
||||
|
||||
context "when the model uses OpenAI to create embeddings" do
|
||||
let(:openai_model) { "text-embedding-ada-002" }
|
||||
|
||||
it "returns an embedding for a given string" do
|
||||
EmbeddingsGenerationStubs.openai_service(openai_model, input, expected_embedding)
|
||||
|
||||
embedding = described_class.generate_embeddings(input)
|
||||
|
||||
expect(embedding).to contain_exactly(*expected_embedding)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -3,6 +3,8 @@
|
|||
require "rails_helper"
|
||||
|
||||
describe DiscourseAi::Embeddings::SemanticRelated do
|
||||
subject(:semantic_related) { described_class.new }
|
||||
|
||||
fab!(:target) { Fabricate(:topic) }
|
||||
fab!(:normal_topic_1) { Fabricate(:topic) }
|
||||
fab!(:normal_topic_2) { Fabricate(:topic) }
|
||||
|
@ -25,13 +27,13 @@ describe DiscourseAi::Embeddings::SemanticRelated do
|
|||
results = nil
|
||||
|
||||
expect_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
|
||||
results = described_class.related_topic_ids_for(topic)
|
||||
results = semantic_related.related_topic_ids_for(topic)
|
||||
end
|
||||
|
||||
expect(results).to eq([])
|
||||
|
||||
expect_not_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
|
||||
results = described_class.related_topic_ids_for(topic)
|
||||
results = semantic_related.related_topic_ids_for(topic)
|
||||
end
|
||||
|
||||
expect(results).to eq([])
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "../../../support/embeddings_generation_stubs"
|
||||
require_relative "../../../support/openai_completions_inference_stubs"
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
||||
fab!(:post) { Fabricate(:post) }
|
||||
fab!(:user) { Fabricate(:user) }
|
||||
|
@ -8,10 +11,28 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
|||
let(:subject) { described_class.new(Guardian.new(user)) }
|
||||
|
||||
describe "#search_for_topics" do
|
||||
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
|
||||
prompt = DiscourseAi::Embeddings::HydeGenerators::OpenAi.new.prompt(query)
|
||||
OpenAiCompletionsInferenceStubs.stub_response(prompt, hypothetical_post)
|
||||
|
||||
hyde_embedding = [0.049382, 0.9999]
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
SiteSetting.ai_embeddings_model,
|
||||
hypothetical_post,
|
||||
hyde_embedding,
|
||||
)
|
||||
end
|
||||
|
||||
after { described_class.clear_cache_for(query) }
|
||||
|
||||
def stub_candidate_ids(candidate_ids)
|
||||
DiscourseAi::Embeddings::SemanticSearch
|
||||
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
|
||||
.any_instance
|
||||
.expects(:asymmetric_semantic_search)
|
||||
.expects(:asymmetric_topics_similarity_search)
|
||||
.returns(candidate_ids)
|
||||
end
|
||||
|
||||
|
|
|
@ -12,9 +12,14 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
|||
fab!(:target) { Fabricate(:topic) }
|
||||
|
||||
def stub_semantic_search_with(results)
|
||||
DiscourseAi::Embeddings::SemanticRelated.expects(:related_topic_ids_for).returns(results)
|
||||
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
|
||||
.any_instance
|
||||
.expects(:symmetric_topics_similarity_search)
|
||||
.returns(results.concat([target.id]))
|
||||
end
|
||||
|
||||
after { DiscourseAi::Embeddings::SemanticRelated.clear_cache_for(target) }
|
||||
|
||||
context "when the semantic search returns an unlisted topic" do
|
||||
fab!(:unlisted_topic) { Fabricate(:topic, visible: false) }
|
||||
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
||||
describe "#process!" do
|
||||
context "when the model uses OpenAI to create embeddings" do
|
||||
subject(:truncation) { described_class.new }
|
||||
|
||||
describe "#prepare_text_from" do
|
||||
context "when using vector from OpenAI" do
|
||||
before { SiteSetting.max_post_length = 100_000 }
|
||||
|
||||
fab!(:topic) { Fabricate(:topic) }
|
||||
|
@ -18,13 +20,15 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
|||
end
|
||||
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
|
||||
|
||||
let(:model) { DiscourseAi::Embeddings::Models::Base.descendants.sample(1).first }
|
||||
let(:truncation) { described_class.new(topic, model) }
|
||||
let(:model) do
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new(truncation)
|
||||
end
|
||||
|
||||
it "truncates a topic" do
|
||||
truncation.process!
|
||||
prepared_text =
|
||||
truncation.prepare_text_from(topic, model.tokenizer, model.max_sequence_length)
|
||||
|
||||
expect(model.tokenizer.size(truncation.processed_target)).to be <= model.max_sequence_length
|
||||
expect(model.tokenizer.size(prepared_text)).to be <= model.max_sequence_length
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "../../../../support/embeddings_generation_stubs"
|
||||
require_relative "vector_rep_shared_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do
|
||||
subject(:vector_rep) { described_class.new(truncation) }
|
||||
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
|
||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embedding using with vector representation"
|
||||
end
|
|
@ -0,0 +1,22 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "../../../../support/embeddings_generation_stubs"
|
||||
require_relative "vector_rep_shared_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large do
|
||||
subject(:vector_rep) { described_class.new(truncation) }
|
||||
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
|
||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
vector_rep.name,
|
||||
"query: #{text}",
|
||||
expected_embedding,
|
||||
)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embedding using with vector representation"
|
||||
end
|
|
@ -0,0 +1,16 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "../../../../support/embeddings_generation_stubs"
|
||||
require_relative "vector_rep_shared_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002 do
|
||||
subject(:vector_rep) { described_class.new(truncation) }
|
||||
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.openai_service(vector_rep.name, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embedding using with vector representation"
|
||||
end
|
|
@ -0,0 +1,54 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.shared_examples "generates and store embedding using with vector representation" do
|
||||
before { @expected_embedding = [0.0038493] * vector_rep.dimensions }
|
||||
|
||||
describe "#vector_from" do
|
||||
it "creates a vector from a given string" do
|
||||
text = "This is a piece of text"
|
||||
stub_vector_mapping(text, @expected_embedding)
|
||||
|
||||
expect(vector_rep.vector_from(text)).to eq(@expected_embedding)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#generate_topic_representation_from" do
|
||||
fab!(:topic) { Fabricate(:topic) }
|
||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||
|
||||
it "creates a vector from a topic and stores it in the database" do
|
||||
text =
|
||||
truncation.prepare_text_from(
|
||||
topic,
|
||||
vector_rep.tokenizer,
|
||||
vector_rep.max_sequence_length - 2,
|
||||
)
|
||||
stub_vector_mapping(text, @expected_embedding)
|
||||
|
||||
vector_rep.generate_topic_representation_from(topic)
|
||||
|
||||
expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#asymmetric_topics_similarity_search" do
|
||||
fab!(:topic) { Fabricate(:topic) }
|
||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||
|
||||
it "finds IDs of similar topics with a given embedding" do
|
||||
similar_vector = [0.0038494] * vector_rep.dimensions
|
||||
text =
|
||||
truncation.prepare_text_from(
|
||||
topic,
|
||||
vector_rep.tokenizer,
|
||||
vector_rep.max_sequence_length - 2,
|
||||
)
|
||||
stub_vector_mapping(text, @expected_embedding)
|
||||
vector_rep.generate_topic_representation_from(topic)
|
||||
|
||||
expect(
|
||||
vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),
|
||||
).to contain_exactly(topic.id)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -16,9 +16,10 @@ describe ::TopicsController do
|
|||
|
||||
context "when a user is logged on" do
|
||||
it "includes related topics in payload when configured" do
|
||||
DiscourseAi::Embeddings::SemanticRelated.stubs(:related_topic_ids_for).returns(
|
||||
[topic1.id, topic2.id, topic3.id],
|
||||
)
|
||||
DiscourseAi::Embeddings::SemanticRelated
|
||||
.any_instance
|
||||
.stubs(:related_topic_ids_for)
|
||||
.returns([topic1.id, topic2.id, topic3.id])
|
||||
|
||||
get("#{topic.relative_url}.json")
|
||||
expect(response.status).to eq(200)
|
||||
|
|
Loading…
Reference in New Issue