FEATURE: HyDE-powered semantic search. (#136)
* FEATURE: HyDE-powered semantic search. It relies on the new outlet added on discourse/discourse#23390 to display semantic search results in an unobtrusive way. We'll use a HyDE-backed approach for semantic search, which consists on generating an hypothetical document from a given keywords, which gets transformed into a vector and used in a asymmetric similarity topic search. This PR also reorganizes the internals to have less moving parts, maintaining one hierarchy of DAOish classes for vector-related operations like transformations and querying. Completions and vectors created by HyDE will remain cached on Redis for now, but we could later use Postgres instead. * Missing translation and rate limiting --------- Co-authored-by: Roman Rizzi <rizziromanalejandro@gmail.com>
This commit is contained in:
parent
3d83d062a1
commit
2c0f535bab
|
@ -9,7 +9,6 @@ module DiscourseAi
|
||||||
|
|
||||||
def search
|
def search
|
||||||
query = params[:q]
|
query = params[:q]
|
||||||
page = (params[:page] || 1).to_i
|
|
||||||
|
|
||||||
grouped_results =
|
grouped_results =
|
||||||
Search::GroupedSearchResults.new(
|
Search::GroupedSearchResults.new(
|
||||||
|
@ -19,12 +18,19 @@ module DiscourseAi
|
||||||
use_pg_headlines_for_excerpt: false,
|
use_pg_headlines_for_excerpt: false,
|
||||||
)
|
)
|
||||||
|
|
||||||
DiscourseAi::Embeddings::SemanticSearch
|
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian)
|
||||||
.new(guardian)
|
|
||||||
.search_for_topics(query, page)
|
|
||||||
.each { |topic_post| grouped_results.add(topic_post) }
|
|
||||||
|
|
||||||
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
|
if !semantic_search.cached_query?(query)
|
||||||
|
RateLimiter.new(current_user, "semantic-search", 4, 1.minutes).performed!
|
||||||
|
end
|
||||||
|
|
||||||
|
hijack do
|
||||||
|
semantic_search
|
||||||
|
.search_for_topics(query)
|
||||||
|
.each { |topic_post| grouped_results.add(topic_post) }
|
||||||
|
|
||||||
|
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
{{#if this.searchEnabled}}
|
||||||
|
<div class="semantic-search__container search-results" role="region">
|
||||||
|
<div
|
||||||
|
class="semantic-search__results"
|
||||||
|
{{did-insert this.setup}}
|
||||||
|
{{did-insert this.debouncedSearch}}
|
||||||
|
{{will-destroy this.teardown}}
|
||||||
|
>
|
||||||
|
{{#if this.searching}}
|
||||||
|
<div class="semantic-search__searching">
|
||||||
|
<div class="semantic-search__searching-text">
|
||||||
|
{{i18n "discourse_ai.embeddings.semantic_search_loading"}}
|
||||||
|
</div>
|
||||||
|
<span class="semantic-search__indicator-wave">
|
||||||
|
<span class="semantic-search__indicator-dot">.</span>
|
||||||
|
<span class="semantic-search__indicator-dot">.</span>
|
||||||
|
<span class="semantic-search__indicator-dot">.</span>
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
{{else}}
|
||||||
|
{{#if this.results.length}}
|
||||||
|
<div class="semantic-search__toggle-button-container">
|
||||||
|
<DButton
|
||||||
|
@translatedTitle={{this.collapsedResultsTitle}}
|
||||||
|
@translatedLabel={{this.collapsedResultsTitle}}
|
||||||
|
@action={{fn
|
||||||
|
(mut this.collapsedResults)
|
||||||
|
(not this.collapsedResults)
|
||||||
|
}}
|
||||||
|
@class="btn-flat"
|
||||||
|
@icon={{if this.collapsedResults "chevron-right" "chevron-down"}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{#unless this.collapsedResults}}
|
||||||
|
<div class="semantic-search__entries">
|
||||||
|
<SearchResultEntries
|
||||||
|
@posts={{this.results}}
|
||||||
|
@highlightQuery={{this.highlightQuery}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{{/unless}}
|
||||||
|
{{else}}
|
||||||
|
<div class="semantic-search__searching">
|
||||||
|
{{i18n "discourse_ai.embeddings.semantic_search_results.none"}}
|
||||||
|
</div>
|
||||||
|
{{/if}}
|
||||||
|
{{/if}}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{{/if}}
|
|
@ -0,0 +1,82 @@
|
||||||
|
import Component from "@glimmer/component";
|
||||||
|
import { action, computed } from "@ember/object";
|
||||||
|
import I18n from "I18n";
|
||||||
|
import { tracked } from "@glimmer/tracking";
|
||||||
|
import { ajax } from "discourse/lib/ajax";
|
||||||
|
import { translateResults } from "discourse/lib/search";
|
||||||
|
import discourseDebounce from "discourse-common/lib/debounce";
|
||||||
|
import { inject as service } from "@ember/service";
|
||||||
|
import { bind } from "discourse-common/utils/decorators";
|
||||||
|
import { SEARCH_TYPE_DEFAULT } from "discourse/controllers/full-page-search";
|
||||||
|
|
||||||
|
export default class extends Component {
|
||||||
|
static shouldRender(_args, { siteSettings }) {
|
||||||
|
return siteSettings.ai_embeddings_semantic_search_enabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
@service appEvents;
|
||||||
|
|
||||||
|
@tracked searching = false;
|
||||||
|
@tracked collapsedResults = true;
|
||||||
|
@tracked results = [];
|
||||||
|
|
||||||
|
@computed("args.outletArgs.search")
|
||||||
|
get searchTerm() {
|
||||||
|
return this.args.outletArgs.search;
|
||||||
|
}
|
||||||
|
|
||||||
|
@computed("args.outletArgs.type")
|
||||||
|
get searchEnabled() {
|
||||||
|
return this.args.outletArgs.type === SEARCH_TYPE_DEFAULT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@computed("results")
|
||||||
|
get collapsedResultsTitle() {
|
||||||
|
return I18n.t("discourse_ai.embeddings.semantic_search_results.toggle", {
|
||||||
|
count: this.results.length,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
setup() {
|
||||||
|
this.appEvents.on(
|
||||||
|
"full-page-search:trigger-search",
|
||||||
|
this,
|
||||||
|
"debouncedSearch"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
teardown() {
|
||||||
|
this.appEvents.off(
|
||||||
|
"full-page-search:trigger-search",
|
||||||
|
this,
|
||||||
|
"debouncedSearch"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@bind
|
||||||
|
performHyDESearch() {
|
||||||
|
if (!this.searchTerm || !this.searchEnabled || this.searching) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.searching = true;
|
||||||
|
this.collapsedResults = true;
|
||||||
|
this.results = [];
|
||||||
|
|
||||||
|
ajax("/discourse-ai/embeddings/semantic-search", {
|
||||||
|
data: { q: this.searchTerm },
|
||||||
|
})
|
||||||
|
.then(async (results) => {
|
||||||
|
const model = (await translateResults(results)) || {};
|
||||||
|
this.results = model.posts;
|
||||||
|
})
|
||||||
|
.finally(() => (this.searching = false));
|
||||||
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
debouncedSearch() {
|
||||||
|
discourseDebounce(this, this.performHyDESearch, 500);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,63 +0,0 @@
|
||||||
import { withPluginApi } from "discourse/lib/plugin-api";
|
|
||||||
import { translateResults, updateRecentSearches } from "discourse/lib/search";
|
|
||||||
import { ajax } from "discourse/lib/ajax";
|
|
||||||
|
|
||||||
const SEMANTIC_SEARCH = "semantic_search";
|
|
||||||
|
|
||||||
function initializeSemanticSearch(api) {
|
|
||||||
api.addFullPageSearchType(
|
|
||||||
"discourse_ai.embeddings.semantic_search",
|
|
||||||
SEMANTIC_SEARCH,
|
|
||||||
(searchController, args) => {
|
|
||||||
if (searchController.currentUser) {
|
|
||||||
updateRecentSearches(searchController.currentUser, args.searchTerm);
|
|
||||||
}
|
|
||||||
|
|
||||||
ajax("/discourse-ai/embeddings/semantic-search", { data: args })
|
|
||||||
.then(async (results) => {
|
|
||||||
const model = (await translateResults(results)) || {};
|
|
||||||
|
|
||||||
if (results.grouped_search_result) {
|
|
||||||
searchController.set("q", results.grouped_search_result.term);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (args.page > 1) {
|
|
||||||
if (model) {
|
|
||||||
searchController.model.posts.pushObjects(model.posts);
|
|
||||||
searchController.model.topics.pushObjects(model.topics);
|
|
||||||
searchController.model.set(
|
|
||||||
"grouped_search_result",
|
|
||||||
results.grouped_search_result
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
model.grouped_search_result = results.grouped_search_result;
|
|
||||||
searchController.set("model", model);
|
|
||||||
}
|
|
||||||
searchController.set("error", null);
|
|
||||||
})
|
|
||||||
.catch((e) => {
|
|
||||||
searchController.set("error", e.jqXHR.responseJSON?.message);
|
|
||||||
})
|
|
||||||
.finally(() => {
|
|
||||||
searchController.setProperties({
|
|
||||||
searching: false,
|
|
||||||
loading: false,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export default {
|
|
||||||
name: "discourse-ai-full-page-semantic-search",
|
|
||||||
|
|
||||||
initialize(container) {
|
|
||||||
const settings = container.lookup("service:site-settings");
|
|
||||||
const semanticSearch = settings.ai_embeddings_semantic_search_enabled;
|
|
||||||
|
|
||||||
if (settings.ai_embeddings_enabled && semanticSearch) {
|
|
||||||
withPluginApi("1.6.0", initializeSemanticSearch);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
};
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
.semantic-search__container {
|
||||||
|
background: var(--primary-very-low);
|
||||||
|
margin: 1rem 0 1rem 0;
|
||||||
|
|
||||||
|
.semantic-search__results {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: baseline;
|
||||||
|
|
||||||
|
.semantic-search {
|
||||||
|
&__searching-text {
|
||||||
|
display: inline-block;
|
||||||
|
margin-left: 3px;
|
||||||
|
}
|
||||||
|
&__indicator-wave {
|
||||||
|
flex: 0 0 auto;
|
||||||
|
display: inline-flex;
|
||||||
|
}
|
||||||
|
&__indicator-dot {
|
||||||
|
display: inline-block;
|
||||||
|
animation: ai-summary__indicator-wave 1.8s linear infinite;
|
||||||
|
&:nth-child(2) {
|
||||||
|
animation-delay: -1.6s;
|
||||||
|
}
|
||||||
|
&:nth-child(3) {
|
||||||
|
animation-delay: -1.4s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.semantic-search__entries {
|
||||||
|
margin-top: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.semantic-search__searching {
|
||||||
|
margin-left: 5px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -34,6 +34,10 @@ en:
|
||||||
|
|
||||||
embeddings:
|
embeddings:
|
||||||
semantic_search: "Topics (Semantic)"
|
semantic_search: "Topics (Semantic)"
|
||||||
|
semantic_search_loading: "Searching for more results using AI"
|
||||||
|
semantic_search_results:
|
||||||
|
toggle: "Found %{count} results using AI"
|
||||||
|
none: "Sorry, our AI search found no matching topics."
|
||||||
|
|
||||||
ai_bot:
|
ai_bot:
|
||||||
pm_warning: "AI chatbot messages are monitored regularly by moderators."
|
pm_warning: "AI chatbot messages are monitored regularly by moderators."
|
||||||
|
|
|
@ -55,6 +55,7 @@ en:
|
||||||
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
|
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
|
||||||
ai_embeddings_semantic_search_enabled: "Enable full-page semantic search."
|
ai_embeddings_semantic_search_enabled: "Enable full-page semantic search."
|
||||||
ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results"
|
ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results"
|
||||||
|
ai_embeddings_semantic_search_hyde_model: "Model used to expand keywords to get better results during a semantic search"
|
||||||
|
|
||||||
ai_summarization_discourse_service_api_endpoint: "URL where the Discourse summarization API is running."
|
ai_summarization_discourse_service_api_endpoint: "URL where the Discourse summarization API is running."
|
||||||
ai_summarization_discourse_service_api_key: "API key for the Discourse summarization API."
|
ai_summarization_discourse_service_api_key: "API key for the Discourse summarization API."
|
||||||
|
|
|
@ -177,6 +177,18 @@ discourse_ai:
|
||||||
ai_embeddings_semantic_search_enabled:
|
ai_embeddings_semantic_search_enabled:
|
||||||
default: false
|
default: false
|
||||||
client: true
|
client: true
|
||||||
|
ai_embeddings_semantic_search_hyde_model:
|
||||||
|
default: "gpt-3.5-turbo"
|
||||||
|
type: enum
|
||||||
|
allow_any: false
|
||||||
|
choices:
|
||||||
|
- Llama2-*-chat-hf
|
||||||
|
- claude-instant-1
|
||||||
|
- claude-2
|
||||||
|
- gpt-3.5-turbo
|
||||||
|
- gpt-4
|
||||||
|
- StableBeluga2
|
||||||
|
- Upstage-Llama-2-*-instruct-v2
|
||||||
|
|
||||||
ai_summarization_discourse_service_api_endpoint: ""
|
ai_summarization_discourse_service_api_endpoint: ""
|
||||||
ai_summarization_discourse_service_api_key:
|
ai_summarization_discourse_service_api_key:
|
||||||
|
|
|
@ -11,13 +11,12 @@ module DiscourseAi
|
||||||
return [] if @text.blank?
|
return [] if @text.blank?
|
||||||
return [] unless SiteSetting.ai_embeddings_enabled
|
return [] unless SiteSetting.ai_embeddings_enabled
|
||||||
|
|
||||||
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
|
vector_rep =
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||||
|
|
||||||
candidates =
|
candidates =
|
||||||
::DiscourseAi::Embeddings::SemanticSearch.new(nil).asymmetric_semantic_search(
|
vector_rep.asymmetric_semantic_search(@text, limit: 100, offset: 0, return_distance: true)
|
||||||
@text,
|
|
||||||
100,
|
|
||||||
0,
|
|
||||||
return_distance: true,
|
|
||||||
)
|
|
||||||
candidate_ids = candidates.map(&:first)
|
candidate_ids = candidates.map(&:first)
|
||||||
|
|
||||||
::Topic
|
::Topic
|
||||||
|
|
|
@ -4,16 +4,21 @@ module DiscourseAi
|
||||||
module Embeddings
|
module Embeddings
|
||||||
class EntryPoint
|
class EntryPoint
|
||||||
def load_files
|
def load_files
|
||||||
require_relative "models/base"
|
require_relative "vector_representations/base"
|
||||||
require_relative "models/all_mpnet_base_v2"
|
require_relative "vector_representations/all_mpnet_base_v2"
|
||||||
require_relative "models/text_embedding_ada_002"
|
require_relative "vector_representations/text_embedding_ada_002"
|
||||||
require_relative "models/multilingual_e5_large"
|
require_relative "vector_representations/multilingual_e5_large"
|
||||||
require_relative "strategies/truncation"
|
require_relative "strategies/truncation"
|
||||||
require_relative "manager"
|
|
||||||
require_relative "jobs/regular/generate_embeddings"
|
require_relative "jobs/regular/generate_embeddings"
|
||||||
require_relative "semantic_related"
|
require_relative "semantic_related"
|
||||||
require_relative "semantic_search"
|
|
||||||
require_relative "semantic_topic_query"
|
require_relative "semantic_topic_query"
|
||||||
|
|
||||||
|
require_relative "hyde_generators/base"
|
||||||
|
require_relative "hyde_generators/openai"
|
||||||
|
require_relative "hyde_generators/anthropic"
|
||||||
|
require_relative "hyde_generators/llama2"
|
||||||
|
require_relative "hyde_generators/llama2_ftos"
|
||||||
|
require_relative "semantic_search"
|
||||||
end
|
end
|
||||||
|
|
||||||
def inject_into(plugin)
|
def inject_into(plugin)
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module HydeGenerators
|
||||||
|
class Anthropic < DiscourseAi::Embeddings::HydeGenerators::Base
|
||||||
|
def prompt(search_term)
|
||||||
|
<<~TEXT
|
||||||
|
Given a search term given between <input> tags, generate a forum post about the search term.
|
||||||
|
Respond with the generated post between <ai> tags.
|
||||||
|
|
||||||
|
<input>#{search_term}</input>
|
||||||
|
TEXT
|
||||||
|
end
|
||||||
|
|
||||||
|
def models
|
||||||
|
%w[claude-instant-1 claude-2]
|
||||||
|
end
|
||||||
|
|
||||||
|
def hypothetical_post_from(query)
|
||||||
|
response =
|
||||||
|
::DiscourseAi::Inference::AnthropicCompletions.perform!(
|
||||||
|
prompt(query),
|
||||||
|
SiteSetting.ai_embeddings_semantic_search_hyde_model,
|
||||||
|
).dig(:completion)
|
||||||
|
|
||||||
|
Nokogiri::HTML5.fragment(response).at("ai").text
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,17 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module HydeGenerators
|
||||||
|
class Base
|
||||||
|
def self.current_hyde_model
|
||||||
|
DiscourseAi::Embeddings::HydeGenerators::Base.descendants.find do |generator_klass|
|
||||||
|
generator_klass.new.models.include?(
|
||||||
|
SiteSetting.ai_embeddings_semantic_search_hyde_model,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,34 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module HydeGenerators
|
||||||
|
class Llama2 < DiscourseAi::Embeddings::HydeGenerators::Base
|
||||||
|
def prompt(search_term)
|
||||||
|
<<~TEXT
|
||||||
|
[INST] <<SYS>>
|
||||||
|
You are a helpful bot
|
||||||
|
You create forum posts about a given topic
|
||||||
|
<</SYS>>
|
||||||
|
|
||||||
|
Topic: #{search_term}
|
||||||
|
[/INST]
|
||||||
|
Here is a forum post about the above topic:
|
||||||
|
TEXT
|
||||||
|
end
|
||||||
|
|
||||||
|
def models
|
||||||
|
["Llama2-*-chat-hf"]
|
||||||
|
end
|
||||||
|
|
||||||
|
def hypothetical_post_from(query)
|
||||||
|
::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(
|
||||||
|
prompt(query),
|
||||||
|
SiteSetting.ai_embeddings_semantic_search_hyde_model,
|
||||||
|
token_limit: 400,
|
||||||
|
).dig(:generated_text)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,27 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module HydeGenerators
|
||||||
|
class Llama2Ftos < DiscourseAi::Embeddings::HydeGenerators::Llama2
|
||||||
|
def prompt(search_term)
|
||||||
|
<<~TEXT
|
||||||
|
### System:
|
||||||
|
You are a helpful bot
|
||||||
|
You create forum posts about a given topic
|
||||||
|
|
||||||
|
### User:
|
||||||
|
Topic: #{search_term}
|
||||||
|
|
||||||
|
### Assistant:
|
||||||
|
Here is a forum post about the above topic:
|
||||||
|
TEXT
|
||||||
|
end
|
||||||
|
|
||||||
|
def models
|
||||||
|
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,30 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module HydeGenerators
|
||||||
|
class OpenAi < DiscourseAi::Embeddings::HydeGenerators::Base
|
||||||
|
def prompt(search_term)
|
||||||
|
[
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: "You are a helpful bot. You create forum posts about a given topic.",
|
||||||
|
},
|
||||||
|
{ role: "user", content: "Create a forum post about the topic: #{search_term}" },
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
|
def models
|
||||||
|
%w[gpt-3.5-turbo gpt-4]
|
||||||
|
end
|
||||||
|
|
||||||
|
def hypothetical_post_from(query)
|
||||||
|
::DiscourseAi::Inference::OpenAiCompletions.perform!(
|
||||||
|
prompt(query),
|
||||||
|
SiteSetting.ai_embeddings_semantic_search_hyde_model,
|
||||||
|
).dig(:choices, 0, :message, :content)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -11,7 +11,11 @@ module Jobs
|
||||||
post = topic.first_post
|
post = topic.first_post
|
||||||
return if post.nil? || post.raw.blank?
|
return if post.nil? || post.raw.blank?
|
||||||
|
|
||||||
DiscourseAi::Embeddings::Manager.new(topic).generate!
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
|
vector_rep =
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new
|
||||||
|
|
||||||
|
vector_rep.generate_topic_representation_from(topic, strategy)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,64 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
class Manager
|
|
||||||
attr_reader :target, :model, :strategy
|
|
||||||
|
|
||||||
def initialize(target)
|
|
||||||
@target = target
|
|
||||||
@model =
|
|
||||||
DiscourseAi::Embeddings::Models::Base.subclasses.find do
|
|
||||||
_1.name == SiteSetting.ai_embeddings_model
|
|
||||||
end
|
|
||||||
@strategy = DiscourseAi::Embeddings::Strategies::Truncation.new(@target, @model)
|
|
||||||
end
|
|
||||||
|
|
||||||
def generate!
|
|
||||||
@strategy.process!
|
|
||||||
|
|
||||||
# TODO bail here if we already have an embedding with matching version and digest
|
|
||||||
|
|
||||||
@embeddings = @model.generate_embeddings(@strategy.processed_target)
|
|
||||||
|
|
||||||
persist!
|
|
||||||
end
|
|
||||||
|
|
||||||
def persist!
|
|
||||||
begin
|
|
||||||
DB.exec(
|
|
||||||
<<~SQL,
|
|
||||||
INSERT INTO ai_topic_embeddings_#{table_suffix} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
|
|
||||||
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
|
||||||
ON CONFLICT (topic_id)
|
|
||||||
DO UPDATE SET
|
|
||||||
model_version = :model_version,
|
|
||||||
strategy_version = :strategy_version,
|
|
||||||
digest = :digest,
|
|
||||||
embeddings = '[:embeddings]',
|
|
||||||
updated_at = CURRENT_TIMESTAMP
|
|
||||||
|
|
||||||
SQL
|
|
||||||
topic_id: @target.id,
|
|
||||||
model_version: @model.version,
|
|
||||||
strategy_version: @strategy.version,
|
|
||||||
digest: @strategy.digest,
|
|
||||||
embeddings: @embeddings,
|
|
||||||
)
|
|
||||||
rescue PG::Error => e
|
|
||||||
Rails.logger.error(
|
|
||||||
"Error #{e} persisting embedding for topic #{topic.id} and model #{model.name}",
|
|
||||||
)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def table_suffix
|
|
||||||
"#{@model.id}_#{@strategy.id}"
|
|
||||||
end
|
|
||||||
|
|
||||||
def topic_embeddings_table
|
|
||||||
"ai_topic_embeddings_#{table_suffix}"
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -1,52 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
module Models
|
|
||||||
class AllMpnetBaseV2 < Base
|
|
||||||
class << self
|
|
||||||
def id
|
|
||||||
1
|
|
||||||
end
|
|
||||||
|
|
||||||
def version
|
|
||||||
1
|
|
||||||
end
|
|
||||||
|
|
||||||
def name
|
|
||||||
"all-mpnet-base-v2"
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
|
||||||
768
|
|
||||||
end
|
|
||||||
|
|
||||||
def max_sequence_length
|
|
||||||
384
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_function
|
|
||||||
"<#>"
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_index_type
|
|
||||||
"vector_ip_ops"
|
|
||||||
end
|
|
||||||
|
|
||||||
def generate_embeddings(text)
|
|
||||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
|
||||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
|
||||||
name,
|
|
||||||
text,
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -1,10 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
module Models
|
|
||||||
class Base
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -1,52 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
module Models
|
|
||||||
class MultilingualE5Large < Base
|
|
||||||
class << self
|
|
||||||
def id
|
|
||||||
3
|
|
||||||
end
|
|
||||||
|
|
||||||
def version
|
|
||||||
1
|
|
||||||
end
|
|
||||||
|
|
||||||
def name
|
|
||||||
"multilingual-e5-large"
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
|
||||||
1024
|
|
||||||
end
|
|
||||||
|
|
||||||
def max_sequence_length
|
|
||||||
512
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_function
|
|
||||||
"<=>"
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_index_type
|
|
||||||
"vector_cosine_ops"
|
|
||||||
end
|
|
||||||
|
|
||||||
def generate_embeddings(text)
|
|
||||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
|
||||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
|
||||||
name,
|
|
||||||
"query: #{text}",
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -1,48 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
module Models
|
|
||||||
class TextEmbeddingAda002 < Base
|
|
||||||
class << self
|
|
||||||
def id
|
|
||||||
2
|
|
||||||
end
|
|
||||||
|
|
||||||
def version
|
|
||||||
1
|
|
||||||
end
|
|
||||||
|
|
||||||
def name
|
|
||||||
"text-embedding-ada-002"
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
|
||||||
1536
|
|
||||||
end
|
|
||||||
|
|
||||||
def max_sequence_length
|
|
||||||
8191
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_function
|
|
||||||
"<=>"
|
|
||||||
end
|
|
||||||
|
|
||||||
def pg_index_type
|
|
||||||
"vector_cosine_ops"
|
|
||||||
end
|
|
||||||
|
|
||||||
def generate_embeddings(text)
|
|
||||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text)
|
|
||||||
response[:data].first[:embedding]
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -5,101 +5,67 @@ module DiscourseAi
|
||||||
class SemanticRelated
|
class SemanticRelated
|
||||||
MissingEmbeddingError = Class.new(StandardError)
|
MissingEmbeddingError = Class.new(StandardError)
|
||||||
|
|
||||||
class << self
|
def self.clear_cache_for(topic)
|
||||||
def semantic_suggested_key(topic_id)
|
Discourse.cache.delete("semantic-suggested-topic-#{topic.id}")
|
||||||
"semantic-suggested-topic-#{topic_id}"
|
Discourse.redis.del("build-semantic-suggested-topic-#{topic.id}")
|
||||||
end
|
end
|
||||||
|
|
||||||
def build_semantic_suggested_key(topic_id)
|
def related_topic_ids_for(topic)
|
||||||
"build-semantic-suggested-topic-#{topic_id}"
|
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
|
||||||
end
|
|
||||||
|
|
||||||
def clear_cache_for(topic)
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
Discourse.cache.delete(semantic_suggested_key(topic.id))
|
vector_rep =
|
||||||
Discourse.redis.del(build_semantic_suggested_key(topic.id))
|
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||||
end
|
cache_for = results_ttl(topic)
|
||||||
|
|
||||||
def related_topic_ids_for(topic)
|
asd =
|
||||||
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
|
Discourse
|
||||||
|
.cache
|
||||||
manager = DiscourseAi::Embeddings::Manager.new(topic)
|
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
|
||||||
cache_for = results_ttl(topic)
|
vector_rep
|
||||||
|
.symmetric_topics_similarity_search(topic)
|
||||||
begin
|
.tap do |candidate_ids|
|
||||||
Discourse
|
# Happens when the topic doesn't have any embeddings
|
||||||
.cache
|
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
|
||||||
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
|
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
|
||||||
symmetric_semantic_search(manager)
|
raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}"
|
||||||
end
|
end
|
||||||
rescue MissingEmbeddingError
|
end
|
||||||
# avoid a flood of jobs when visiting topic
|
|
||||||
if Discourse.redis.set(
|
|
||||||
build_semantic_suggested_key(topic.id),
|
|
||||||
"queued",
|
|
||||||
ex: 15.minutes.to_i,
|
|
||||||
nx: true,
|
|
||||||
)
|
|
||||||
Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
|
|
||||||
end
|
end
|
||||||
[]
|
rescue MissingEmbeddingError
|
||||||
end
|
# avoid a flood of jobs when visiting topic
|
||||||
|
if Discourse.redis.set(
|
||||||
|
build_semantic_suggested_key(topic.id),
|
||||||
|
"queued",
|
||||||
|
ex: 15.minutes.to_i,
|
||||||
|
nx: true,
|
||||||
|
)
|
||||||
|
Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
|
||||||
end
|
end
|
||||||
|
[]
|
||||||
|
end
|
||||||
|
|
||||||
def symmetric_semantic_search(manager)
|
def results_ttl(topic)
|
||||||
topic = manager.target
|
case topic.created_at
|
||||||
candidate_ids = self.query_symmetric_embeddings(manager)
|
when 6.hour.ago..Time.now
|
||||||
|
15.minutes
|
||||||
# Happens when the topic doesn't have any embeddings
|
when 3.day.ago..6.hour.ago
|
||||||
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
|
1.hour
|
||||||
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
|
when 15.days.ago..3.day.ago
|
||||||
raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}"
|
12.hours
|
||||||
end
|
else
|
||||||
|
1.week
|
||||||
candidate_ids
|
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
def query_symmetric_embeddings(manager)
|
private
|
||||||
topic = manager.target
|
|
||||||
model = manager.model
|
|
||||||
table = manager.topic_embeddings_table
|
|
||||||
begin
|
|
||||||
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
|
||||||
SELECT
|
|
||||||
topic_id
|
|
||||||
FROM
|
|
||||||
#{table}
|
|
||||||
ORDER BY
|
|
||||||
embeddings #{model.pg_function} (
|
|
||||||
SELECT
|
|
||||||
embeddings
|
|
||||||
FROM
|
|
||||||
#{table}
|
|
||||||
WHERE
|
|
||||||
topic_id = :topic_id
|
|
||||||
LIMIT 1
|
|
||||||
)
|
|
||||||
LIMIT 100
|
|
||||||
SQL
|
|
||||||
rescue PG::Error => e
|
|
||||||
Rails.logger.error(
|
|
||||||
"Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}",
|
|
||||||
)
|
|
||||||
raise MissingEmbeddingError
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def results_ttl(topic)
|
def semantic_suggested_key(topic_id)
|
||||||
case topic.created_at
|
"semantic-suggested-topic-#{topic_id}"
|
||||||
when 6.hour.ago..Time.now
|
end
|
||||||
15.minutes
|
|
||||||
when 3.day.ago..6.hour.ago
|
def build_semantic_suggested_key(topic_id)
|
||||||
1.hour
|
"build-semantic-suggested-topic-#{topic_id}"
|
||||||
when 15.days.ago..3.day.ago
|
|
||||||
12.hours
|
|
||||||
else
|
|
||||||
1.week
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,59 +3,66 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Embeddings
|
module Embeddings
|
||||||
class SemanticSearch
|
class SemanticSearch
|
||||||
|
def self.clear_cache_for(query)
|
||||||
|
digest = OpenSSL::Digest::SHA1.hexdigest(query)
|
||||||
|
|
||||||
|
Discourse.cache.delete("hyde-doc-#{digest}")
|
||||||
|
Discourse.cache.delete("hyde-doc-embedding-#{digest}")
|
||||||
|
end
|
||||||
|
|
||||||
def initialize(guardian)
|
def initialize(guardian)
|
||||||
@guardian = guardian
|
@guardian = guardian
|
||||||
@manager = DiscourseAi::Embeddings::Manager.new(nil)
|
end
|
||||||
@model = @manager.model
|
|
||||||
|
def cached_query?(query)
|
||||||
|
digest = OpenSSL::Digest::SHA1.hexdigest(query)
|
||||||
|
Discourse.cache.read("hyde-doc-embedding-#{digest}").present?
|
||||||
end
|
end
|
||||||
|
|
||||||
def search_for_topics(query, page = 1)
|
def search_for_topics(query, page = 1)
|
||||||
limit = Search.per_filter + 1
|
max_results_per_page = 50
|
||||||
offset = (page - 1) * Search.per_filter
|
limit = [Search.per_filter, max_results_per_page].min + 1
|
||||||
|
offset = (page - 1) * limit
|
||||||
|
|
||||||
candidate_ids = asymmetric_semantic_search(query, limit, offset)
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
|
vector_rep =
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||||
|
|
||||||
|
digest = OpenSSL::Digest::SHA1.hexdigest(query)
|
||||||
|
|
||||||
|
hypothetical_post =
|
||||||
|
Discourse
|
||||||
|
.cache
|
||||||
|
.fetch("hyde-doc-#{digest}", expires_in: 1.week) do
|
||||||
|
hyde_generator = DiscourseAi::Embeddings::HydeGenerators::Base.current_hyde_model.new
|
||||||
|
hyde_generator.hypothetical_post_from(query)
|
||||||
|
end
|
||||||
|
|
||||||
|
hypothetical_post_embedding =
|
||||||
|
Discourse
|
||||||
|
.cache
|
||||||
|
.fetch("hyde-doc-embedding-#{digest}", expires_in: 1.week) do
|
||||||
|
vector_rep.vector_from(hypothetical_post)
|
||||||
|
end
|
||||||
|
|
||||||
|
candidate_topic_ids =
|
||||||
|
vector_rep.asymmetric_topics_similarity_search(
|
||||||
|
hypothetical_post_embedding,
|
||||||
|
limit: limit,
|
||||||
|
offset: offset,
|
||||||
|
)
|
||||||
|
|
||||||
::Post
|
::Post
|
||||||
.where(post_type: ::Topic.visible_post_types(guardian.user))
|
.where(post_type: ::Topic.visible_post_types(guardian.user))
|
||||||
.public_posts
|
.public_posts
|
||||||
.where("topics.visible")
|
.where("topics.visible")
|
||||||
.where(topic_id: candidate_ids, post_number: 1)
|
.where(topic_id: candidate_topic_ids, post_number: 1)
|
||||||
.order("array_position(ARRAY#{candidate_ids}, topic_id)")
|
.order("array_position(ARRAY#{candidate_topic_ids}, topic_id)")
|
||||||
end
|
|
||||||
|
|
||||||
def asymmetric_semantic_search(query, limit, offset, return_distance: false)
|
|
||||||
embedding = model.generate_embeddings(query)
|
|
||||||
table = @manager.topic_embeddings_table
|
|
||||||
|
|
||||||
begin
|
|
||||||
candidate_ids = DB.query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset)
|
|
||||||
SELECT
|
|
||||||
topic_id,
|
|
||||||
embeddings #{@model.pg_function} '[:query_embedding]' AS distance
|
|
||||||
FROM
|
|
||||||
#{table}
|
|
||||||
ORDER BY
|
|
||||||
embeddings #{@model.pg_function} '[:query_embedding]'
|
|
||||||
LIMIT :limit
|
|
||||||
OFFSET :offset
|
|
||||||
SQL
|
|
||||||
rescue PG::Error => e
|
|
||||||
Rails.logger.error(
|
|
||||||
"Error #{e} querying embeddings for model #{model.name} and search #{query}",
|
|
||||||
)
|
|
||||||
raise MissingEmbeddingError
|
|
||||||
end
|
|
||||||
|
|
||||||
if return_distance
|
|
||||||
candidate_ids.map { |c| [c.topic_id, c.distance] }
|
|
||||||
else
|
|
||||||
candidate_ids.map(&:topic_id)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
attr_reader :model, :guardian
|
attr_reader :guardian
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -14,7 +14,7 @@ class DiscourseAi::Embeddings::SemanticTopicQuery < TopicQuery
|
||||||
|
|
||||||
list =
|
list =
|
||||||
create_list(:semantic_related, query_opts) do |topics|
|
create_list(:semantic_related, query_opts) do |topics|
|
||||||
candidate_ids = DiscourseAi::Embeddings::SemanticRelated.related_topic_ids_for(topic)
|
candidate_ids = DiscourseAi::Embeddings::SemanticRelated.new.related_topic_ids_for(topic)
|
||||||
|
|
||||||
list =
|
list =
|
||||||
topics
|
topics
|
||||||
|
|
|
@ -4,77 +4,57 @@ module DiscourseAi
|
||||||
module Embeddings
|
module Embeddings
|
||||||
module Strategies
|
module Strategies
|
||||||
class Truncation
|
class Truncation
|
||||||
attr_reader :processed_target, :digest
|
|
||||||
|
|
||||||
def self.id
|
|
||||||
1
|
|
||||||
end
|
|
||||||
|
|
||||||
def id
|
def id
|
||||||
self.class.id
|
1
|
||||||
end
|
end
|
||||||
|
|
||||||
def version
|
def version
|
||||||
1
|
1
|
||||||
end
|
end
|
||||||
|
|
||||||
def initialize(target, model)
|
def prepare_text_from(target, tokenizer, max_length)
|
||||||
@model = model
|
case target
|
||||||
@target = target
|
|
||||||
@tokenizer = @model.tokenizer
|
|
||||||
@max_length = @model.max_sequence_length - 2
|
|
||||||
@processed_target = nil
|
|
||||||
end
|
|
||||||
|
|
||||||
# Need a better name for this method
|
|
||||||
def process!
|
|
||||||
case @target
|
|
||||||
when Topic
|
when Topic
|
||||||
@processed_target = topic_truncation(@target)
|
topic_truncation(target, tokenizer, max_length)
|
||||||
when Post
|
when Post
|
||||||
@processed_target = post_truncation(@target)
|
post_truncation(target, tokenizer, max_length)
|
||||||
else
|
else
|
||||||
raise ArgumentError, "Invalid target type"
|
raise ArgumentError, "Invalid target type"
|
||||||
end
|
end
|
||||||
|
|
||||||
@digest = OpenSSL::Digest::SHA1.hexdigest(@processed_target)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def topic_truncation(topic)
|
private
|
||||||
t = +""
|
|
||||||
|
|
||||||
t << topic.title
|
def topic_information(topic)
|
||||||
t << "\n\n"
|
info = +""
|
||||||
t << topic.category.name
|
|
||||||
|
info << topic.title
|
||||||
|
info << "\n\n"
|
||||||
|
info << topic.category.name
|
||||||
if SiteSetting.tagging_enabled
|
if SiteSetting.tagging_enabled
|
||||||
t << "\n\n"
|
info << "\n\n"
|
||||||
t << topic.tags.pluck(:name).join(", ")
|
info << topic.tags.pluck(:name).join(", ")
|
||||||
end
|
end
|
||||||
t << "\n\n"
|
info << "\n\n"
|
||||||
|
end
|
||||||
|
|
||||||
|
def topic_truncation(topic, tokenizer, max_length)
|
||||||
|
text = +topic_information(topic)
|
||||||
|
|
||||||
topic.posts.find_each do |post|
|
topic.posts.find_each do |post|
|
||||||
t << post.raw
|
text << post.raw
|
||||||
break if @tokenizer.size(t) >= @max_length #maybe keep a partial counter to speed this up?
|
break if tokenizer.size(text) >= max_length #maybe keep a partial counter to speed this up?
|
||||||
t << "\n\n"
|
text << "\n\n"
|
||||||
end
|
end
|
||||||
|
|
||||||
@tokenizer.truncate(t, @max_length)
|
tokenizer.truncate(text, max_length)
|
||||||
end
|
end
|
||||||
|
|
||||||
def post_truncation(post)
|
def post_truncation(topic, tokenizer, max_length)
|
||||||
t = +""
|
text = +topic_information(post.topic)
|
||||||
|
text << post.raw
|
||||||
|
|
||||||
t << post.topic.title
|
tokenizer.truncate(text, max_length)
|
||||||
t << "\n\n"
|
|
||||||
t << post.topic.category.name
|
|
||||||
if SiteSetting.tagging_enabled
|
|
||||||
t << "\n\n"
|
|
||||||
t << post.topic.tags.pluck(:name).join(", ")
|
|
||||||
end
|
|
||||||
t << "\n\n"
|
|
||||||
t << post.raw
|
|
||||||
|
|
||||||
@tokenizer.truncate(t, @max_length)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module VectorRepresentations
|
||||||
|
class AllMpnetBaseV2 < Base
|
||||||
|
def vector_from(text)
|
||||||
|
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||||
|
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||||
|
name,
|
||||||
|
text,
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def name
|
||||||
|
"all-mpnet-base-v2"
|
||||||
|
end
|
||||||
|
|
||||||
|
def dimensions
|
||||||
|
768
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_sequence_length
|
||||||
|
384
|
||||||
|
end
|
||||||
|
|
||||||
|
def id
|
||||||
|
1
|
||||||
|
end
|
||||||
|
|
||||||
|
def version
|
||||||
|
1
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_function
|
||||||
|
"<#>"
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_index_type
|
||||||
|
"vector_ip_ops"
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,166 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module VectorRepresentations
|
||||||
|
class Base
|
||||||
|
def self.current_representation(strategy)
|
||||||
|
subclasses.map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model }
|
||||||
|
end
|
||||||
|
|
||||||
|
def initialize(strategy)
|
||||||
|
@strategy = strategy
|
||||||
|
end
|
||||||
|
|
||||||
|
def create_index(lists, probes)
|
||||||
|
index_name = "#{table_name}_search"
|
||||||
|
|
||||||
|
DB.exec(<<~SQL)
|
||||||
|
DROP INDEX IF EXISTS #{index_name};
|
||||||
|
CREATE INDEX IF NOT EXISTS
|
||||||
|
#{index}
|
||||||
|
ON
|
||||||
|
#{table_name}
|
||||||
|
USING
|
||||||
|
ivfflat (embeddings #{pg_index_type})
|
||||||
|
WITH
|
||||||
|
(lists = #{lists})
|
||||||
|
WHERE
|
||||||
|
model_version = #{version} AND
|
||||||
|
strategy_version = #{@strategy.version};
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
|
||||||
|
def vector_from(text)
|
||||||
|
raise NotImplementedError
|
||||||
|
end
|
||||||
|
|
||||||
|
def generate_topic_representation_from(target, persist: true)
|
||||||
|
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
|
||||||
|
|
||||||
|
vector_from(text).tap do |vector|
|
||||||
|
if persist
|
||||||
|
digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
||||||
|
save_to_db(target, vector, digest)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def topic_id_from_representation(raw_vector)
|
||||||
|
DB.query_single(<<~SQL, query_embedding: raw_vector).first
|
||||||
|
SELECT
|
||||||
|
topic_id
|
||||||
|
FROM
|
||||||
|
#{table_name}
|
||||||
|
ORDER BY
|
||||||
|
embeddings #{pg_function} '[:query_embedding]'
|
||||||
|
LIMIT 1
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
|
||||||
|
def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
|
||||||
|
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
|
||||||
|
SELECT
|
||||||
|
topic_id,
|
||||||
|
embeddings #{pg_function} '[:query_embedding]' AS distance
|
||||||
|
FROM
|
||||||
|
#{table_name}
|
||||||
|
ORDER BY
|
||||||
|
embeddings #{pg_function} '[:query_embedding]'
|
||||||
|
LIMIT :limit
|
||||||
|
OFFSET :offset
|
||||||
|
SQL
|
||||||
|
|
||||||
|
if return_distance
|
||||||
|
results.map { |r| [r.topic_id, r.distance] }
|
||||||
|
else
|
||||||
|
results.map(&:topic_id)
|
||||||
|
end
|
||||||
|
rescue PG::Error => e
|
||||||
|
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
||||||
|
raise MissingEmbeddingError
|
||||||
|
end
|
||||||
|
|
||||||
|
def symmetric_topics_similarity_search(topic)
|
||||||
|
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
||||||
|
SELECT
|
||||||
|
topic_id
|
||||||
|
FROM
|
||||||
|
#{table_name}
|
||||||
|
ORDER BY
|
||||||
|
embeddings #{pg_function} (
|
||||||
|
SELECT
|
||||||
|
embeddings
|
||||||
|
FROM
|
||||||
|
#{table_name}
|
||||||
|
WHERE
|
||||||
|
topic_id = :topic_id
|
||||||
|
LIMIT 1
|
||||||
|
)
|
||||||
|
LIMIT 100
|
||||||
|
SQL
|
||||||
|
rescue PG::Error => e
|
||||||
|
Rails.logger.error(
|
||||||
|
"Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
|
||||||
|
)
|
||||||
|
raise MissingEmbeddingError
|
||||||
|
end
|
||||||
|
|
||||||
|
def table_name
|
||||||
|
"ai_topic_embeddings_#{id}_#{@strategy.id}"
|
||||||
|
end
|
||||||
|
|
||||||
|
def name
|
||||||
|
raise NotImplementedError
|
||||||
|
end
|
||||||
|
|
||||||
|
def dimensions
|
||||||
|
raise NotImplementedError
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_sequence_length
|
||||||
|
raise NotImplementedError
|
||||||
|
end
|
||||||
|
|
||||||
|
def id
|
||||||
|
raise NotImplementedError
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_function
|
||||||
|
raise NotImplementedError
|
||||||
|
end
|
||||||
|
|
||||||
|
def version
|
||||||
|
raise NotImplementedError
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
raise NotImplementedError
|
||||||
|
end
|
||||||
|
|
||||||
|
protected
|
||||||
|
|
||||||
|
def save_to_db(target, vector, digest)
|
||||||
|
DB.exec(
|
||||||
|
<<~SQL,
|
||||||
|
INSERT INTO #{table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
|
||||||
|
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||||
|
ON CONFLICT (topic_id)
|
||||||
|
DO UPDATE SET
|
||||||
|
model_version = :model_version,
|
||||||
|
strategy_version = :strategy_version,
|
||||||
|
digest = :digest,
|
||||||
|
embeddings = '[:embeddings]',
|
||||||
|
updated_at = CURRENT_TIMESTAMP
|
||||||
|
SQL
|
||||||
|
topic_id: target.id,
|
||||||
|
model_version: version,
|
||||||
|
strategy_version: @strategy.version,
|
||||||
|
digest: digest,
|
||||||
|
embeddings: vector,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,50 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module VectorRepresentations
|
||||||
|
class MultilingualE5Large < Base
|
||||||
|
def vector_from(text)
|
||||||
|
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||||
|
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||||
|
name,
|
||||||
|
"query: #{text}",
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def id
|
||||||
|
3
|
||||||
|
end
|
||||||
|
|
||||||
|
def version
|
||||||
|
1
|
||||||
|
end
|
||||||
|
|
||||||
|
def name
|
||||||
|
"multilingual-e5-large"
|
||||||
|
end
|
||||||
|
|
||||||
|
def dimensions
|
||||||
|
1024
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_sequence_length
|
||||||
|
512
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_function
|
||||||
|
"<=>"
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_index_type
|
||||||
|
"vector_cosine_ops"
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,46 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module VectorRepresentations
|
||||||
|
class TextEmbeddingAda002 < Base
|
||||||
|
def id
|
||||||
|
2
|
||||||
|
end
|
||||||
|
|
||||||
|
def version
|
||||||
|
1
|
||||||
|
end
|
||||||
|
|
||||||
|
def name
|
||||||
|
"text-embedding-ada-002"
|
||||||
|
end
|
||||||
|
|
||||||
|
def dimensions
|
||||||
|
1536
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_sequence_length
|
||||||
|
8191
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_function
|
||||||
|
"<=>"
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_index_type
|
||||||
|
"vector_cosine_ops"
|
||||||
|
end
|
||||||
|
|
||||||
|
def vector_from(text)
|
||||||
|
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text)
|
||||||
|
response[:data].first[:embedding]
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -4,7 +4,7 @@ module ::DiscourseAi
|
||||||
module Inference
|
module Inference
|
||||||
class HuggingFaceTextGeneration
|
class HuggingFaceTextGeneration
|
||||||
CompletionFailed = Class.new(StandardError)
|
CompletionFailed = Class.new(StandardError)
|
||||||
TIMEOUT = 60
|
TIMEOUT = 120
|
||||||
|
|
||||||
def self.perform!(
|
def self.perform!(
|
||||||
prompt,
|
prompt,
|
||||||
|
|
|
@ -4,18 +4,22 @@ desc "Backfill embeddings for all topics"
|
||||||
task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args|
|
task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args|
|
||||||
public_categories = Category.where(read_restricted: false).pluck(:id)
|
public_categories = Category.where(read_restricted: false).pluck(:id)
|
||||||
manager = DiscourseAi::Embeddings::Manager.new(Topic.first)
|
manager = DiscourseAi::Embeddings::Manager.new(Topic.first)
|
||||||
|
|
||||||
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
|
vector_rep =
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new(strategy)
|
||||||
|
table_name = vector_rep.table_name
|
||||||
|
|
||||||
Topic
|
Topic
|
||||||
.joins(
|
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
|
||||||
"LEFT JOIN #{manager.topic_embeddings_table} ON #{manager.topic_embeddings_table}.topic_id = topics.id",
|
.where("#{table_name}.topic_id IS NULL")
|
||||||
)
|
|
||||||
.where("#{manager.topic_embeddings_table}.topic_id IS NULL")
|
|
||||||
.where("topics.id >= ?", args[:start_topic].to_i || 0)
|
.where("topics.id >= ?", args[:start_topic].to_i || 0)
|
||||||
.where("category_id IN (?)", public_categories)
|
.where("category_id IN (?)", public_categories)
|
||||||
.where(deleted_at: nil)
|
.where(deleted_at: nil)
|
||||||
.order("topics.id ASC")
|
.order("topics.id ASC")
|
||||||
.find_each do |t|
|
.find_each do |t|
|
||||||
print "."
|
print "."
|
||||||
DiscourseAi::Embeddings::Manager.new(t).generate!
|
vector_rep.generate_topic_representation_from(t)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -28,25 +32,11 @@ task "ai:embeddings:index", [:work_mem] => [:environment] do |_, args|
|
||||||
lists = count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i
|
lists = count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i
|
||||||
probes = count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i
|
probes = count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i
|
||||||
|
|
||||||
manager = DiscourseAi::Embeddings::Manager.new(Topic.first)
|
vector_representation_klass = DiscourseAi::Embeddings::Vectors::Base.find_vector_representation
|
||||||
table = manager.topic_embeddings_table
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
index = "#{table}_search"
|
|
||||||
|
|
||||||
DB.exec("SET work_mem TO '#{args[:work_mem] || "1GB"}';")
|
DB.exec("SET work_mem TO '#{args[:work_mem] || "1GB"}';")
|
||||||
DB.exec(<<~SQL)
|
vector_representation_klass.new(strategy).create_index(lists, probes)
|
||||||
DROP INDEX IF EXISTS #{index};
|
|
||||||
CREATE INDEX IF NOT EXISTS
|
|
||||||
#{index}
|
|
||||||
ON
|
|
||||||
#{table}
|
|
||||||
USING
|
|
||||||
ivfflat (embeddings #{manager.model.pg_index_type})
|
|
||||||
WITH
|
|
||||||
(lists = #{lists})
|
|
||||||
WHERE
|
|
||||||
model_version = #{manager.model.version} AND
|
|
||||||
strategy_version = #{manager.strategy.version};
|
|
||||||
SQL
|
|
||||||
DB.exec("RESET work_mem;")
|
DB.exec("RESET work_mem;")
|
||||||
DB.exec("SET ivfflat.probes = #{probes};")
|
DB.exec("SET ivfflat.probes = #{probes};")
|
||||||
end
|
end
|
||||||
|
|
|
@ -17,6 +17,7 @@ register_asset "stylesheets/modules/ai-helper/common/ai-helper.scss"
|
||||||
register_asset "stylesheets/modules/ai-bot/common/bot-replies.scss"
|
register_asset "stylesheets/modules/ai-bot/common/bot-replies.scss"
|
||||||
|
|
||||||
register_asset "stylesheets/modules/embeddings/common/semantic-related-topics.scss"
|
register_asset "stylesheets/modules/embeddings/common/semantic-related-topics.scss"
|
||||||
|
register_asset "stylesheets/modules/embeddings/common/semantic-search.scss"
|
||||||
|
|
||||||
module ::DiscourseAi
|
module ::DiscourseAi
|
||||||
PLUGIN_NAME = "discourse-ai"
|
PLUGIN_NAME = "discourse-ai"
|
||||||
|
|
|
@ -1,44 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
require_relative "../../support/embeddings_generation_stubs"
|
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Embeddings::Manager do
|
|
||||||
let(:user) { Fabricate(:user) }
|
|
||||||
let(:expected_embedding) do
|
|
||||||
JSON.parse(
|
|
||||||
File.read("#{Rails.root}/plugins/discourse-ai/spec/fixtures/embeddings/embedding.txt"),
|
|
||||||
)
|
|
||||||
end
|
|
||||||
let(:discourse_model) { "all-mpnet-base-v2" }
|
|
||||||
|
|
||||||
before do
|
|
||||||
SiteSetting.discourse_ai_enabled = true
|
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
|
||||||
SiteSetting.ai_embeddings_model = "all-mpnet-base-v2"
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
|
||||||
Jobs.run_immediately!
|
|
||||||
end
|
|
||||||
|
|
||||||
it "generates embeddings for new topics automatically" do
|
|
||||||
pc =
|
|
||||||
PostCreator.new(
|
|
||||||
user,
|
|
||||||
raw: "this is the new content for my topic",
|
|
||||||
title: "this is my new topic title",
|
|
||||||
)
|
|
||||||
input =
|
|
||||||
"This is my new topic title\n\nUncategorized\n\n\n\nthis is the new content for my topic\n\n"
|
|
||||||
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
|
|
||||||
post = pc.create
|
|
||||||
manager = DiscourseAi::Embeddings::Manager.new(post.topic)
|
|
||||||
|
|
||||||
embeddings =
|
|
||||||
DB.query_single(
|
|
||||||
"SELECT embeddings FROM #{manager.topic_embeddings_table} WHERE topic_id = #{post.topic.id}",
|
|
||||||
).first
|
|
||||||
|
|
||||||
expect(embeddings.split(",")[1].to_f).to be_within(0.0001).of(expected_embedding[1])
|
|
||||||
expect(embeddings.split(",")[13].to_f).to be_within(0.0001).of(expected_embedding[13])
|
|
||||||
expect(embeddings.split(",")[135].to_f).to be_within(0.0001).of(expected_embedding[135])
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -28,97 +28,4 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "SemanticTopicQuery extension" do
|
|
||||||
describe "#list_semantic_related_topics" do
|
|
||||||
subject(:topic_query) { DiscourseAi::Embeddings::SemanticTopicQuery.new(user) }
|
|
||||||
|
|
||||||
fab!(:target) { Fabricate(:topic) }
|
|
||||||
|
|
||||||
def stub_semantic_search_with(results)
|
|
||||||
DiscourseAi::Embeddings::SemanticRelated.expects(:related_topic_ids_for).returns(results)
|
|
||||||
end
|
|
||||||
|
|
||||||
context "when the semantic search returns an unlisted topic" do
|
|
||||||
fab!(:unlisted_topic) { Fabricate(:topic, visible: false) }
|
|
||||||
|
|
||||||
before { stub_semantic_search_with([unlisted_topic.id]) }
|
|
||||||
|
|
||||||
it "filters it out" do
|
|
||||||
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "when the semantic search returns a private topic" do
|
|
||||||
fab!(:private_topic) { Fabricate(:private_message_topic) }
|
|
||||||
|
|
||||||
before { stub_semantic_search_with([private_topic.id]) }
|
|
||||||
|
|
||||||
it "filters it out" do
|
|
||||||
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "when the semantic search returns a topic from a restricted category" do
|
|
||||||
fab!(:group) { Fabricate(:group) }
|
|
||||||
fab!(:category) { Fabricate(:private_category, group: group) }
|
|
||||||
fab!(:secured_category_topic) { Fabricate(:topic, category: category) }
|
|
||||||
|
|
||||||
before { stub_semantic_search_with([secured_category_topic.id]) }
|
|
||||||
|
|
||||||
it "filters it out" do
|
|
||||||
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
|
|
||||||
end
|
|
||||||
|
|
||||||
it "doesn't filter it out if the user has access to the category" do
|
|
||||||
group.add(user)
|
|
||||||
|
|
||||||
expect(topic_query.list_semantic_related_topics(target).topics).to contain_exactly(
|
|
||||||
secured_category_topic,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "when the semantic search returns a closed topic and we explicitly exclude them" do
|
|
||||||
fab!(:closed_topic) { Fabricate(:topic, closed: true) }
|
|
||||||
|
|
||||||
before do
|
|
||||||
SiteSetting.ai_embeddings_semantic_related_include_closed_topics = false
|
|
||||||
stub_semantic_search_with([closed_topic.id])
|
|
||||||
end
|
|
||||||
|
|
||||||
it "filters it out" do
|
|
||||||
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "when the semantic search returns public topics" do
|
|
||||||
fab!(:normal_topic_1) { Fabricate(:topic) }
|
|
||||||
fab!(:normal_topic_2) { Fabricate(:topic) }
|
|
||||||
fab!(:normal_topic_3) { Fabricate(:topic) }
|
|
||||||
fab!(:closed_topic) { Fabricate(:topic, closed: true) }
|
|
||||||
|
|
||||||
before do
|
|
||||||
stub_semantic_search_with(
|
|
||||||
[closed_topic.id, normal_topic_1.id, normal_topic_2.id, normal_topic_3.id],
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "filters it out" do
|
|
||||||
expect(topic_query.list_semantic_related_topics(target).topics).to eq(
|
|
||||||
[closed_topic, normal_topic_1, normal_topic_2, normal_topic_3],
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "returns the plugin limit for the number of results" do
|
|
||||||
SiteSetting.ai_embeddings_semantic_related_topics = 2
|
|
||||||
|
|
||||||
expect(topic_query.list_semantic_related_topics(target).topics).to contain_exactly(
|
|
||||||
closed_topic,
|
|
||||||
normal_topic_1,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,24 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
require_relative "../../../../support/embeddings_generation_stubs"
|
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Embeddings::Models::AllMpnetBaseV2 do
|
|
||||||
describe "#generate_embeddings" do
|
|
||||||
let(:input) { "test" }
|
|
||||||
let(:expected_embedding) { [0.0038493, 0.482001] }
|
|
||||||
|
|
||||||
context "when the model uses the discourse service to create embeddings" do
|
|
||||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
|
||||||
|
|
||||||
let(:discourse_model) { "all-mpnet-base-v2" }
|
|
||||||
|
|
||||||
it "returns an embedding for a given string" do
|
|
||||||
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
|
|
||||||
|
|
||||||
embedding = described_class.generate_embeddings(input)
|
|
||||||
|
|
||||||
expect(embedding).to contain_exactly(*expected_embedding)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -1,22 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
require_relative "../../../../support/embeddings_generation_stubs"
|
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Embeddings::Models::TextEmbeddingAda002 do
|
|
||||||
describe "#generate_embeddings" do
|
|
||||||
let(:input) { "test" }
|
|
||||||
let(:expected_embedding) { [0.0038493, 0.482001] }
|
|
||||||
|
|
||||||
context "when the model uses OpenAI to create embeddings" do
|
|
||||||
let(:openai_model) { "text-embedding-ada-002" }
|
|
||||||
|
|
||||||
it "returns an embedding for a given string" do
|
|
||||||
EmbeddingsGenerationStubs.openai_service(openai_model, input, expected_embedding)
|
|
||||||
|
|
||||||
embedding = described_class.generate_embeddings(input)
|
|
||||||
|
|
||||||
expect(embedding).to contain_exactly(*expected_embedding)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -3,6 +3,8 @@
|
||||||
require "rails_helper"
|
require "rails_helper"
|
||||||
|
|
||||||
describe DiscourseAi::Embeddings::SemanticRelated do
|
describe DiscourseAi::Embeddings::SemanticRelated do
|
||||||
|
subject(:semantic_related) { described_class.new }
|
||||||
|
|
||||||
fab!(:target) { Fabricate(:topic) }
|
fab!(:target) { Fabricate(:topic) }
|
||||||
fab!(:normal_topic_1) { Fabricate(:topic) }
|
fab!(:normal_topic_1) { Fabricate(:topic) }
|
||||||
fab!(:normal_topic_2) { Fabricate(:topic) }
|
fab!(:normal_topic_2) { Fabricate(:topic) }
|
||||||
|
@ -25,13 +27,13 @@ describe DiscourseAi::Embeddings::SemanticRelated do
|
||||||
results = nil
|
results = nil
|
||||||
|
|
||||||
expect_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
|
expect_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
|
||||||
results = described_class.related_topic_ids_for(topic)
|
results = semantic_related.related_topic_ids_for(topic)
|
||||||
end
|
end
|
||||||
|
|
||||||
expect(results).to eq([])
|
expect(results).to eq([])
|
||||||
|
|
||||||
expect_not_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
|
expect_not_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
|
||||||
results = described_class.related_topic_ids_for(topic)
|
results = semantic_related.related_topic_ids_for(topic)
|
||||||
end
|
end
|
||||||
|
|
||||||
expect(results).to eq([])
|
expect(results).to eq([])
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "../../../support/embeddings_generation_stubs"
|
||||||
|
require_relative "../../../support/openai_completions_inference_stubs"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
||||||
fab!(:post) { Fabricate(:post) }
|
fab!(:post) { Fabricate(:post) }
|
||||||
fab!(:user) { Fabricate(:user) }
|
fab!(:user) { Fabricate(:user) }
|
||||||
|
@ -8,10 +11,28 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
||||||
let(:subject) { described_class.new(Guardian.new(user)) }
|
let(:subject) { described_class.new(Guardian.new(user)) }
|
||||||
|
|
||||||
describe "#search_for_topics" do
|
describe "#search_for_topics" do
|
||||||
|
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
|
||||||
|
|
||||||
|
before do
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||||
|
|
||||||
|
prompt = DiscourseAi::Embeddings::HydeGenerators::OpenAi.new.prompt(query)
|
||||||
|
OpenAiCompletionsInferenceStubs.stub_response(prompt, hypothetical_post)
|
||||||
|
|
||||||
|
hyde_embedding = [0.049382, 0.9999]
|
||||||
|
EmbeddingsGenerationStubs.discourse_service(
|
||||||
|
SiteSetting.ai_embeddings_model,
|
||||||
|
hypothetical_post,
|
||||||
|
hyde_embedding,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
after { described_class.clear_cache_for(query) }
|
||||||
|
|
||||||
def stub_candidate_ids(candidate_ids)
|
def stub_candidate_ids(candidate_ids)
|
||||||
DiscourseAi::Embeddings::SemanticSearch
|
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
|
||||||
.any_instance
|
.any_instance
|
||||||
.expects(:asymmetric_semantic_search)
|
.expects(:asymmetric_topics_similarity_search)
|
||||||
.returns(candidate_ids)
|
.returns(candidate_ids)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -12,9 +12,14 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
||||||
fab!(:target) { Fabricate(:topic) }
|
fab!(:target) { Fabricate(:topic) }
|
||||||
|
|
||||||
def stub_semantic_search_with(results)
|
def stub_semantic_search_with(results)
|
||||||
DiscourseAi::Embeddings::SemanticRelated.expects(:related_topic_ids_for).returns(results)
|
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
|
||||||
|
.any_instance
|
||||||
|
.expects(:symmetric_topics_similarity_search)
|
||||||
|
.returns(results.concat([target.id]))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
after { DiscourseAi::Embeddings::SemanticRelated.clear_cache_for(target) }
|
||||||
|
|
||||||
context "when the semantic search returns an unlisted topic" do
|
context "when the semantic search returns an unlisted topic" do
|
||||||
fab!(:unlisted_topic) { Fabricate(:topic, visible: false) }
|
fab!(:unlisted_topic) { Fabricate(:topic, visible: false) }
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
||||||
describe "#process!" do
|
subject(:truncation) { described_class.new }
|
||||||
context "when the model uses OpenAI to create embeddings" do
|
|
||||||
|
describe "#prepare_text_from" do
|
||||||
|
context "when using vector from OpenAI" do
|
||||||
before { SiteSetting.max_post_length = 100_000 }
|
before { SiteSetting.max_post_length = 100_000 }
|
||||||
|
|
||||||
fab!(:topic) { Fabricate(:topic) }
|
fab!(:topic) { Fabricate(:topic) }
|
||||||
|
@ -18,13 +20,15 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
||||||
end
|
end
|
||||||
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
|
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
|
||||||
|
|
||||||
let(:model) { DiscourseAi::Embeddings::Models::Base.descendants.sample(1).first }
|
let(:model) do
|
||||||
let(:truncation) { described_class.new(topic, model) }
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new(truncation)
|
||||||
|
end
|
||||||
|
|
||||||
it "truncates a topic" do
|
it "truncates a topic" do
|
||||||
truncation.process!
|
prepared_text =
|
||||||
|
truncation.prepare_text_from(topic, model.tokenizer, model.max_sequence_length)
|
||||||
|
|
||||||
expect(model.tokenizer.size(truncation.processed_target)).to be <= model.max_sequence_length
|
expect(model.tokenizer.size(prepared_text)).to be <= model.max_sequence_length
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "../../../../support/embeddings_generation_stubs"
|
||||||
|
require_relative "vector_rep_shared_examples"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do
|
||||||
|
subject(:vector_rep) { described_class.new(truncation) }
|
||||||
|
|
||||||
|
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||||
|
|
||||||
|
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||||
|
|
||||||
|
def stub_vector_mapping(text, expected_embedding)
|
||||||
|
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
|
||||||
|
end
|
||||||
|
|
||||||
|
it_behaves_like "generates and store embedding using with vector representation"
|
||||||
|
end
|
|
@ -0,0 +1,22 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "../../../../support/embeddings_generation_stubs"
|
||||||
|
require_relative "vector_rep_shared_examples"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large do
|
||||||
|
subject(:vector_rep) { described_class.new(truncation) }
|
||||||
|
|
||||||
|
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||||
|
|
||||||
|
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||||
|
|
||||||
|
def stub_vector_mapping(text, expected_embedding)
|
||||||
|
EmbeddingsGenerationStubs.discourse_service(
|
||||||
|
vector_rep.name,
|
||||||
|
"query: #{text}",
|
||||||
|
expected_embedding,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
it_behaves_like "generates and store embedding using with vector representation"
|
||||||
|
end
|
|
@ -0,0 +1,16 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "../../../../support/embeddings_generation_stubs"
|
||||||
|
require_relative "vector_rep_shared_examples"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002 do
|
||||||
|
subject(:vector_rep) { described_class.new(truncation) }
|
||||||
|
|
||||||
|
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||||
|
|
||||||
|
def stub_vector_mapping(text, expected_embedding)
|
||||||
|
EmbeddingsGenerationStubs.openai_service(vector_rep.name, text, expected_embedding)
|
||||||
|
end
|
||||||
|
|
||||||
|
it_behaves_like "generates and store embedding using with vector representation"
|
||||||
|
end
|
|
@ -0,0 +1,54 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.shared_examples "generates and store embedding using with vector representation" do
|
||||||
|
before { @expected_embedding = [0.0038493] * vector_rep.dimensions }
|
||||||
|
|
||||||
|
describe "#vector_from" do
|
||||||
|
it "creates a vector from a given string" do
|
||||||
|
text = "This is a piece of text"
|
||||||
|
stub_vector_mapping(text, @expected_embedding)
|
||||||
|
|
||||||
|
expect(vector_rep.vector_from(text)).to eq(@expected_embedding)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#generate_topic_representation_from" do
|
||||||
|
fab!(:topic) { Fabricate(:topic) }
|
||||||
|
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||||
|
|
||||||
|
it "creates a vector from a topic and stores it in the database" do
|
||||||
|
text =
|
||||||
|
truncation.prepare_text_from(
|
||||||
|
topic,
|
||||||
|
vector_rep.tokenizer,
|
||||||
|
vector_rep.max_sequence_length - 2,
|
||||||
|
)
|
||||||
|
stub_vector_mapping(text, @expected_embedding)
|
||||||
|
|
||||||
|
vector_rep.generate_topic_representation_from(topic)
|
||||||
|
|
||||||
|
expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#asymmetric_topics_similarity_search" do
|
||||||
|
fab!(:topic) { Fabricate(:topic) }
|
||||||
|
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||||
|
|
||||||
|
it "finds IDs of similar topics with a given embedding" do
|
||||||
|
similar_vector = [0.0038494] * vector_rep.dimensions
|
||||||
|
text =
|
||||||
|
truncation.prepare_text_from(
|
||||||
|
topic,
|
||||||
|
vector_rep.tokenizer,
|
||||||
|
vector_rep.max_sequence_length - 2,
|
||||||
|
)
|
||||||
|
stub_vector_mapping(text, @expected_embedding)
|
||||||
|
vector_rep.generate_topic_representation_from(topic)
|
||||||
|
|
||||||
|
expect(
|
||||||
|
vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),
|
||||||
|
).to contain_exactly(topic.id)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -16,9 +16,10 @@ describe ::TopicsController do
|
||||||
|
|
||||||
context "when a user is logged on" do
|
context "when a user is logged on" do
|
||||||
it "includes related topics in payload when configured" do
|
it "includes related topics in payload when configured" do
|
||||||
DiscourseAi::Embeddings::SemanticRelated.stubs(:related_topic_ids_for).returns(
|
DiscourseAi::Embeddings::SemanticRelated
|
||||||
[topic1.id, topic2.id, topic3.id],
|
.any_instance
|
||||||
)
|
.stubs(:related_topic_ids_for)
|
||||||
|
.returns([topic1.id, topic2.id, topic3.id])
|
||||||
|
|
||||||
get("#{topic.relative_url}.json")
|
get("#{topic.relative_url}.json")
|
||||||
expect(response.status).to eq(200)
|
expect(response.status).to eq(200)
|
||||||
|
|
Loading…
Reference in New Issue