mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-02-16 16:34:45 +00:00
FEATURE: add semantic search with hyde bot (#210)
In specific scenarios (no special filters or limits) we will also always include 5 semantic results (at least) with every query. This effectively means that all very wide queries will always return 20 results, regardless of how complex they are. Also: FIX: embedding backfill rake task not working We renamed internals, this corrects the implementation
This commit is contained in:
parent
abe96d5533
commit
615eb8b440
@ -15,7 +15,8 @@ module DiscourseAi::AiBot::Commands
|
|||||||
[
|
[
|
||||||
Parameter.new(
|
Parameter.new(
|
||||||
name: "search_query",
|
name: "search_query",
|
||||||
description: "Search query (correct bad spelling, remove connector words!)",
|
description:
|
||||||
|
"Specific keywords to search for, space seperated (correct bad spelling, remove connector words)",
|
||||||
type: "string",
|
type: "string",
|
||||||
),
|
),
|
||||||
Parameter.new(
|
Parameter.new(
|
||||||
@ -93,6 +94,9 @@ module DiscourseAi::AiBot::Commands
|
|||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
MAX_RESULTS = 20
|
||||||
|
MIN_SEMANTIC_RESULTS = 5
|
||||||
|
|
||||||
def process(**search_args)
|
def process(**search_args)
|
||||||
limit = nil
|
limit = nil
|
||||||
|
|
||||||
@ -120,12 +124,35 @@ module DiscourseAi::AiBot::Commands
|
|||||||
)
|
)
|
||||||
|
|
||||||
# let's be frugal with tokens, 50 results is too much and stuff gets cut off
|
# let's be frugal with tokens, 50 results is too much and stuff gets cut off
|
||||||
limit ||= 20
|
limit ||= MAX_RESULTS
|
||||||
limit = 20 if limit > 20
|
limit = MAX_RESULTS if limit > MAX_RESULTS
|
||||||
|
|
||||||
|
should_try_semantic_search = SiteSetting.ai_embeddings_semantic_search_enabled
|
||||||
|
should_try_semantic_search &&= (limit == MAX_RESULTS)
|
||||||
|
should_try_semantic_search &&= (search_args.keys - %i[search_query order]).length == 0
|
||||||
|
should_try_semantic_search &&= (search_args[:search_query].present?)
|
||||||
|
|
||||||
|
limit = limit - MIN_SEMANTIC_RESULTS if should_try_semantic_search
|
||||||
|
|
||||||
posts = results&.posts || []
|
posts = results&.posts || []
|
||||||
posts = posts[0..limit - 1]
|
posts = posts[0..limit - 1]
|
||||||
|
|
||||||
|
if should_try_semantic_search
|
||||||
|
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(Guardian.new())
|
||||||
|
topic_ids = Set.new(posts.map(&:topic_id))
|
||||||
|
|
||||||
|
semantic_search
|
||||||
|
.search_for_topics(search_args[:search_query])
|
||||||
|
.each do |post|
|
||||||
|
next if topic_ids.include?(post.topic_id)
|
||||||
|
|
||||||
|
topic_ids << post.topic_id
|
||||||
|
posts << post
|
||||||
|
|
||||||
|
break if posts.length >= MAX_RESULTS
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
@last_num_results = posts.length
|
@last_num_results = posts.length
|
||||||
|
|
||||||
if posts.blank?
|
if posts.blank?
|
||||||
|
@ -3,11 +3,9 @@
|
|||||||
desc "Backfill embeddings for all topics"
|
desc "Backfill embeddings for all topics"
|
||||||
task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args|
|
task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args|
|
||||||
public_categories = Category.where(read_restricted: false).pluck(:id)
|
public_categories = Category.where(read_restricted: false).pluck(:id)
|
||||||
manager = DiscourseAi::Embeddings::Manager.new(Topic.first)
|
|
||||||
|
|
||||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
vector_rep =
|
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new(strategy)
|
|
||||||
table_name = vector_rep.table_name
|
table_name = vector_rep.table_name
|
||||||
|
|
||||||
Topic
|
Topic
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#frozen_string_literal: true
|
#frozen_string_literal: true
|
||||||
|
|
||||||
require_relative "../../../../support/openai_completions_inference_stubs"
|
require_relative "../../../../support/openai_completions_inference_stubs"
|
||||||
|
require_relative "../../../../support/embeddings_generation_stubs"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
|
RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
@ -19,6 +20,43 @@ RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
|
|||||||
expect(results[:rows]).to eq([])
|
expect(results[:rows]).to eq([])
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "semantic search" do
|
||||||
|
let (:query) {
|
||||||
|
"this is an expanded search"
|
||||||
|
}
|
||||||
|
after { DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query) }
|
||||||
|
|
||||||
|
it "supports semantic search when enabled" do
|
||||||
|
SiteSetting.ai_embeddings_semantic_search_enabled = true
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||||
|
|
||||||
|
WebMock.stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||||
|
status: 200,
|
||||||
|
body: JSON.dump(OpenAiCompletionsInferenceStubs.response(query)),
|
||||||
|
)
|
||||||
|
|
||||||
|
hyde_embedding = [0.049382, 0.9999]
|
||||||
|
EmbeddingsGenerationStubs.discourse_service(
|
||||||
|
SiteSetting.ai_embeddings_model,
|
||||||
|
query,
|
||||||
|
hyde_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
post1 = Fabricate(:post)
|
||||||
|
search = described_class.new(bot_user: bot_user, post: post1, args: nil)
|
||||||
|
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
|
||||||
|
.any_instance
|
||||||
|
.expects(:asymmetric_topics_similarity_search)
|
||||||
|
.returns([post1.topic_id])
|
||||||
|
|
||||||
|
results = search.process(search_query: "hello world, sam")
|
||||||
|
|
||||||
|
expect(results[:args]).to eq({ search_query: "hello world, sam" })
|
||||||
|
expect(results[:rows].length).to eq(1)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
it "supports subfolder properly" do
|
it "supports subfolder properly" do
|
||||||
Discourse.stubs(:base_path).returns("/subfolder")
|
Discourse.stubs(:base_path).returns("/subfolder")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user