DEV: Drop `OpenStruct` for the context object in services

While using `OpenStruct` is nice, it’s generally not a very good idea as
it usually leads to performance problems.

The `OpenStruct` source code even says basically to avoid it.

Since the context object is crucial in our services, this patch replaces
`OpenStruct` with a custom implementation instead.
This commit is contained in:
Loïc Guitaut 2024-10-03 18:05:45 +02:00 committed by Loïc Guitaut
parent 974a3bfc41
commit 229773e7a8
20 changed files with 145 additions and 143 deletions

View File

@ -9,9 +9,10 @@ class UpdateSiteSetting
attribute :new_value
attribute :allow_changing_hidden, :boolean, default: false
before_validation { self.setting_name = setting_name&.to_sym }
validates :setting_name, presence: true
end
step :convert_name_to_sym
policy :setting_is_visible
policy :setting_is_configurable
step :cleanup_value
@ -19,28 +20,25 @@ class UpdateSiteSetting
private
def convert_name_to_sym(setting_name:)
context.setting_name = setting_name.to_sym
end
def current_user_is_admin(guardian:)
guardian.is_admin?
end
def setting_is_visible(setting_name:)
context.allow_changing_hidden || !SiteSetting.hidden_settings.include?(setting_name)
def setting_is_visible(contract:)
contract.allow_changing_hidden || !SiteSetting.hidden_settings.include?(contract.setting_name)
end
def setting_is_configurable(setting_name:)
return true if !SiteSetting.plugins[setting_name]
def setting_is_configurable(contract:)
return true if !SiteSetting.plugins[contract.setting_name]
Discourse.plugins_by_name[SiteSetting.plugins[setting_name]].configurable?
Discourse.plugins_by_name[SiteSetting.plugins[contract.setting_name]].configurable?
end
def cleanup_value(setting_name:, new_value:)
def cleanup_value(contract:)
new_value = contract.new_value
new_value = new_value.strip if new_value.is_a?(String)
case SiteSetting.type_supervisor.get_type(setting_name)
case SiteSetting.type_supervisor.get_type(contract.setting_name)
when :integer
new_value = new_value.tr("^-0-9", "").to_i if new_value.is_a?(String)
when :file_size_restriction
@ -50,10 +48,10 @@ class UpdateSiteSetting
when :upload
new_value = Upload.get_from_url(new_value) || ""
end
context.new_value = new_value
context[:new_value] = new_value
end
def save(setting_name:, new_value:, guardian:)
SiteSetting.set_and_log(setting_name, new_value, guardian.user)
def save(contract:, new_value:, guardian:)
SiteSetting.set_and_log(contract.setting_name, new_value, guardian.user)
end
end

View File

@ -17,8 +17,24 @@ module Service
end
# Simple structure to hold the context of the service during its whole lifecycle.
class Context < OpenStruct
include ActiveModel::Serialization
class Context
delegate :slice, to: :store
def initialize(context = {})
@store = context.symbolize_keys
end
def [](key)
store[key.to_sym]
end
def []=(key, value)
store[key.to_sym] = value
end
def to_h
store.dup
end
# @return [Boolean] returns +true+ if the context is set as successful (default)
def success?
@ -48,27 +64,27 @@ module Service
# context.fail("failure": "something went wrong")
# @return [Context]
def fail(context = {})
merge(context)
store.merge!(context.symbolize_keys)
@failure = true
self
end
# Merges the given context into the current one.
# @!visibility private
def merge(other_context = {})
other_context.each { |key, value| self[key.to_sym] = value }
self
end
def inspect_steps
StepsInspector.new(self)
Service::StepsInspector.new(self)
end
private
attr_reader :store
def self.build(context = {})
self === context ? context : new(context)
end
def method_missing(method_name, *args, &block)
return super if args.present?
store[method_name]
end
end
# Internal module to define available steps as DSL
@ -117,7 +133,7 @@ module Service
if method.parameters.any? { _1[0] != :keyreq }
raise "In #{type} '#{name}': default values in step implementations are not allowed. Maybe they could be defined in a contract?"
end
args = context.to_h.slice(*method.parameters.select { _1[0] == :keyreq }.map(&:last))
args = context.slice(*method.parameters.select { _1[0] == :keyreq }.map(&:last))
context[result_key] = Context.build(object: object)
instance.instance_exec(**args, &method)
end
@ -180,7 +196,7 @@ module Service
attributes = class_name.attribute_names.map(&:to_sym)
default_values = {}
default_values = context[default_values_from].slice(*attributes) if default_values_from
contract = class_name.new(default_values.merge(context.to_h.slice(*attributes)))
contract = class_name.new(default_values.merge(context.slice(*attributes)))
context[contract_name] = contract
context[result_key] = Context.build
if contract.invalid?
@ -347,7 +363,6 @@ module Service
# @!visibility private
def initialize(initial_context = {})
@initial_context = initial_context.with_indifferent_access
@context = Context.build(initial_context.merge(__steps__: self.class.steps))
end

View File

@ -1,10 +1,10 @@
# frozen_string_literal: true
# = StepsInspector
# = Service::StepsInspector
#
# This class takes a {Service::Base::Context} object and inspects it.
# It will output a list of steps and what is their known state.
class StepsInspector
class Service::StepsInspector
# @!visibility private
class Step
attr_reader :step, :result, :nesting_level

View File

@ -82,19 +82,18 @@ module Chat
end
if memberships.blank?
context.added_user_ids = []
context[:added_user_ids] = []
return
end
context.added_user_ids =
::Chat::UserChatChannelMembership
.upsert_all(
memberships,
unique_by: %i[user_id chat_channel_id],
returning: Arel.sql("user_id, (xmax = '0') as inserted"),
)
.select { |row| row["inserted"] }
.map { |row| row["user_id"] }
context[:added_user_ids] = ::Chat::UserChatChannelMembership
.upsert_all(
memberships,
unique_by: %i[user_id chat_channel_id],
returning: Arel.sql("user_id, (xmax = '0') as inserted"),
)
.select { |row| row["inserted"] }
.map { |row| row["user_id"] }
::Chat::DirectMessageUser.upsert_all(
context.added_user_ids.map do |id|

View File

@ -47,8 +47,10 @@ module Chat
end
def create_memberships(channel:, contract:)
context.added_user_ids =
::Chat::Action::CreateMembershipsForAutoJoin.call(channel: channel, contract: contract)
context[:added_user_ids] = ::Chat::Action::CreateMembershipsForAutoJoin.call(
channel: channel,
contract: contract,
)
end
def recalculate_user_count(channel:, added_user_ids:)

View File

@ -100,11 +100,8 @@ module Chat
return if memberships_to_remove.empty?
context.merge(
users_removed_map:
Chat::Action::RemoveMemberships.call(
memberships: Chat::UserChatChannelMembership.where(id: memberships_to_remove),
),
context[:users_removed_map] = Chat::Action::RemoveMemberships.call(
memberships: Chat::UserChatChannelMembership.where(id: memberships_to_remove),
)
end

View File

@ -81,11 +81,8 @@ module Chat
return if memberships_to_remove.empty?
context.merge(
users_removed_map:
Chat::Action::RemoveMemberships.call(
memberships: Chat::UserChatChannelMembership.where(id: memberships_to_remove),
),
context[:users_removed_map] = Chat::Action::RemoveMemberships.call(
memberships: Chat::UserChatChannelMembership.where(id: memberships_to_remove),
)
end

View File

@ -59,15 +59,14 @@ module Chat
def find_or_create_thread(channel:, original_message:, contract:)
if original_message.thread_id.present?
return context.thread = ::Chat::Thread.find_by(id: original_message.thread_id)
return context[:thread] = ::Chat::Thread.find_by(id: original_message.thread_id)
end
context.thread =
channel.threads.create(
title: contract.title,
original_message: original_message,
original_message_user: original_message.user,
)
context[:thread] = channel.threads.create(
title: contract.title,
original_message: original_message,
original_message_user: original_message.user,
)
fail!(context.thread.errors.full_messages.join(", ")) if context.thread.invalid?
end
@ -76,7 +75,7 @@ module Chat
end
def fetch_membership(guardian:)
context.membership = context.thread.membership_for(guardian.user)
context[:membership] = context.thread.membership_for(guardian.user)
end
def publish_new_thread(channel:, original_message:)

View File

@ -65,18 +65,18 @@ module Chat
end
def enabled_threads?(channel:)
context.enabled_threads = channel.threading_enabled
context[:enabled_threads] = channel.threading_enabled
end
def can_view_channel(guardian:, channel:)
guardian.can_preview_chat_channel?(channel)
end
def determine_target_message_id(contract:)
def determine_target_message_id(contract:, membership:)
if contract.fetch_from_last_read
context.target_message_id = context.membership&.last_read_message_id
context[:target_message_id] = membership&.last_read_message_id
else
context.target_message_id = contract.target_message_id
context[:target_message_id] = contract.target_message_id
end
end
@ -92,7 +92,7 @@ module Chat
return true
end
context.target_message_id = nil
context[:target_message_id] = nil
true
end
@ -108,9 +108,9 @@ module Chat
target_date: contract.target_date,
)
context.can_load_more_past = messages_data[:can_load_more_past]
context.can_load_more_future = messages_data[:can_load_more_future]
context.target_message_id = messages_data[:target_message_id]
context[:can_load_more_past] = messages_data[:can_load_more_past]
context[:can_load_more_future] = messages_data[:can_load_more_future]
context[:target_message_id] = messages_data[:target_message_id]
messages_data[:target_message] = (
if messages_data[:target_message]&.thread_reply? &&
@ -121,7 +121,7 @@ module Chat
end
)
context.messages = [
context[:messages] = [
messages_data[:messages],
messages_data[:past_messages]&.reverse,
messages_data[:target_message],
@ -130,37 +130,36 @@ module Chat
end
def fetch_tracking(guardian:)
context.tracking = {}
context[:tracking] = {}
return if !context.thread_ids.present?
context.tracking =
::Chat::TrackingStateReportQuery.call(
guardian: guardian,
thread_ids: context.thread_ids,
include_threads: true,
)
context[:tracking] = ::Chat::TrackingStateReportQuery.call(
guardian: guardian,
thread_ids: context.thread_ids,
include_threads: true,
)
end
def fetch_thread_ids(messages:)
context.thread_ids = messages.map(&:thread_id).compact.uniq
context[:thread_ids] = messages.map(&:thread_id).compact.uniq
end
def fetch_thread_participants(messages:)
return if context.thread_ids.empty?
context.thread_participants =
::Chat::ThreadParticipantQuery.call(thread_ids: context.thread_ids)
context[:thread_participants] = ::Chat::ThreadParticipantQuery.call(
thread_ids: context.thread_ids,
)
end
def fetch_thread_memberships(guardian:)
return if context.thread_ids.empty?
context.thread_memberships =
::Chat::UserChatThreadMembership.where(
thread_id: context.thread_ids,
user_id: guardian.user.id,
)
context[:thread_memberships] = ::Chat::UserChatThreadMembership.where(
thread_id: context.thread_ids,
user_id: guardian.user.id,
)
end
def update_membership_last_viewed_at(guardian:)

View File

@ -63,13 +63,13 @@ module Chat
def determine_target_message_id(contract:, membership:, guardian:, thread:)
if contract.fetch_from_last_message
context.target_message_id = thread.last_message_id
context[:target_message_id] = thread.last_message_id
elsif contract.fetch_from_first_message
context.target_message_id = thread.original_message_id
context[:target_message_id] = thread.original_message_id
elsif contract.fetch_from_last_read || !contract.target_message_id
context.target_message_id = membership&.last_read_message_id
context[:target_message_id] = membership&.last_read_message_id
elsif contract.target_message_id
context.target_message_id = contract.target_message_id
context[:target_message_id] = contract.target_message_id
end
end
@ -99,8 +99,8 @@ module Chat
contract.fetch_from_first_message || contract.fetch_from_last_message,
)
context.can_load_more_past = messages_data[:can_load_more_past]
context.can_load_more_future = messages_data[:can_load_more_future]
context[:can_load_more_past] = messages_data[:can_load_more_past]
context[:can_load_more_future] = messages_data[:can_load_more_future]
[
messages_data[:messages],

View File

@ -51,11 +51,11 @@ module Chat
private
def set_limit(contract:)
context.limit = (contract.limit || THREADS_LIMIT).to_i.clamp(1, THREADS_LIMIT)
context[:limit] = (contract.limit || THREADS_LIMIT).to_i.clamp(1, THREADS_LIMIT)
end
def set_offset(contract:)
context.offset = [contract.offset || 0, 0].max
context[:offset] = [contract.offset || 0, 0].max
end
def fetch_channel(contract:)
@ -118,33 +118,30 @@ module Chat
end
def fetch_tracking(guardian:, threads:)
context.tracking =
::Chat::TrackingStateReportQuery.call(
guardian: guardian,
thread_ids: threads.map(&:id),
include_threads: true,
).thread_tracking
context[:tracking] = ::Chat::TrackingStateReportQuery.call(
guardian: guardian,
thread_ids: threads.map(&:id),
include_threads: true,
).thread_tracking
end
def fetch_memberships(guardian:, threads:)
context.memberships =
::Chat::UserChatThreadMembership.where(
thread_id: threads.map(&:id),
user_id: guardian.user.id,
)
context[:memberships] = ::Chat::UserChatThreadMembership.where(
thread_id: threads.map(&:id),
user_id: guardian.user.id,
)
end
def fetch_participants(threads:)
context.participants = ::Chat::ThreadParticipantQuery.call(thread_ids: threads.map(&:id))
context[:participants] = ::Chat::ThreadParticipantQuery.call(thread_ids: threads.map(&:id))
end
def build_load_more_url(contract:)
load_more_params = { offset: context.offset + context.limit }.to_query
context.load_more_url =
::URI::HTTP.build(
path: "/chat/api/channels/#{contract.channel_id}/threads",
query: load_more_params,
).request_uri
context[:load_more_url] = ::URI::HTTP.build(
path: "/chat/api/channels/#{contract.channel_id}/threads",
query: load_more_params,
).request_uri
end
end
end

View File

@ -35,11 +35,11 @@ module Chat
private
def set_limit(contract:)
context.limit = (contract.limit || THREADS_LIMIT).to_i.clamp(1, THREADS_LIMIT)
context[:limit] = (contract.limit || THREADS_LIMIT).to_i.clamp(1, THREADS_LIMIT)
end
def set_offset(contract:)
context.offset = [contract.offset || 0, 0].max
context[:offset] = [contract.offset || 0, 0].max
end
def fetch_threads(guardian:)
@ -112,31 +112,31 @@ module Chat
end
def fetch_tracking(guardian:, threads:)
context.tracking =
::Chat::TrackingStateReportQuery.call(
guardian: guardian,
thread_ids: threads.map(&:id),
include_threads: true,
).thread_tracking
context[:tracking] = ::Chat::TrackingStateReportQuery.call(
guardian: guardian,
thread_ids: threads.map(&:id),
include_threads: true,
).thread_tracking
end
def fetch_memberships(guardian:, threads:)
context.memberships =
::Chat::UserChatThreadMembership.where(
thread_id: threads.map(&:id),
user_id: guardian.user.id,
)
context[:memberships] = ::Chat::UserChatThreadMembership.where(
thread_id: threads.map(&:id),
user_id: guardian.user.id,
)
end
def fetch_participants(threads:)
context.participants = ::Chat::ThreadParticipantQuery.call(thread_ids: threads.map(&:id))
context[:participants] = ::Chat::ThreadParticipantQuery.call(thread_ids: threads.map(&:id))
end
def build_load_more_url(contract:)
load_more_params = { limit: context.limit, offset: context.offset + context.limit }.to_query
context.load_more_url =
::URI::HTTP.build(path: "/chat/api/me/threads", query: load_more_params).request_uri
context[:load_more_url] = ::URI::HTTP.build(
path: "/chat/api/me/threads",
query: load_more_params,
).request_uri
end
end
end

View File

@ -54,7 +54,7 @@ module Chat
notification_level: Chat::NotificationLevels.all[:normal],
) if !membership
membership.update!(thread_title_prompt_seen: true)
context.membership = membership
context[:membership] = membership
end
end
end

View File

@ -34,7 +34,7 @@ module Chat
private
def clean_term(contract:)
context.term = contract.term&.downcase&.strip&.gsub(/^[@#]+/, "")
context[:term] = contract.term&.downcase&.strip&.gsub(/^[@#]+/, "")
end
def fetch_memberships(guardian:)

View File

@ -32,7 +32,7 @@ module Chat
end
def unfollow(channel:, guardian:)
context.membership = channel.remove(guardian.user)
context[:membership] = channel.remove(guardian.user)
end
end
end

View File

@ -66,7 +66,7 @@ module Chat
def update_channel(channel:, contract:)
channel.assign_attributes(contract.attributes)
context.threading_enabled_changed = channel.threading_enabled_changed?
context[:threading_enabled_changed] = channel.threading_enabled_changed?
channel.save!
end

View File

@ -120,12 +120,11 @@ module Chat
prev_message = message.message_before_last_save || message.message_was
return if !should_create_revision(message, prev_message, guardian)
context.revision =
message.revisions.create!(
old_message: prev_message,
new_message: message.message,
user_id: guardian.user.id,
)
context[:revision] = message.revisions.create!(
old_message: prev_message,
new_message: message.message,
user_id: guardian.user.id,
)
end
def should_create_revision(new_message, prev_message, guardian)
@ -151,7 +150,7 @@ module Chat
end
def publish(message:, guardian:, contract:)
edit_timestamp = context.revision&.created_at&.iso8601(6) || Time.zone.now.iso8601(6)
edit_timestamp = context[:revision]&.created_at&.iso8601(6) || Time.zone.now.iso8601(6)
::Chat::Publisher.publish_edit!(message.chat_channel, message)

View File

@ -60,7 +60,7 @@ module Chat
membership.update!(last_read_message_id: thread.last_message_id)
end
membership.update!(notification_level: contract.notification_level)
context.membership = membership
context[:membership] = membership
end
end
end

View File

@ -1,6 +1,6 @@
# frozen_string_literal: true
RSpec.describe StepsInspector do
RSpec.describe Service::StepsInspector do
class DummyService
include Service::Base
@ -239,7 +239,7 @@ RSpec.describe StepsInspector do
end
context "when a reason is provided" do
before { result["result.policy.policy"].reason = "failed" }
before { result["result.policy.policy"][:reason] = "failed" }
it "returns the reason" do
expect(error).to eq "failed"

View File

@ -26,7 +26,7 @@ module ServiceMatchers
private
def error_message_with_inspection(message)
inspector = StepsInspector.new(result)
inspector = Service::StepsInspector.new(result)
"#{message}\n\n#{inspector.inspect}\n\n#{inspector.error}"
end
end
@ -89,7 +89,7 @@ module ServiceMatchers
end
def error_message_with_inspection(message)
inspector = StepsInspector.new(result)
inspector = Service::StepsInspector.new(result)
"#{message}\n\n#{inspector.inspect}\n\n#{inspector.error}"
end
@ -158,7 +158,7 @@ module ServiceMatchers
end
def inspect_steps(result)
inspector = StepsInspector.new(result)
inspector = Service::StepsInspector.new(result)
puts "Steps:"
puts inspector.inspect
puts "\nFirst error:"