diff --git a/app/jobs/regular/post_sentiment_analysis.rb b/app/jobs/regular/post_sentiment_analysis.rb index cdbb668d..bcd1e1d5 100644 --- a/app/jobs/regular/post_sentiment_analysis.rb +++ b/app/jobs/regular/post_sentiment_analysis.rb @@ -9,9 +9,7 @@ module ::Jobs post = Post.find_by(id: post_id, post_type: Post.types[:regular]) return if post&.raw.blank? - DiscourseAi::PostClassificator.new( - DiscourseAi::Sentiment::SentimentClassification.new, - ).classify!(post) + DiscourseAi::Sentiment::PostClassification.new.classify!(post) end end end diff --git a/lib/inference/hugging_face_text_embeddings.rb b/lib/inference/hugging_face_text_embeddings.rb index 743a2b57..3881cfc3 100644 --- a/lib/inference/hugging_face_text_embeddings.rb +++ b/lib/inference/hugging_face_text_embeddings.rb @@ -64,8 +64,8 @@ module ::DiscourseAi JSON.parse(response.body, symbolize_names: true) end - def classify(content, model_config) - headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } + def classify(content, model_config, base_url = Discourse.base_url) + headers = { "Referer" => base_url, "Content-Type" => "application/json" } headers["X-API-KEY"] = model_config.api_key headers["Authorization"] = "Bearer #{model_config.api_key}" diff --git a/lib/sentiment/post_classification.rb b/lib/sentiment/post_classification.rb new file mode 100644 index 00000000..eea10ce9 --- /dev/null +++ b/lib/sentiment/post_classification.rb @@ -0,0 +1,111 @@ +# frozen_string_literal: true + +module DiscourseAi + module Sentiment + class PostClassification + def bulk_classify!(relation) + http_pool_size = 100 + pool = + Concurrent::CachedThreadPool.new( + min_threads: 0, + max_threads: http_pool_size, + idletime: 30, + ) + + available_classifiers = classifiers + base_url = Discourse.base_url + + promised_classifications = + relation + .map do |record| + text = prepare_text(record) + next if text.blank? + + Concurrent::Promises + .fulfilled_future({ target: record, text: text }, pool) + .then_on(pool) do |w_text| + results = Concurrent::Hash.new + + promised_target_results = + available_classifiers.map do |c| + Concurrent::Promises.future_on(pool) do + results[c.model_name] = request_with(w_text[:text], c, base_url) + end + end + + Concurrent::Promises + .zip(*promised_target_results) + .then_on(pool) { |_| w_text.merge(classification: results) } + end + .flat(1) + end + .compact + + Concurrent::Promises + .zip(*promised_classifications) + .value! + .each { |r| store_classification(r[:target], r[:classification]) } + + pool.shutdown + pool.wait_for_termination + end + + def classify!(target) + return if target.blank? + + to_classify = prepare_text(target) + return if to_classify.blank? + + results = + classifiers.reduce({}) do |memo, model| + memo[model.model_name] = request_with(to_classify, model) + memo + end + + store_classification(target, results) + end + + private + + def prepare_text(target) + content = + if target.post_number == 1 + "#{target.topic.title}\n#{target.raw}" + else + target.raw + end + + Tokenizer::BertTokenizer.truncate(content, 512) + end + + def classifiers + DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values + end + + def request_with(content, config, base_url = Discourse.base_url) + DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url) + end + + def store_classification(target, classification) + attrs = + classification.map do |model_name, classifications| + { + model_used: model_name, + target_id: target.id, + target_type: target.class.sti_name, + classification_type: :sentiment, + classification: classifications, + updated_at: DateTime.now, + created_at: DateTime.now, + } + end + + ClassificationResult.upsert_all( + attrs, + unique_by: %i[target_id target_type model_used], + update_only: %i[classification], + ) + end + end + end +end diff --git a/lib/tasks/modules/sentiment/backfill.rake b/lib/tasks/modules/sentiment/backfill.rake index 30d43e43..a975434e 100644 --- a/lib/tasks/modules/sentiment/backfill.rake +++ b/lib/tasks/modules/sentiment/backfill.rake @@ -14,11 +14,8 @@ task "ai:sentiment:backfill", [:start_post] => [:environment] do |_, args| .where("category_id IN (?)", public_categories) .where(posts: { deleted_at: nil }) .where(topics: { deleted_at: nil }) - .order("posts.id ASC") - .find_each do |post| + .find_in_batches do |batch| print "." - DiscourseAi::PostClassificator.new( - DiscourseAi::Sentiment::SentimentClassification.new, - ).classify!(post) + DiscourseAi::Sentiment::PostClassification.new.bulk_classify!(batch) end end diff --git a/spec/lib/modules/sentiment/post_classification_spec.rb b/spec/lib/modules/sentiment/post_classification_spec.rb new file mode 100644 index 00000000..99ce44de --- /dev/null +++ b/spec/lib/modules/sentiment/post_classification_spec.rb @@ -0,0 +1,47 @@ +# frozen_string_literal: true + +require_relative "../../../support/sentiment_inference_stubs" + +RSpec.describe DiscourseAi::Sentiment::PostClassification do + fab!(:post_1) { Fabricate(:post, post_number: 2) } + + before do + SiteSetting.ai_sentiment_enabled = true + SiteSetting.ai_sentiment_model_configs = + "[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]" + end + + describe "#classify!" do + it "does nothing if the post content is blank" do + post_1.update_columns(raw: "") + + subject.classify!(post_1) + + expect(ClassificationResult.where(target: post_1).count).to be_zero + end + + it "successfully classifies the post" do + expected_analysis = DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.length + SentimentInferenceStubs.stub_classification(post_1) + + subject.classify!(post_1) + + expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis) + end + end + + describe "#classify_bulk!" do + fab!(:post_2) { Fabricate(:post, post_number: 2) } + + it "classifies all given posts" do + expected_analysis = DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.length + SentimentInferenceStubs.stub_classification(post_1) + SentimentInferenceStubs.stub_classification(post_2) + + subject.bulk_classify!(Post.where(id: [post_1.id, post_2.id])) + + expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis) + expect(ClassificationResult.where(target: post_2).count).to eq(expected_analysis) + end + end +end