From 1320eed9b29fc6d1b99ee7ee1689d7ed9512d833 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 4 Jul 2024 10:48:18 +1000 Subject: [PATCH] FEATURE: move summary to use llm_model (#699) This allows summary to use the new LLM models and migrates of API key based model selection Claude 3.5 etc... all work now. --------- Co-authored-by: Roman Rizzi --- Gemfile.lock | 43 +++-- .../summarization/chat_summary_controller.rb | 7 +- .../summarization/summary_controller.rb | 10 +- app/jobs/regular/stream_topic_ai_summary.rb | 6 +- .../ai_topic_summary_serializer.rb | 2 +- .../discourse_ai/topic_summarization.rb | 6 +- .../initializers/ai-chat-summarization.js | 5 +- config/locales/server.en.yml | 8 +- config/settings.yml | 27 +-- ...0703135444_llm_models_for_summarization.rb | 60 ++++++ lib/completions/llm.rb | 8 +- lib/configuration/llm_enumerator.rb | 23 +++ lib/configuration/llm_validator.rb | 1 + lib/configuration/summarization_enumerator.rb | 20 -- lib/configuration/summarization_validator.rb | 23 --- lib/guardian_extensions.rb | 21 +++ lib/summarization.rb | 12 ++ lib/summarization/entry_point.rb | 13 +- lib/summarization/models/anthropic.rb | 26 --- lib/summarization/models/base.rb | 171 ------------------ lib/summarization/models/custom_llm.rb | 41 ----- lib/summarization/models/fake.rb | 25 --- lib/summarization/models/gemini.rb | 25 --- lib/summarization/models/mixtral.rb | 25 --- lib/summarization/models/open_ai.rb | 25 --- lib/summarization/strategies/fold_content.rb | 65 ++++--- .../regular/stream_topic_ai_summary_spec.rb | 8 +- ...se_spec.rb => guardian_extensions_spec.rb} | 38 ++-- .../strategies/fold_content_spec.rb | 18 +- .../chat_summary_controller_spec.rb | 3 +- .../summarization/summary_controller_spec.rb | 5 +- .../discourse_ai/topic_summarization_spec.rb | 31 ++-- .../summarization/chat_summarization_spec.rb | 3 +- 33 files changed, 282 insertions(+), 522 deletions(-) create mode 100644 db/post_migrate/20240703135444_llm_models_for_summarization.rb delete mode 100644 lib/configuration/summarization_enumerator.rb delete mode 100644 lib/configuration/summarization_validator.rb create mode 100644 lib/summarization.rb delete mode 100644 lib/summarization/models/anthropic.rb delete mode 100644 lib/summarization/models/base.rb delete mode 100644 lib/summarization/models/custom_llm.rb delete mode 100644 lib/summarization/models/fake.rb delete mode 100644 lib/summarization/models/gemini.rb delete mode 100644 lib/summarization/models/mixtral.rb delete mode 100644 lib/summarization/models/open_ai.rb rename spec/lib/{modules/summarization/base_spec.rb => guardian_extensions_spec.rb} (71%) diff --git a/Gemfile.lock b/Gemfile.lock index 27fcd637..f8f1f5ab 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -1,7 +1,7 @@ GEM remote: https://rubygems.org/ specs: - activesupport (7.1.3.3) + activesupport (7.1.3.4) base64 bigdecimal concurrent-ruby (~> 1.0, >= 1.0.2) @@ -14,27 +14,27 @@ GEM ast (2.4.2) base64 (0.2.0) bigdecimal (3.1.8) - concurrent-ruby (1.2.3) + concurrent-ruby (1.3.3) connection_pool (2.4.1) drb (2.2.1) i18n (1.14.5) concurrent-ruby (~> 1.0) json (2.7.2) language_server-protocol (3.17.0.3) - minitest (5.23.1) + minitest (5.24.1) mutex_m (0.2.0) - parallel (1.24.0) - parser (3.3.1.0) + parallel (1.25.1) + parser (3.3.3.0) ast (~> 2.4.1) racc prettier_print (1.2.1) racc (1.8.0) - rack (3.0.11) + rack (3.1.6) rainbow (3.1.1) regexp_parser (2.9.2) - rexml (3.2.8) - strscan (>= 3.0.9) - rubocop (1.64.0) + rexml (3.3.1) + strscan + rubocop (1.64.1) json (~> 2.3) language_server-protocol (>= 3.17.0) parallel (~> 1.10) @@ -47,29 +47,28 @@ GEM unicode-display_width (>= 2.4.0, < 3.0) rubocop-ast (1.31.3) parser (>= 3.3.1.0) - rubocop-capybara (2.20.0) + rubocop-capybara (2.21.0) rubocop (~> 1.41) - rubocop-discourse (3.8.0) + rubocop-discourse (3.8.1) activesupport (>= 6.1) rubocop (>= 1.59.0) rubocop-capybara (>= 2.0.0) rubocop-factory_bot (>= 2.0.0) rubocop-rails (>= 2.25.0) - rubocop-rspec (>= 2.25.0) - rubocop-factory_bot (2.25.1) - rubocop (~> 1.41) - rubocop-rails (2.25.0) + rubocop-rspec (>= 3.0.1) + rubocop-rspec_rails (>= 2.30.0) + rubocop-factory_bot (2.26.1) + rubocop (~> 1.61) + rubocop-rails (2.25.1) activesupport (>= 4.2.0) rack (>= 1.1) rubocop (>= 1.33.0, < 2.0) rubocop-ast (>= 1.31.1, < 2.0) - rubocop-rspec (2.29.2) - rubocop (~> 1.40) - rubocop-capybara (~> 2.17) - rubocop-factory_bot (~> 2.22) - rubocop-rspec_rails (~> 2.28) - rubocop-rspec_rails (2.28.3) - rubocop (~> 1.40) + rubocop-rspec (3.0.2) + rubocop (~> 1.61) + rubocop-rspec_rails (2.30.0) + rubocop (~> 1.61) + rubocop-rspec (~> 3, >= 3.0.1) ruby-progressbar (1.13.0) strscan (3.1.0) syntax_tree (6.2.0) diff --git a/app/controllers/discourse_ai/summarization/chat_summary_controller.rb b/app/controllers/discourse_ai/summarization/chat_summary_controller.rb index 78d51be4..96d95d8c 100644 --- a/app/controllers/discourse_ai/summarization/chat_summary_controller.rb +++ b/app/controllers/discourse_ai/summarization/chat_summary_controller.rb @@ -15,11 +15,10 @@ module DiscourseAi channel = ::Chat::Channel.find(params[:channel_id]) guardian.ensure_can_join_chat_channel!(channel) - strategy = DiscourseAi::Summarization::Models::Base.selected_strategy + strategy = DiscourseAi::Summarization.default_strategy raise Discourse::NotFound.new unless strategy - unless DiscourseAi::Summarization::Models::Base.can_request_summary_for?(current_user) - raise Discourse::InvalidAccess - end + + guardian.ensure_can_request_summary! RateLimiter.new(current_user, "channel_summary", 6, 5.minutes).performed! diff --git a/app/controllers/discourse_ai/summarization/summary_controller.rb b/app/controllers/discourse_ai/summarization/summary_controller.rb index 48c8b7e0..14f57c62 100644 --- a/app/controllers/discourse_ai/summarization/summary_controller.rb +++ b/app/controllers/discourse_ai/summarization/summary_controller.rb @@ -8,12 +8,8 @@ module DiscourseAi def show topic = Topic.find(params[:topic_id]) guardian.ensure_can_see!(topic) - strategy = DiscourseAi::Summarization::Models::Base.selected_strategy - if strategy.nil? || - !DiscourseAi::Summarization::Models::Base.can_see_summary?(topic, current_user) - raise Discourse::NotFound - end + raise Discourse::NotFound if !guardian.can_see_summary?(topic) RateLimiter.new(current_user, "summary", 6, 5.minutes).performed! if current_user @@ -30,9 +26,7 @@ module DiscourseAi render json: success_json else hijack do - summary = - DiscourseAi::TopicSummarization.new(strategy).summarize(topic, current_user, opts) - + summary = DiscourseAi::TopicSummarization.summarize(topic, current_user, opts) render_serialized(summary, AiTopicSummarySerializer) end end diff --git a/app/jobs/regular/stream_topic_ai_summary.rb b/app/jobs/regular/stream_topic_ai_summary.rb index 1dce0925..6641918c 100644 --- a/app/jobs/regular/stream_topic_ai_summary.rb +++ b/app/jobs/regular/stream_topic_ai_summary.rb @@ -8,10 +8,8 @@ module Jobs return unless topic = Topic.find_by(id: args[:topic_id]) return unless user = User.find_by(id: args[:user_id]) - strategy = DiscourseAi::Summarization::Models::Base.selected_strategy - if strategy.nil? || !DiscourseAi::Summarization::Models::Base.can_see_summary?(topic, user) - return - end + strategy = DiscourseAi::Summarization.default_strategy + return if strategy.nil? || !Guardian.new(user).can_see_summary?(topic) guardian = Guardian.new(user) return unless guardian.can_see?(topic) diff --git a/app/serializers/ai_topic_summary_serializer.rb b/app/serializers/ai_topic_summary_serializer.rb index 04846193..39d3abff 100644 --- a/app/serializers/ai_topic_summary_serializer.rb +++ b/app/serializers/ai_topic_summary_serializer.rb @@ -4,7 +4,7 @@ class AiTopicSummarySerializer < ApplicationSerializer attributes :summarized_text, :algorithm, :outdated, :can_regenerate, :new_posts_since_summary def can_regenerate - DiscourseAi::Summarization::Models::Base.can_request_summary_for?(scope.current_user) + scope.can_request_summary? end def new_posts_since_summary diff --git a/app/services/discourse_ai/topic_summarization.rb b/app/services/discourse_ai/topic_summarization.rb index 47051694..d5ea3be2 100644 --- a/app/services/discourse_ai/topic_summarization.rb +++ b/app/services/discourse_ai/topic_summarization.rb @@ -2,6 +2,10 @@ module DiscourseAi class TopicSummarization + def self.summarize(topic, user, opts = {}, &on_partial_blk) + new(DiscourseAi::Summarization.default_strategy).summarize(topic, user, opts, &on_partial_blk) + end + def initialize(strategy) @strategy = strategy end @@ -15,7 +19,7 @@ module DiscourseAi targets_data = summary_targets(topic).pluck(:post_number, :raw, :username) current_topic_sha = build_sha(targets_data.map(&:first)) - can_summarize = DiscourseAi::Summarization::Models::Base.can_request_summary_for?(user) + can_summarize = Guardian.new(user).can_request_summary? if use_cached?(existing_summary, can_summarize, current_topic_sha, !!opts[:skip_age_check]) # It's important that we signal a cached summary is outdated diff --git a/assets/javascripts/initializers/ai-chat-summarization.js b/assets/javascripts/initializers/ai-chat-summarization.js index 1b17ec5c..57966c59 100644 --- a/assets/javascripts/initializers/ai-chat-summarization.js +++ b/assets/javascripts/initializers/ai-chat-summarization.js @@ -6,10 +6,7 @@ export default apiInitializer("1.34.0", (api) => { const currentUser = api.getCurrentUser(); const chatService = api.container.lookup("service:chat"); const modal = api.container.lookup("service:modal"); - const canSummarize = - siteSettings.ai_summarization_strategy && - currentUser && - currentUser.can_summarize; + const canSummarize = currentUser && currentUser.can_summarize; if ( !siteSettings.chat_enabled || diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 6b5d58a1..97001fce 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -92,11 +92,9 @@ en: ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results" ai_embeddings_semantic_search_hyde_model: "Model used to expand keywords to get better results during a semantic search" ai_embeddings_per_post_enabled: Generate embeddings for each post - - ai_summarization_discourse_service_api_endpoint: "URL where the Discourse summarization API is running." - ai_summarization_discourse_service_api_key: "API key for the Discourse summarization API." - ai_summarization_strategy: "Additional ways to summarize content registered by plugins" - ai_custom_summarization_allowed_groups: "Groups allowed to summarize contents using the `summarization_strategy`." + ai_summarization_enabled: "Enable the topic summarization module." + ai_summarization_model: "Model to use for summarization." + ai_custom_summarization_allowed_groups: "Groups allowed to use create new summaries." ai_bot_enabled: "Enable the AI Bot module." ai_bot_enable_chat_warning: "Display a warning when PM chat is initiated. Can be overriden by editing the translation string: discourse_ai.ai_bot.pm_warning" diff --git a/config/settings.yml b/config/settings.yml index aac0514d..f890ad8f 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -319,23 +319,24 @@ discourse_ai: default: false client: true hidden: true - - ai_summarization_discourse_service_api_endpoint: "" - ai_summarization_discourse_service_api_endpoint_srv: + ai_summarization_enabled: + default: false + validator: "DiscourseAi::Configuration::LlmDependencyValidator" + ai_summarization_model: default: "" - hidden: true - ai_summarization_discourse_service_api_key: - default: "" - secret: true - ai_summarization_strategy: - client: true - default: "" - enum: "DiscourseAi::Configuration::SummarizationEnumerator" - validator: "DiscourseAi::Configuration::SummarizationValidator" + allow_any: false + type: enum + enum: "DiscourseAi::Configuration::LlmEnumerator" + validator: "DiscourseAi::Configuration::LlmValidator" ai_custom_summarization_allowed_groups: type: group_list list_type: compact default: "3|13" # 3: @staff, 13: @trust_level_3 + ai_summarization_strategy: # TODO(roman): Deprecated. Remove by Sept 2024 + type: enum + default: "" + hidden: true + choices: "DiscourseAi::Configuration::LlmEnumerator.old_summarization_options + ['']" ai_bot_enabled: default: false @@ -359,7 +360,7 @@ discourse_ai: default: "1|2" # 1: admins, 2: moderators allow_any: false refresh: true - ai_bot_enabled_chat_bots: # TODO(roman): Remove setting. Deprecated + ai_bot_enabled_chat_bots: # TODO(roman): Deprecated. Remove by Sept 2024 type: list default: "gpt-3.5-turbo" hidden: true diff --git a/db/post_migrate/20240703135444_llm_models_for_summarization.rb b/db/post_migrate/20240703135444_llm_models_for_summarization.rb new file mode 100644 index 00000000..5046139e --- /dev/null +++ b/db/post_migrate/20240703135444_llm_models_for_summarization.rb @@ -0,0 +1,60 @@ +# frozen_string_literal: true + +class LlmModelsForSummarization < ActiveRecord::Migration[7.0] + def up + setting_value = + DB + .query_single( + "SELECT value FROM site_settings WHERE name = :llm_setting", + llm_setting: "ai_summarization_strategy", + ) + .first + .to_s + + return if setting_value.empty? + + gpt_models = %w[gpt-4 gpt-4-32k gpt-4-turbo gpt-4o gpt-3.5-turbo gpt-3.5-turbo-16k] + gemini_models = %w[gemini-pro gemini-1.5-pro gemini-1.5-flash] + claude_models = %w[claude-2 claude-instant-1 claude-3-haiku claude-3-sonnet claude-3-opus] + oss_models = %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mixtral-8x7B-Instruct-v0.1] + + providers = [] + prov_priority = "" + + if gpt_models.include?(setting_value) + providers = %w[azure open_ai] + prov_priority = "azure" + elsif gemini_models.include?(setting_value) + providers = %w[google] + prov_priority = "google" + elsif claude_models.include?(setting_value) + providers = %w[aws_bedrock anthropic] + prov_priority = "aws_bedrock" + elsif oss_models.include?(setting_value) + providers = %w[hugging_face vllm] + prov_priority = "vllm" + end + + insert_llm_model(setting_value, providers, prov_priority) if providers.present? + end + + def insert_llm_model(old_value, providers, priority) + matching_models = DB.query(<<~SQL, model_name: old_value, providers: providers) + SELECT * FROM llm_models WHERE name = :model_name AND provider IN (:providers) + SQL + + return if matching_models.empty? + + priority_model = matching_models.find { |m| m.provider == priority } || matching_models.first + new_value = "custom:#{priority_model.id}" + + DB.exec(<<~SQL, new_value: new_value) + INSERT INTO site_settings(name, data_type, value, created_at, updated_at) + VALUES ('ai_summarization_model', 1, :new_value, NOW(), NOW()) + SQL + end + + def down + raise ActiveRecord::IrreversibleMigration + end +end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 741ae653..c2660090 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -200,7 +200,9 @@ module DiscourseAi raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}" end - return new(dialect_klass, nil, model_name, gateway: @canned_response) + return( + new(dialect_klass, nil, model_name, gateway: @canned_response, llm_model: llm_model) + ) end gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name) @@ -293,11 +295,11 @@ module DiscourseAi dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).tokenizer end - attr_reader :model_name + attr_reader :model_name, :llm_model private - attr_reader :dialect_klass, :gateway_klass, :llm_model + attr_reader :dialect_klass, :gateway_klass end end end diff --git a/lib/configuration/llm_enumerator.rb b/lib/configuration/llm_enumerator.rb index 8b308510..e8ceb603 100644 --- a/lib/configuration/llm_enumerator.rb +++ b/lib/configuration/llm_enumerator.rb @@ -20,6 +20,29 @@ module DiscourseAi values end + # TODO(roman): Deprecated. Remove by Sept 2024 + def self.old_summarization_options + %w[ + gpt-4 + gpt-4-32k + gpt-4-turbo + gpt-4o + gpt-3.5-turbo + gpt-3.5-turbo-16k + gemini-pro + gemini-1.5-pro + gemini-1.5-flash + claude-2 + claude-instant-1 + claude-3-haiku + claude-3-sonnet + claude-3-opus + mistralai/Mixtral-8x7B-Instruct-v0.1 + mistralai/Mixtral-8x7B-Instruct-v0.1 + ] + end + + # TODO(roman): Deprecated. Remove by Sept 2024 def self.available_ai_bots %w[ gpt-3.5-turbo diff --git a/lib/configuration/llm_validator.rb b/lib/configuration/llm_validator.rb index 4e8c771a..a0c76210 100644 --- a/lib/configuration/llm_validator.rb +++ b/lib/configuration/llm_validator.rb @@ -61,6 +61,7 @@ module DiscourseAi { ai_embeddings_semantic_search_enabled: :ai_embeddings_semantic_search_hyde_model, composer_ai_helper_enabled: :ai_helper_model, + ai_summarization_enabled: :ai_summarization_model, } end end diff --git a/lib/configuration/summarization_enumerator.rb b/lib/configuration/summarization_enumerator.rb deleted file mode 100644 index 915991dd..00000000 --- a/lib/configuration/summarization_enumerator.rb +++ /dev/null @@ -1,20 +0,0 @@ -# frozen_string_literal: true - -require "enum_site_setting" - -module DiscourseAi - module Configuration - class SummarizationEnumerator < ::EnumSiteSetting - def self.valid_value?(val) - true - end - - def self.values - @values ||= - DiscourseAi::Summarization::Models::Base.available_strategies.map do |strategy| - { name: strategy.display_name, value: strategy.model } - end - end - end - end -end diff --git a/lib/configuration/summarization_validator.rb b/lib/configuration/summarization_validator.rb deleted file mode 100644 index 87edbea9..00000000 --- a/lib/configuration/summarization_validator.rb +++ /dev/null @@ -1,23 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Configuration - class SummarizationValidator - def initialize(opts = {}) - @opts = opts - end - - def valid_value?(val) - strategy = DiscourseAi::Summarization::Models::Base.find_strategy(val) - - return true unless strategy - - strategy.correctly_configured?.tap { |is_valid| @strategy = strategy unless is_valid } - end - - def error_message - @strategy.configuration_hint - end - end - end -end diff --git a/lib/guardian_extensions.rb b/lib/guardian_extensions.rb index a13bf917..30ada985 100644 --- a/lib/guardian_extensions.rb +++ b/lib/guardian_extensions.rb @@ -2,6 +2,27 @@ module DiscourseAi module GuardianExtensions + def can_see_summary?(target) + return false if !SiteSetting.ai_summarization_enabled + + # TODO we want a switch to allow summaries for all topics + return false if target.class == Topic && target.private_message? + + has_cached_summary = AiSummary.exists?(target: target) + return has_cached_summary if user.nil? + + has_cached_summary || can_request_summary? + end + + def can_request_summary? + return false if anonymous? + + user_group_ids = user.group_ids + SiteSetting.ai_custom_summarization_allowed_groups_map.any? do |group_id| + user_group_ids.include?(group_id) + end + end + def can_debug_ai_bot_conversation?(target) return false if anonymous? diff --git a/lib/summarization.rb b/lib/summarization.rb new file mode 100644 index 00000000..e8b037df --- /dev/null +++ b/lib/summarization.rb @@ -0,0 +1,12 @@ +# frozen_string_literal: true +module DiscourseAi + module Summarization + def self.default_strategy + if SiteSetting.ai_summarization_model.present? && SiteSetting.ai_summarization_enabled + DiscourseAi::Summarization::Strategies::FoldContent.new(SiteSetting.ai_summarization_model) + else + nil + end + end + end +end diff --git a/lib/summarization/entry_point.rb b/lib/summarization/entry_point.rb index 3dcbd769..3066be47 100644 --- a/lib/summarization/entry_point.rb +++ b/lib/summarization/entry_point.rb @@ -2,18 +2,27 @@ module DiscourseAi module Summarization + def self.default_strategy + if SiteSetting.ai_summarization_model.present? && SiteSetting.ai_summarization_enabled + DiscourseAi::Summarization::Strategies::FoldContent.new(SiteSetting.ai_summarization_model) + else + nil + end + end + class EntryPoint def inject_into(plugin) plugin.add_to_serializer(:current_user, :can_summarize) do + return false if !SiteSetting.ai_summarization_enabled scope.user.in_any_groups?(SiteSetting.ai_custom_summarization_allowed_groups_map) end plugin.add_to_serializer(:topic_view, :summarizable) do - DiscourseAi::Summarization::Models::Base.can_see_summary?(object.topic, scope.user) + scope.can_see_summary?(object.topic) end plugin.add_to_serializer(:web_hook_topic_view, :summarizable) do - DiscourseAi::Summarization::Models::Base.can_see_summary?(object.topic, scope.user) + scope.can_see_summary?(object.topic) end end end diff --git a/lib/summarization/models/anthropic.rb b/lib/summarization/models/anthropic.rb deleted file mode 100644 index 138a7697..00000000 --- a/lib/summarization/models/anthropic.rb +++ /dev/null @@ -1,26 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Summarization - module Models - class Anthropic < Base - def display_name - "Anthropic's #{model}" - end - - def correctly_configured? - SiteSetting.ai_anthropic_api_key.present? || - DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") - end - - def configuration_hint - I18n.t( - "discourse_ai.summarization.configuration_hint", - count: 1, - setting: "ai_anthropic_api_key", - ) - end - end - end - end -end diff --git a/lib/summarization/models/base.rb b/lib/summarization/models/base.rb deleted file mode 100644 index 91d0e38f..00000000 --- a/lib/summarization/models/base.rb +++ /dev/null @@ -1,171 +0,0 @@ -# frozen_string_literal: true - -# Base class that defines the interface that every summarization -# strategy must implement. -# Above each method, you'll find an explanation of what -# it does and what it should return. - -module DiscourseAi - module Summarization - module Models - class Base - class << self - def available_strategies - foldable_models = [ - Models::OpenAi.new("open_ai:gpt-4", max_tokens: 8192), - Models::OpenAi.new("open_ai:gpt-4-32k", max_tokens: 32_768), - Models::OpenAi.new("open_ai:gpt-4-turbo", max_tokens: 100_000), - Models::OpenAi.new("open_ai:gpt-4o", max_tokens: 100_000), - Models::OpenAi.new("open_ai:gpt-3.5-turbo", max_tokens: 4096), - Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384), - Models::Gemini.new("google:gemini-pro", max_tokens: 32_768), - Models::Gemini.new("google:gemini-1.5-pro", max_tokens: 800_000), - Models::Gemini.new("google:gemini-1.5-flash", max_tokens: 800_000), - ] - - claude_prov = "anthropic" - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") - claude_prov = "aws_bedrock" - end - - foldable_models << Models::Anthropic.new("#{claude_prov}:claude-2", max_tokens: 200_000) - foldable_models << Models::Anthropic.new( - "#{claude_prov}:claude-instant-1", - max_tokens: 100_000, - ) - foldable_models << Models::Anthropic.new( - "#{claude_prov}:claude-3-haiku", - max_tokens: 200_000, - ) - foldable_models << Models::Anthropic.new( - "#{claude_prov}:claude-3-sonnet", - max_tokens: 200_000, - ) - - foldable_models << Models::Anthropic.new( - "#{claude_prov}:claude-3-opus", - max_tokens: 200_000, - ) - - mixtral_prov = "hugging_face" - if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?( - "mistralai/Mixtral-8x7B-Instruct-v0.1", - ) - mixtral_prov = "vllm" - end - - foldable_models << Models::Mixtral.new( - "#{mixtral_prov}:mistralai/Mixtral-8x7B-Instruct-v0.1", - max_tokens: 32_000, - ) - - unless Rails.env.production? - foldable_models << Models::Fake.new("fake:fake", max_tokens: 8192) - end - - folded_models = foldable_models.map { |model| Strategies::FoldContent.new(model) } - - folded_models - end - - def find_strategy(strategy_model) - available_strategies.detect { |s| s.model == strategy_model } - end - - def selected_strategy - return if SiteSetting.ai_summarization_strategy.blank? - - find_strategy(SiteSetting.ai_summarization_strategy) - end - - def can_see_summary?(target, user) - return false if SiteSetting.ai_summarization_strategy.blank? - return false if target.class == Topic && target.private_message? - - has_cached_summary = AiSummary.exists?(target: target) - return has_cached_summary if user.nil? - - has_cached_summary || can_request_summary_for?(user) - end - - def can_request_summary_for?(user) - return false unless user - - user_group_ids = user.group_ids - - SiteSetting.ai_custom_summarization_allowed_groups_map.any? do |group_id| - user_group_ids.include?(group_id) - end - end - end - - def initialize(model_name, max_tokens:) - @model_name = model_name - @max_tokens = max_tokens - end - - # Some strategies could require other conditions to work correctly, - # like site settings. - # This method gets called when admins attempt to select it, - # checking if we met those conditions. - def correctly_configured? - raise NotImplemented - end - - # Strategy name to display to admins in the available strategies dropdown. - def display_name - raise NotImplemented - end - - # If we don't meet the conditions to enable this strategy, - # we'll display this hint as an error to admins. - def configuration_hint - raise NotImplemented - end - - # The idea behind this method is "give me a collection of texts, - # and I'll handle the summarization to the best of my capabilities.". - # It's important to emphasize the "collection of texts" part, which implies - # it's not tied to any model and expects the "content" to be a hash instead. - # - # @param content { Hash } - Includes the content to summarize, plus additional - # context to help the strategy produce a better result. Keys present in the content hash: - # - resource_path (optional): Helps the strategy build links to the content in the summary (e.g. "/t/-/:topic_id/POST_NUMBER") - # - content_title (optional): Provides guidance about what the content is about. - # - contents (required): Array of hashes with content to summarize (e.g. [{ poster: "asd", id: 1, text: "This is a text" }]) - # All keys are required. - # @param &on_partial_blk { Block - Optional } - If the strategy supports it, the passed block - # will get called with partial summarized text as its generated. - # - # @param current_user { User } - User requesting the summary. - # - # @returns { Hash } - The summarized content. Example: - # { - # summary: "This is the final summary", - # } - def summarize(content, current_user) - raise NotImplemented - end - - def available_tokens - max_tokens - reserved_tokens - end - - # Returns the string we'll store in the selected strategy site setting. - def model - model_name.split(":").last - end - - attr_reader :model_name, :max_tokens - - protected - - def reserved_tokens - # Reserve tokens for the response and the base prompt - # ~500 words - 700 - end - end - end - end -end diff --git a/lib/summarization/models/custom_llm.rb b/lib/summarization/models/custom_llm.rb deleted file mode 100644 index 67798326..00000000 --- a/lib/summarization/models/custom_llm.rb +++ /dev/null @@ -1,41 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Summarization - module Models - class CustomLlm < Base - def display_name - custom_llm.display_name - end - - def correctly_configured? - if Rails.env.development? - SiteSetting.ai_ollama_endpoint.present? - else - SiteSetting.ai_hugging_face_api_url.present? || - SiteSetting.ai_vllm_endpoint_srv.present? || SiteSetting.ai_vllm_endpoint.present? - end - end - - def configuration_hint - I18n.t( - "discourse_ai.summarization.configuration_hint", - count: 1, - setting: "ai_hugging_face_api_url", - ) - end - - def model - model_name - end - - private - - def custom_llm - id = model.split(":").last - @llm ||= LlmModel.find_by(id: id) - end - end - end - end -end diff --git a/lib/summarization/models/fake.rb b/lib/summarization/models/fake.rb deleted file mode 100644 index 7398b649..00000000 --- a/lib/summarization/models/fake.rb +++ /dev/null @@ -1,25 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Summarization - module Models - class Fake < Base - def display_name - "fake" - end - - def correctly_configured? - true - end - - def configuration_hint - "" - end - - def model - "fake" - end - end - end - end -end diff --git a/lib/summarization/models/gemini.rb b/lib/summarization/models/gemini.rb deleted file mode 100644 index 2f0550ac..00000000 --- a/lib/summarization/models/gemini.rb +++ /dev/null @@ -1,25 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Summarization - module Models - class Gemini < Base - def display_name - "Google Gemini #{model}" - end - - def correctly_configured? - SiteSetting.ai_gemini_api_key.present? - end - - def configuration_hint - I18n.t( - "discourse_ai.summarization.configuration_hint", - count: 1, - setting: "ai_gemini_api_key", - ) - end - end - end - end -end diff --git a/lib/summarization/models/mixtral.rb b/lib/summarization/models/mixtral.rb deleted file mode 100644 index 3c3c7dab..00000000 --- a/lib/summarization/models/mixtral.rb +++ /dev/null @@ -1,25 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Summarization - module Models - class Mixtral < Base - def display_name - "MistralAI's #{model}" - end - - def correctly_configured? - SiteSetting.ai_hugging_face_api_url.present? || SiteSetting.ai_vllm_endpoint_srv.present? - end - - def configuration_hint - I18n.t( - "discourse_ai.summarization.configuration_hint", - count: 1, - settings: %w[ai_hugging_face_api_url ai_vllm_endpoint_srv], - ) - end - end - end - end -end diff --git a/lib/summarization/models/open_ai.rb b/lib/summarization/models/open_ai.rb deleted file mode 100644 index 121d71f5..00000000 --- a/lib/summarization/models/open_ai.rb +++ /dev/null @@ -1,25 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Summarization - module Models - class OpenAi < Base - def display_name - "Open AI's #{model}" - end - - def correctly_configured? - SiteSetting.ai_openai_api_key.present? - end - - def configuration_hint - I18n.t( - "discourse_ai.summarization.configuration_hint", - count: 1, - setting: "ai_openai_api_key", - ) - end - end - end - end -end diff --git a/lib/summarization/strategies/fold_content.rb b/lib/summarization/strategies/fold_content.rb index 36f76276..788e09f6 100644 --- a/lib/summarization/strategies/fold_content.rb +++ b/lib/summarization/strategies/fold_content.rb @@ -3,27 +3,19 @@ module DiscourseAi module Summarization module Strategies - class FoldContent < DiscourseAi::Summarization::Models::Base + class FoldContent def initialize(completion_model) - @completion_model = completion_model + @llm = DiscourseAi::Completions::Llm.proxy(completion_model) + raise "Invalid model provided for summarization strategy" if @llm.llm_model.nil? end - attr_reader :completion_model - - delegate :correctly_configured?, - :display_name, - :configuration_hint, - :model, - to: :completion_model + attr_reader :llm def summarize(content, user, &on_partial_blk) opts = content.except(:contents) - llm = DiscourseAi::Completions::Llm.proxy(completion_model.model_name) - initial_chunks = rebalance_chunks( - llm.tokenizer, content[:contents].map { |c| { ids: [c[:id]], summary: format_content_item(c) } }, ) @@ -31,28 +23,35 @@ module DiscourseAi if initial_chunks.length == 1 { summary: - summarize_single(llm, initial_chunks.first[:summary], user, opts, &on_partial_blk), + summarize_single(initial_chunks.first[:summary], user, opts, &on_partial_blk), chunks: [], } else - summarize_chunks(llm, initial_chunks, user, opts, &on_partial_blk) + summarize_chunks(initial_chunks, user, opts, &on_partial_blk) end end + def display_name + llm_model&.name || "unknown model" + end + private - def summarize_chunks(llm, chunks, user, opts, &on_partial_blk) - # Safely assume we always have more than one chunk. - summarized_chunks = summarize_in_chunks(llm, chunks, user, opts) - total_summaries_size = - llm.tokenizer.size(summarized_chunks.map { |s| s[:summary].to_s }.join) + def llm_model + llm.llm_model + end - if total_summaries_size < completion_model.available_tokens + def summarize_chunks(chunks, user, opts, &on_partial_blk) + # Safely assume we always have more than one chunk. + summarized_chunks = summarize_in_chunks(chunks, user, opts) + total_summaries_size = + llm_model.tokenizer_class.size(summarized_chunks.map { |s| s[:summary].to_s }.join) + + if total_summaries_size < available_tokens # Chunks are small enough, we can concatenate them. { summary: concatenate_summaries( - llm, summarized_chunks.map { |s| s[:summary] }, user, &on_partial_blk @@ -61,9 +60,9 @@ module DiscourseAi } else # We have summarized chunks but we can't concatenate them yet. Split them into smaller summaries and summarize again. - rebalanced_chunks = rebalance_chunks(llm.tokenizer, summarized_chunks) + rebalanced_chunks = rebalance_chunks(summarized_chunks) - summarize_chunks(llm, rebalanced_chunks, user, opts, &on_partial_blk) + summarize_chunks(rebalanced_chunks, user, opts, &on_partial_blk) end end @@ -71,15 +70,15 @@ module DiscourseAi "(#{item[:id]} #{item[:poster]} said: #{item[:text]} " end - def rebalance_chunks(tokenizer, chunks) + def rebalance_chunks(chunks) section = { ids: [], summary: "" } chunks = chunks.reduce([]) do |sections, chunk| - if tokenizer.can_expand_tokens?( + if llm_model.tokenizer_class.can_expand_tokens?( section[:summary], chunk[:summary], - completion_model.available_tokens, + available_tokens, ) section[:summary] += chunk[:summary] section[:ids] = section[:ids].concat(chunk[:ids]) @@ -96,13 +95,13 @@ module DiscourseAi chunks end - def summarize_single(llm, text, user, opts, &on_partial_blk) + def summarize_single(text, user, opts, &on_partial_blk) prompt = summarization_prompt(text, opts) llm.generate(prompt, user: user, feature_name: "summarize", &on_partial_blk) end - def summarize_in_chunks(llm, chunks, user, opts) + def summarize_in_chunks(chunks, user, opts) chunks.map do |chunk| prompt = summarization_prompt(chunk[:summary], opts) @@ -116,7 +115,7 @@ module DiscourseAi end end - def concatenate_summaries(llm, summaries, user, &on_partial_blk) + def concatenate_summaries(summaries, user, &on_partial_blk) prompt = DiscourseAi::Completions::Prompt.new(<<~TEXT.strip) You are a summarization bot that effectively concatenates disjoint summaries, creating a cohesive narrative. The narrative you create is in the form of one or multiple paragraphs. @@ -185,6 +184,14 @@ module DiscourseAi prompt end + + def available_tokens + # Reserve tokens for the response and the base prompt + # ~500 words + reserved_tokens = 700 + + llm_model.max_prompt_tokens - reserved_tokens + end end end end diff --git a/spec/jobs/regular/stream_topic_ai_summary_spec.rb b/spec/jobs/regular/stream_topic_ai_summary_spec.rb index 5eb0c49f..87848560 100644 --- a/spec/jobs/regular/stream_topic_ai_summary_spec.rb +++ b/spec/jobs/regular/stream_topic_ai_summary_spec.rb @@ -9,9 +9,11 @@ RSpec.describe Jobs::StreamTopicAiSummary do fab!(:post_2) { Fabricate(:post, topic: topic, post_number: 2) } fab!(:user) { Fabricate(:leader) } - before { Group.find(Group::AUTO_GROUPS[:trust_level_3]).add(user) } - - before { SiteSetting.ai_summarization_strategy = "fake" } + before do + Group.find(Group::AUTO_GROUPS[:trust_level_3]).add(user) + assign_fake_provider_to(:ai_summarization_model) + SiteSetting.ai_summarization_enabled = true + end def with_responses(responses) DiscourseAi::Completions::Llm.with_prepared_responses(responses) { yield } diff --git a/spec/lib/modules/summarization/base_spec.rb b/spec/lib/guardian_extensions_spec.rb similarity index 71% rename from spec/lib/modules/summarization/base_spec.rb rename to spec/lib/guardian_extensions_spec.rb index 5af07692..5dd0c651 100644 --- a/spec/lib/modules/summarization/base_spec.rb +++ b/spec/lib/guardian_extensions_spec.rb @@ -1,24 +1,26 @@ # frozen_string_literal: true -describe DiscourseAi::Summarization::Models::Base do +describe DiscourseAi::GuardianExtensions do fab!(:user) fab!(:group) fab!(:topic) before do group.add(user) - - SiteSetting.ai_summarization_strategy = "fake" + assign_fake_provider_to(:ai_summarization_model) + SiteSetting.ai_summarization_enabled = true end describe "#can_see_summary?" do + let(:guardian) { Guardian.new(user) } + context "when the user cannot generate a summary" do before { SiteSetting.ai_custom_summarization_allowed_groups = "" } it "returns false" do SiteSetting.ai_custom_summarization_allowed_groups = "" - expect(described_class.can_see_summary?(topic, user)).to eq(false) + expect(guardian.can_see_summary?(topic)).to eq(false) end it "returns true if there is a cached summary" do @@ -29,7 +31,7 @@ describe DiscourseAi::Summarization::Models::Base do algorithm: "test", ) - expect(described_class.can_see_summary?(topic, user)).to eq(true) + expect(guardian.can_see_summary?(topic)).to eq(true) end end @@ -37,13 +39,24 @@ describe DiscourseAi::Summarization::Models::Base do before { SiteSetting.ai_custom_summarization_allowed_groups = group.id } it "returns true if the user group is present in the ai_custom_summarization_allowed_groups_map setting" do - expect(described_class.can_see_summary?(topic, user)).to eq(true) + expect(guardian.can_see_summary?(topic)).to eq(true) + end + end + + context "when the topic is a PM" do + before { SiteSetting.ai_custom_summarization_allowed_groups = group.id } + let(:pm) { Fabricate(:private_message_topic) } + + it "returns false" do + expect(guardian.can_see_summary?(pm)).to eq(false) end end context "when there is no user" do + let(:guardian) { Guardian.new } + it "returns false for anons" do - expect(described_class.can_see_summary?(topic, nil)).to eq(false) + expect(guardian.can_see_summary?(topic)).to eq(false) end it "returns true for anons when there is a cached summary" do @@ -54,16 +67,7 @@ describe DiscourseAi::Summarization::Models::Base do algorithm: "test", ) - expect(described_class.can_see_summary?(topic, nil)).to eq(true) - end - end - - context "when the topic is a PM" do - before { SiteSetting.ai_custom_summarization_allowed_groups = group.id } - let(:pm) { Fabricate(:private_message_topic) } - - it "returns false" do - expect(described_class.can_see_summary?(pm, user)).to eq(false) + expect(guardian.can_see_summary?(topic)).to eq(true) end end end diff --git a/spec/lib/modules/summarization/strategies/fold_content_spec.rb b/spec/lib/modules/summarization/strategies/fold_content_spec.rb index 0333dd45..16ebc892 100644 --- a/spec/lib/modules/summarization/strategies/fold_content_spec.rb +++ b/spec/lib/modules/summarization/strategies/fold_content_spec.rb @@ -2,19 +2,21 @@ RSpec.describe DiscourseAi::Summarization::Strategies::FoldContent do describe "#summarize" do - subject(:strategy) { described_class.new(model) } + let!(:llm_model) { assign_fake_provider_to(:ai_summarization_model) } + + before do + SiteSetting.ai_summarization_enabled = true - let(:summarize_text) { "This is a text" } - let(:model_tokens) do # Make sure each content fits in a single chunk. # 700 is the number of tokens reserved for the prompt. - 700 + DiscourseAi::Tokenizer::OpenAiTokenizer.size("(1 asd said: This is a text ") + 3 - end - - let(:model) do - DiscourseAi::Summarization::Models::OpenAi.new("fake:fake", max_tokens: model_tokens) + model_tokens = + 700 + DiscourseAi::Tokenizer::OpenAiTokenizer.size("(1 asd said: This is a text ") + 3 + + llm_model.update!(max_prompt_tokens: model_tokens) end + let(:strategy) { DiscourseAi::Summarization.default_strategy } + let(:summarize_text) { "This is a text" } let(:content) { { contents: [{ poster: "asd", id: 1, text: summarize_text }] } } let(:single_summary) { "this is a single summary" } diff --git a/spec/requests/summarization/chat_summary_controller_spec.rb b/spec/requests/summarization/chat_summary_controller_spec.rb index cbc2b7bd..f23c447e 100644 --- a/spec/requests/summarization/chat_summary_controller_spec.rb +++ b/spec/requests/summarization/chat_summary_controller_spec.rb @@ -7,7 +7,8 @@ RSpec.describe DiscourseAi::Summarization::ChatSummaryController do before do group.add(current_user) - SiteSetting.ai_summarization_strategy = "fake" + assign_fake_provider_to(:ai_summarization_model) + SiteSetting.ai_summarization_enabled = true SiteSetting.ai_custom_summarization_allowed_groups = group.id SiteSetting.chat_enabled = true diff --git a/spec/requests/summarization/summary_controller_spec.rb b/spec/requests/summarization/summary_controller_spec.rb index 1d9006c3..558eb93e 100644 --- a/spec/requests/summarization/summary_controller_spec.rb +++ b/spec/requests/summarization/summary_controller_spec.rb @@ -6,7 +6,10 @@ RSpec.describe DiscourseAi::Summarization::SummaryController do fab!(:post_1) { Fabricate(:post, topic: topic, post_number: 1) } fab!(:post_2) { Fabricate(:post, topic: topic, post_number: 2) } - before { SiteSetting.ai_summarization_strategy = "fake" } + before do + assign_fake_provider_to(:ai_summarization_model) + SiteSetting.ai_summarization_enabled = true + end context "for anons" do it "returns a 404 if there is no cached summary" do diff --git a/spec/services/discourse_ai/topic_summarization_spec.rb b/spec/services/discourse_ai/topic_summarization_spec.rb index 5ff20741..b6980d61 100644 --- a/spec/services/discourse_ai/topic_summarization_spec.rb +++ b/spec/services/discourse_ai/topic_summarization_spec.rb @@ -6,14 +6,15 @@ describe DiscourseAi::TopicSummarization do fab!(:post_1) { Fabricate(:post, topic: topic, post_number: 1) } fab!(:post_2) { Fabricate(:post, topic: topic, post_number: 2) } - let(:model) do - DiscourseAi::Summarization::Strategies::FoldContent.new( - DiscourseAi::Summarization::Models::Fake.new("fake:fake", max_tokens: 8192), - ) + before do + assign_fake_provider_to(:ai_summarization_model) + SiteSetting.ai_summarization_enabled = true end + let(:strategy) { DiscourseAi::Summarization.default_strategy } + shared_examples "includes only public-visible topics" do - subject { described_class.new(model) } + subject { DiscourseAi::TopicSummarization.new(strategy) } it "only includes visible posts" do topic.first_post.update!(hidden: true) @@ -55,7 +56,7 @@ describe DiscourseAi::TopicSummarization do end describe "#summarize" do - subject(:summarization) { described_class.new(model) } + subject(:summarization) { described_class.new(strategy) } def assert_summary_is_cached(topic, summary_response) cached_summary = AiSummary.find_by(target: topic) @@ -72,9 +73,7 @@ describe DiscourseAi::TopicSummarization do it "caches the summary" do DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do section = summarization.summarize(topic, user) - expect(section.summarized_text).to eq(summary) - assert_summary_is_cached(topic, summary) end end @@ -83,11 +82,10 @@ describe DiscourseAi::TopicSummarization do summarization.summarize(topic, user) cached_summary_text = "This is a cached summary" - cached_summary = - AiSummary.find_by(target: topic).update!( - summarized_text: cached_summary_text, - updated_at: 24.hours.ago, - ) + AiSummary.find_by(target: topic).update!( + summarized_text: cached_summary_text, + updated_at: 24.hours.ago, + ) section = summarization.summarize(topic, user) expect(section.summarized_text).to eq(cached_summary_text) @@ -129,8 +127,13 @@ describe DiscourseAi::TopicSummarization do end before do + # a bit tricky, but fold_content now caches an instance of LLM + # once it is cached with_prepared_responses will not work as expected + # since it is glued to the old llm instance + # so we create the cached summary totally independantly DiscourseAi::Completions::Llm.with_prepared_responses([cached_text]) do - summarization.summarize(topic, user) + strategy = DiscourseAi::Summarization.default_strategy + described_class.new(strategy).summarize(topic, user) end cached_summary.update!(summarized_text: cached_text, created_at: 24.hours.ago) diff --git a/spec/system/summarization/chat_summarization_spec.rb b/spec/system/summarization/chat_summarization_spec.rb index 38ee34d9..6a54647c 100644 --- a/spec/system/summarization/chat_summarization_spec.rb +++ b/spec/system/summarization/chat_summarization_spec.rb @@ -11,7 +11,8 @@ RSpec.describe "Summarize a channel since your last visit", type: :system do before do group.add(current_user) - SiteSetting.ai_summarization_strategy = "fake" + assign_fake_provider_to(:ai_summarization_model) + SiteSetting.ai_summarization_enabled = true SiteSetting.ai_custom_summarization_allowed_groups = group.id.to_s SiteSetting.chat_enabled = true