FEATURE: Backfill posts sentiment. (#982)
* FEATURE: Backfill posts sentiment. It adds a scheduled job to backfill posts' sentiment, similar to our existing rake task, but with two settings to control the batch size and posts' max-age. * Make sure model_name order is consistent.
This commit is contained in:
parent
7c65dd171f
commit
ce6a2eca21
|
@ -0,0 +1,30 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module Jobs
|
||||
class SentimentBackfill < ::Jobs::Scheduled
|
||||
every 5.minutes
|
||||
cluster_concurrency 1
|
||||
|
||||
def execute(_args)
|
||||
return if !SiteSetting.ai_sentiment_enabled
|
||||
|
||||
base_budget = SiteSetting.ai_sentiment_backfill_maximum_posts_per_hour
|
||||
return if base_budget.zero?
|
||||
# Split budget in 12 intervals, but make sure is at least one.
|
||||
#
|
||||
# This is not exact as we don't have a way of tracking how many
|
||||
# posts we classified in the current hour, but it's a good enough approximation.
|
||||
limit_per_job = [base_budget, 12].max / 12
|
||||
|
||||
classificator = DiscourseAi::Sentiment::PostClassification.new
|
||||
return if !classificator.has_classifiers?
|
||||
|
||||
posts =
|
||||
DiscourseAi::Sentiment::PostClassification.backfill_query(
|
||||
max_age_days: SiteSetting.ai_sentiment_backfill_post_max_age_days,
|
||||
).limit(limit_per_job)
|
||||
|
||||
classificator.bulk_classify!(posts)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -17,6 +17,15 @@ discourse_ai:
|
|||
ai_sentiment_model_configs:
|
||||
default: ""
|
||||
json_schema: DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema
|
||||
ai_sentiment_backfill_maximum_posts_per_hour:
|
||||
default: 250
|
||||
min: 0
|
||||
max: 10000
|
||||
hidden: true
|
||||
ai_sentiment_backfill_post_max_age_days:
|
||||
default: 60
|
||||
hidden: true
|
||||
|
||||
|
||||
ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations"
|
||||
ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings"
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module PostExtensions
|
||||
extend ActiveSupport::Concern
|
||||
|
||||
prepended do
|
||||
has_many :classification_results, as: :target
|
||||
|
||||
has_many :sentiment_classifications,
|
||||
-> { where(classification_type: "sentiment") },
|
||||
class_name: "ClassificationResult",
|
||||
as: :target
|
||||
end
|
||||
end
|
||||
end
|
|
@ -3,6 +3,47 @@
|
|||
module DiscourseAi
|
||||
module Sentiment
|
||||
class PostClassification
|
||||
def self.backfill_query(from_post_id: nil, max_age_days: nil)
|
||||
available_classifier_names =
|
||||
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema
|
||||
.values
|
||||
.map { |mc| mc.model_name.downcase }
|
||||
.sort
|
||||
|
||||
base_query =
|
||||
Post
|
||||
.includes(:sentiment_classifications)
|
||||
.joins("INNER JOIN topics ON topics.id = posts.topic_id")
|
||||
.where(post_type: Post.types[:regular])
|
||||
.where.not(topics: { archetype: Archetype.private_message })
|
||||
.where(posts: { deleted_at: nil })
|
||||
.where(topics: { deleted_at: nil })
|
||||
.joins(<<~SQL)
|
||||
LEFT JOIN classification_results crs
|
||||
ON crs.target_id = posts.id
|
||||
AND crs.target_type = 'Post'
|
||||
AND crs.classification_type = 'sentiment'
|
||||
SQL
|
||||
.group("posts.id")
|
||||
.having(<<~SQL, available_classifier_names)
|
||||
COUNT(crs.model_used) = 0
|
||||
OR array_agg(
|
||||
DISTINCT LOWER(crs.model_used) ORDER BY LOWER(crs.model_used)
|
||||
)::text[] IS DISTINCT FROM array[?]
|
||||
SQL
|
||||
|
||||
base_query = base_query.where("posts.id >= ?", from_post_id.to_i) if from_post_id.present?
|
||||
|
||||
if max_age_days.present?
|
||||
base_query =
|
||||
base_query.where(
|
||||
"posts.created_at > current_date - INTERVAL '#{max_age_days.to_i} DAY'",
|
||||
)
|
||||
end
|
||||
|
||||
base_query
|
||||
end
|
||||
|
||||
def bulk_classify!(relation)
|
||||
http_pool_size = 100
|
||||
pool =
|
||||
|
@ -13,6 +54,7 @@ module DiscourseAi
|
|||
)
|
||||
|
||||
available_classifiers = classifiers
|
||||
return if available_classifiers.blank?
|
||||
base_url = Discourse.base_url
|
||||
|
||||
promised_classifications =
|
||||
|
@ -25,9 +67,13 @@ module DiscourseAi
|
|||
.fulfilled_future({ target: record, text: text }, pool)
|
||||
.then_on(pool) do |w_text|
|
||||
results = Concurrent::Hash.new
|
||||
already_classified = w_text[:target].sentiment_classifications.map(&:model_used)
|
||||
|
||||
classifiers_for_target =
|
||||
available_classifiers.reject { |ac| already_classified.include?(ac.model_name) }
|
||||
|
||||
promised_target_results =
|
||||
available_classifiers.map do |c|
|
||||
classifiers_for_target.map do |c|
|
||||
Concurrent::Promises.future_on(pool) do
|
||||
results[c.model_name] = request_with(w_text[:text], c, base_url)
|
||||
end
|
||||
|
@ -52,12 +98,17 @@ module DiscourseAi
|
|||
|
||||
def classify!(target)
|
||||
return if target.blank?
|
||||
return if classifiers.blank?
|
||||
|
||||
to_classify = prepare_text(target)
|
||||
return if to_classify.blank?
|
||||
|
||||
already_classified = target.sentiment_classifications.map(&:model_used)
|
||||
classifiers_for_target =
|
||||
classifiers.reject { |ac| already_classified.include?(ac.model_name) }
|
||||
|
||||
results =
|
||||
classifiers.reduce({}) do |memo, model|
|
||||
classifiers_for_target.reduce({}) do |memo, model|
|
||||
memo[model.model_name] = request_with(to_classify, model)
|
||||
memo
|
||||
end
|
||||
|
@ -69,6 +120,10 @@ module DiscourseAi
|
|||
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values
|
||||
end
|
||||
|
||||
def has_classifiers?
|
||||
classifiers.present?
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def prepare_text(target)
|
||||
|
|
|
@ -27,6 +27,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def self.values
|
||||
return {} if SiteSetting.ai_sentiment_model_configs.blank?
|
||||
JSON.parse(SiteSetting.ai_sentiment_model_configs, object_class: OpenStruct)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -2,18 +2,8 @@
|
|||
|
||||
desc "Backfill sentiment for all posts"
|
||||
task "ai:sentiment:backfill", [:start_post] => [:environment] do |_, args|
|
||||
public_categories = Category.where(read_restricted: false).pluck(:id)
|
||||
|
||||
Post
|
||||
.joins("INNER JOIN topics ON topics.id = posts.topic_id")
|
||||
.joins(
|
||||
"LEFT JOIN classification_results ON classification_results.target_id = posts.id AND classification_results.target_type = 'Post'",
|
||||
)
|
||||
.where("classification_results.target_id IS NULL")
|
||||
.where("posts.id >= ?", args[:start_post].to_i || 0)
|
||||
.where("category_id IN (?)", public_categories)
|
||||
.where(posts: { deleted_at: nil })
|
||||
.where(topics: { deleted_at: nil })
|
||||
DiscourseAi::Sentiment::PostClassification
|
||||
.backfill_query(from_post_id: args[:start_post].to_i)
|
||||
.find_in_batches do |batch|
|
||||
print "."
|
||||
DiscourseAi::Sentiment::PostClassification.new.bulk_classify!(batch)
|
||||
|
|
|
@ -89,6 +89,7 @@ after_initialize do
|
|||
reloadable_patch do |plugin|
|
||||
Guardian.prepend DiscourseAi::GuardianExtensions
|
||||
Topic.prepend DiscourseAi::TopicExtensions
|
||||
Post.prepend DiscourseAi::PostExtensions
|
||||
end
|
||||
|
||||
register_modifier(:post_should_secure_uploads?) do |_, _, topic|
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "../../support/sentiment_inference_stubs"
|
||||
|
||||
RSpec.describe Jobs::SentimentBackfill do
|
||||
describe "#execute" do
|
||||
fab!(:post)
|
||||
|
||||
before do
|
||||
SiteSetting.ai_sentiment_enabled = true
|
||||
SiteSetting.ai_sentiment_backfill_maximum_posts_per_hour = 100
|
||||
SiteSetting.ai_sentiment_model_configs =
|
||||
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
||||
end
|
||||
|
||||
let(:expected_analysis) { DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.length }
|
||||
|
||||
it "backfills when settings are correct" do
|
||||
SentimentInferenceStubs.stub_classification(post)
|
||||
subject.execute({})
|
||||
|
||||
expect(ClassificationResult.where(target: post).count).to eq(expected_analysis)
|
||||
end
|
||||
|
||||
it "does nothing when batch size is zero" do
|
||||
SiteSetting.ai_sentiment_backfill_maximum_posts_per_hour = 0
|
||||
|
||||
subject.execute({})
|
||||
|
||||
expect(ClassificationResult.count).to be_zero
|
||||
end
|
||||
|
||||
it "does nothing when sentiment is disabled" do
|
||||
SiteSetting.ai_sentiment_enabled = false
|
||||
|
||||
subject.execute({})
|
||||
|
||||
expect(ClassificationResult.count).to be_zero
|
||||
end
|
||||
|
||||
it "respects the ai_sentiment_backfill_post_max_age_days setting" do
|
||||
SentimentInferenceStubs.stub_classification(post)
|
||||
SiteSetting.ai_sentiment_backfill_post_max_age_days = 80
|
||||
post_2 = Fabricate(:post, created_at: 81.days.ago)
|
||||
|
||||
subject.execute({})
|
||||
|
||||
expect(ClassificationResult.where(target: post).count).to eq(expected_analysis)
|
||||
expect(ClassificationResult.where(target: post_2).count).to be_zero
|
||||
end
|
||||
end
|
||||
end
|
|
@ -3,8 +3,6 @@
|
|||
require_relative "../../../support/sentiment_inference_stubs"
|
||||
|
||||
RSpec.describe DiscourseAi::Sentiment::PostClassification do
|
||||
fab!(:post_1) { Fabricate(:post, post_number: 2) }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_sentiment_enabled = true
|
||||
SiteSetting.ai_sentiment_model_configs =
|
||||
|
@ -22,6 +20,8 @@ RSpec.describe DiscourseAi::Sentiment::PostClassification do
|
|||
end
|
||||
|
||||
describe "#classify!" do
|
||||
fab!(:post_1) { Fabricate(:post, post_number: 2) }
|
||||
|
||||
it "does nothing if the post content is blank" do
|
||||
post_1.update_columns(raw: "")
|
||||
|
||||
|
@ -45,9 +45,36 @@ RSpec.describe DiscourseAi::Sentiment::PostClassification do
|
|||
subject.classify!(post_1)
|
||||
check_classification_for(post_1)
|
||||
end
|
||||
|
||||
it "does nothing if there are no classification model" do
|
||||
SiteSetting.ai_sentiment_model_configs = ""
|
||||
|
||||
subject.classify!(post_1)
|
||||
|
||||
expect(ClassificationResult.where(target: post_1).count).to be_zero
|
||||
end
|
||||
|
||||
it "don't reclassify everything when a model config changes" do
|
||||
SentimentInferenceStubs.stub_classification(post_1)
|
||||
|
||||
subject.classify!(post_1)
|
||||
first_classified_at = 2.days.ago
|
||||
ClassificationResult.update_all(created_at: first_classified_at)
|
||||
|
||||
current_models = JSON.parse(SiteSetting.ai_sentiment_model_configs)
|
||||
current_models << { model_name: "new", endpoint: "https://test.com", api_key: "123" }
|
||||
SiteSetting.ai_sentiment_model_configs = current_models.to_json
|
||||
|
||||
SentimentInferenceStubs.stub_classification(post_1)
|
||||
subject.classify!(post_1.reload)
|
||||
|
||||
new_classifications = ClassificationResult.where("created_at > ?", first_classified_at).count
|
||||
expect(new_classifications).to eq(1)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#classify_bulk!" do
|
||||
fab!(:post_1) { Fabricate(:post, post_number: 2) }
|
||||
fab!(:post_2) { Fabricate(:post, post_number: 2) }
|
||||
|
||||
it "classifies all given posts" do
|
||||
|
@ -70,5 +97,116 @@ RSpec.describe DiscourseAi::Sentiment::PostClassification do
|
|||
check_classification_for(post_1)
|
||||
check_classification_for(post_2)
|
||||
end
|
||||
|
||||
it "does nothing if there are no classification model" do
|
||||
SiteSetting.ai_sentiment_model_configs = ""
|
||||
|
||||
subject.bulk_classify!(Post.where(id: [post_1.id, post_2.id]))
|
||||
|
||||
expect(ClassificationResult.where(target: post_1).count).to be_zero
|
||||
expect(ClassificationResult.where(target: post_2).count).to be_zero
|
||||
end
|
||||
|
||||
it "don't reclassify everything when a model config changes" do
|
||||
SentimentInferenceStubs.stub_classification(post_1)
|
||||
|
||||
subject.bulk_classify!(Post.where(id: [post_1.id]))
|
||||
first_classified_at = 2.days.ago
|
||||
ClassificationResult.update_all(created_at: first_classified_at)
|
||||
|
||||
current_models = JSON.parse(SiteSetting.ai_sentiment_model_configs)
|
||||
current_models << { model_name: "new", endpoint: "https://test.com", api_key: "123" }
|
||||
SiteSetting.ai_sentiment_model_configs = current_models.to_json
|
||||
|
||||
SentimentInferenceStubs.stub_classification(post_1)
|
||||
subject.bulk_classify!(Post.where(id: [post_1.id]))
|
||||
|
||||
new_classifications = ClassificationResult.where("created_at > ?", first_classified_at).count
|
||||
expect(new_classifications).to eq(1)
|
||||
end
|
||||
end
|
||||
|
||||
describe ".backfill_query" do
|
||||
it "excludes posts in personal messages" do
|
||||
Fabricate(:private_message_post)
|
||||
|
||||
posts = described_class.backfill_query
|
||||
|
||||
expect(posts).to be_empty
|
||||
end
|
||||
|
||||
it "includes regular posts only" do
|
||||
Fabricate(:small_action)
|
||||
|
||||
posts = described_class.backfill_query
|
||||
|
||||
expect(posts).to be_empty
|
||||
end
|
||||
|
||||
it "excludes posts from deleted topics" do
|
||||
topic = Fabricate(:topic, deleted_at: 1.hour.ago)
|
||||
Fabricate(:post, topic: topic)
|
||||
|
||||
posts = described_class.backfill_query
|
||||
|
||||
expect(posts).to be_empty
|
||||
end
|
||||
|
||||
it "includes topics if at least one configured model is missing" do
|
||||
classified_post = Fabricate(:post)
|
||||
current_models = JSON.parse(SiteSetting.ai_sentiment_model_configs)
|
||||
current_models.each do |cm|
|
||||
Fabricate(:classification_result, target: classified_post, model_used: cm["model_name"])
|
||||
end
|
||||
|
||||
posts = described_class.backfill_query
|
||||
expect(posts).not_to include(classified_post)
|
||||
|
||||
current_models << { model_name: "new", endpoint: "htttps://test.com", api_key: "123" }
|
||||
SiteSetting.ai_sentiment_model_configs = current_models.to_json
|
||||
|
||||
posts = described_class.backfill_query
|
||||
expect(posts).to contain_exactly(classified_post)
|
||||
end
|
||||
|
||||
it "excludes deleted posts" do
|
||||
Fabricate(:post, deleted_at: 1.hour.ago)
|
||||
|
||||
posts = described_class.backfill_query
|
||||
|
||||
expect(posts).to be_empty
|
||||
end
|
||||
|
||||
context "with max_age_days" do
|
||||
fab!(:age_post) { Fabricate(:post, created_at: 3.days.ago) }
|
||||
|
||||
it "includes a post when is younger" do
|
||||
posts = described_class.backfill_query(max_age_days: 4)
|
||||
|
||||
expect(posts).to contain_exactly(age_post)
|
||||
end
|
||||
|
||||
it "excludes posts when it's older" do
|
||||
posts = described_class.backfill_query(max_age_days: 2)
|
||||
|
||||
expect(posts).to be_empty
|
||||
end
|
||||
end
|
||||
|
||||
context "with from_post_id" do
|
||||
fab!(:post)
|
||||
|
||||
it "includes post if ID is higher" do
|
||||
posts = described_class.backfill_query(from_post_id: post.id - 1)
|
||||
|
||||
expect(posts).to contain_exactly(post)
|
||||
end
|
||||
|
||||
it "excludes post if ID is lower" do
|
||||
posts = described_class.backfill_query(from_post_id: post.id + 1)
|
||||
|
||||
expect(posts).to be_empty
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -51,13 +51,19 @@ class SentimentInferenceStubs
|
|||
{ score: 0.0023369535, label: "surprise" },
|
||||
{ score: 0.001663634, label: "joy" },
|
||||
]
|
||||
else
|
||||
[
|
||||
{ score: 0.1, label: "label 1" },
|
||||
{ score: 0.2, label: "label 2" },
|
||||
{ score: 0.3, label: "label 3" },
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def stub_classification(post)
|
||||
content = post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw
|
||||
|
||||
DiscourseAi::Sentiment::PostClassification.new.classifiers.each do |model_config|
|
||||
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.each do |model_config|
|
||||
WebMock
|
||||
.stub_request(:post, model_config.endpoint)
|
||||
.with(body: JSON.dump(inputs: content, truncate: true))
|
||||
|
|
Loading…
Reference in New Issue