From f6ac5cd0a8928eb7d2980273d488ead2b538e23d Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 12 Apr 2024 23:32:46 +1000 Subject: [PATCH] FEATURE: allow tuning of RAG generation (#565) * FEATURE: allow tuning of RAG generation - change chunking to be token based vs char based (which is more accurate) - allow control over overlap / tokens per chunk and conversation snippets inserted - UI to control new settings * improve ui a bit * fix various reindex issues * reduce concurrency * try ultra low queue ... concurrency 1 is too slow. --- ...ugins-show-discourse-ai-ai-personas-new.js | 4 + .../admin/ai_personas_controller.rb | 3 + app/jobs/regular/digest_rag_upload.rb | 49 ++++++++--- app/jobs/regular/generate_rag_embeddings.rb | 3 +- app/models/ai_persona.rb | 66 +++++++++----- .../localized_ai_persona_serializer.rb | 5 +- .../discourse/admin/models/ai-persona.js | 19 +++- .../components/ai-persona-editor.gjs | 73 ++++++++++++++++ .../modules/ai-bot/common/ai-persona.scss | 5 ++ config/locales/client.en.yml | 8 ++ ...0409035951_add_rag_params_to_ai_persona.rb | 11 +++ lib/ai_bot/entry_point.rb | 15 ++-- lib/ai_bot/personas/persona.rb | 19 +++- lib/tokenizer/basic_tokenizer.rb | 9 +- lib/tokenizer/open_ai_tokenizer.rb | 8 ++ spec/fixtures/rag/doc_with_metadata.txt | 11 +++ .../fixtures/rag/parsed_doc_with_metadata.txt | 87 +++++++++++++++---- spec/jobs/regular/digest_rag_upload_spec.rb | 4 + .../modules/ai_bot/personas/persona_spec.rb | 34 +++++++- spec/models/ai_persona_spec.rb | 26 ++++++ spec/models/rag_document_fragment_spec.rb | 14 +++ .../admin/ai_personas_controller_spec.rb | 20 +++++ .../unit/models/ai-persona-test.js | 3 + 23 files changed, 435 insertions(+), 61 deletions(-) create mode 100644 db/migrate/20240409035951_add_rag_params_to_ai_persona.rb diff --git a/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-ai-personas-new.js b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-ai-personas-new.js index 82d73f3a..03d2a493 100644 --- a/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-ai-personas-new.js +++ b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-ai-personas-new.js @@ -6,6 +6,10 @@ export default DiscourseRoute.extend({ const record = this.store.createRecord("ai-persona"); record.set("allowed_group_ids", [AUTO_GROUPS.trust_level_0.id]); record.set("rag_uploads", []); + // these match the defaults on the table + record.set("rag_chunk_tokens", 374); + record.set("rag_chunk_overlap_tokens", 10); + record.set("rag_conversation_chunks", 10); return record; }, diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index 64b33735..928155d4 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -121,6 +121,9 @@ module DiscourseAi :max_context_posts, :vision_enabled, :vision_max_pixels, + :rag_chunk_tokens, + :rag_chunk_overlap_tokens, + :rag_conversation_chunks, allowed_group_ids: [], rag_uploads: [:id], ) diff --git a/app/jobs/regular/digest_rag_upload.rb b/app/jobs/regular/digest_rag_upload.rb index d7c348e8..b66fe86f 100644 --- a/app/jobs/regular/digest_rag_upload.rb +++ b/app/jobs/regular/digest_rag_upload.rb @@ -4,13 +4,21 @@ module ::Jobs class DigestRagUpload < ::Jobs::Base CHUNK_SIZE = 1024 CHUNK_OVERLAP = 64 - MAX_FRAGMENTS = 10_000 + MAX_FRAGMENTS = 100_000 # TODO(roman): Add a way to automatically recover from errors, resulting in unindexed uploads. def execute(args) return if (upload = Upload.find_by(id: args[:upload_id])).nil? return if (ai_persona = AiPersona.find_by(id: args[:ai_persona_id])).nil? + truncation = DiscourseAi::Embeddings::Strategies::Truncation.new + vector_rep = + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) + + tokenizer = vector_rep.tokenizer + chunk_tokens = ai_persona.rag_chunk_tokens + overlap_tokens = ai_persona.rag_chunk_overlap_tokens + fragment_ids = RagDocumentFragment.where(ai_persona: ai_persona, upload: upload).pluck(:id) # Check if this is the first time we process this upload. @@ -22,7 +30,12 @@ module ::Jobs idx = 0 ActiveRecord::Base.transaction do - chunk_document(document) do |chunk, metadata| + chunk_document( + file: document, + tokenizer: tokenizer, + chunk_tokens: chunk_tokens, + overlap_tokens: overlap_tokens, + ) do |chunk, metadata| fragment_ids << RagDocumentFragment.create!( ai_persona: ai_persona, fragment: chunk, @@ -53,15 +66,18 @@ module ::Jobs private - def chunk_document(file) + def chunk_document(file:, tokenizer:, chunk_tokens:, overlap_tokens:) buffer = +"" current_metadata = nil done = false overlap = "" + # generally this will be plenty + read_size = chunk_tokens * 10 + while buffer.present? || !done - if buffer.length < CHUNK_SIZE * 2 - read = file.read(CHUNK_SIZE * 2) + if buffer.length < read_size + read = file.read(read_size) done = true if read.nil? read = Encodings.to_utf8(read) if read @@ -84,7 +100,7 @@ module ::Jobs overlap = "" end - chunk, split_char = first_chunk(to_chunk) + chunk, split_char = first_chunk(to_chunk, tokenizer: tokenizer, chunk_tokens: chunk_tokens) buffer = buffer[chunk.length..-1] processed_chunk = overlap + chunk @@ -94,15 +110,28 @@ module ::Jobs yield processed_chunk, current_metadata - overlap = (chunk[-CHUNK_OVERLAP..-1] || chunk) + split_char + current_chunk_tokens = tokenizer.encode(chunk) + overlap_token_ids = current_chunk_tokens[-overlap_tokens..-1] || current_chunk_tokens + + overlap = "" + + while overlap_token_ids.present? + begin + overlap = tokenizer.decode(overlap_token_ids) + split_char + break if overlap.encoding == Encoding::UTF_8 + rescue StandardError + # it is possible that we truncated mid char + end + overlap_token_ids.shift + end # remove first word it is probably truncated overlap = overlap.split(" ", 2).last end end - def first_chunk(text, chunk_size: CHUNK_SIZE, splitters: ["\n\n", "\n", ".", ""]) - return text, " " if text.length <= chunk_size + def first_chunk(text, chunk_tokens:, tokenizer:, splitters: ["\n\n", "\n", ".", ""]) + return text, " " if tokenizer.tokenize(text).length <= chunk_tokens splitters = splitters.find_all { |s| text.include?(s) }.compact @@ -115,7 +144,7 @@ module ::Jobs text .split(split_char) .each do |part| - break if (buffer.length + split_char.length + part.length) > chunk_size + break if tokenizer.tokenize(buffer + split_char + part).length > chunk_tokens buffer << split_char buffer << part end diff --git a/app/jobs/regular/generate_rag_embeddings.rb b/app/jobs/regular/generate_rag_embeddings.rb index 96483a9a..d1d05a19 100644 --- a/app/jobs/regular/generate_rag_embeddings.rb +++ b/app/jobs/regular/generate_rag_embeddings.rb @@ -2,7 +2,8 @@ module ::Jobs class GenerateRagEmbeddings < ::Jobs::Base - sidekiq_options queue: "low" + sidekiq_options queue: "ultra_low" + # we could also restrict concurrency but this takes so long if it is not concurrent def execute(args) return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty? diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb index b820b80c..e88cc14f 100644 --- a/app/models/ai_persona.rb +++ b/app/models/ai_persona.rb @@ -13,6 +13,10 @@ class AiPersona < ActiveRecord::Base # we may want to revisit this in the future validates :vision_max_pixels, numericality: { greater_than: 0, maximum: 4_000_000 } + validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 } + validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 } + validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 } + belongs_to :created_by, class_name: "User" belongs_to :user @@ -25,6 +29,8 @@ class AiPersona < ActiveRecord::Base before_destroy :ensure_not_system + before_update :regenerate_rag_fragments + class MultisiteHash def initialize(id) @hash = Hash.new { |h, k| h[k] = {} } @@ -110,6 +116,7 @@ class AiPersona < ActiveRecord::Base max_context_posts = self.max_context_posts vision_enabled = self.vision_enabled vision_max_pixels = self.vision_max_pixels + rag_conversation_chunks = self.rag_conversation_chunks persona_class = DiscourseAi::AiBot::Personas::Persona.system_personas_by_id[self.id] if persona_class @@ -149,6 +156,10 @@ class AiPersona < ActiveRecord::Base vision_max_pixels end + persona_class.define_singleton_method :rag_conversation_chunks do + rag_conversation_chunks + end + return persona_class end @@ -232,6 +243,10 @@ class AiPersona < ActiveRecord::Base vision_max_pixels end + define_singleton_method :rag_conversation_chunks do + rag_conversation_chunks + end + define_singleton_method :to_s do "#" end @@ -314,6 +329,12 @@ class AiPersona < ActiveRecord::Base user end + def regenerate_rag_fragments + if rag_chunk_tokens_changed? || rag_chunk_overlap_tokens_changed? + RagDocumentFragment.where(ai_persona: self).delete_all + end + end + private def system_persona_unchangeable @@ -335,26 +356,31 @@ end # # Table name: ai_personas # -# id :bigint not null, primary key -# name :string(100) not null -# description :string(2000) not null -# commands :json not null -# system_prompt :string(10000000) not null -# allowed_group_ids :integer default([]), not null, is an Array -# created_by_id :integer -# enabled :boolean default(TRUE), not null -# created_at :datetime not null -# updated_at :datetime not null -# system :boolean default(FALSE), not null -# priority :boolean default(FALSE), not null -# temperature :float -# top_p :float -# user_id :integer -# mentionable :boolean default(FALSE), not null -# default_llm :text -# max_context_posts :integer -# vision_enabled :boolean default(FALSE), not null -# vision_max_pixels :integer default(1048576), not null +# id :bigint not null, primary key +# name :string(100) not null +# description :string(2000) not null +# commands :json not null +# system_prompt :string(10000000) not null +# allowed_group_ids :integer default([]), not null, is an Array +# created_by_id :integer +# enabled :boolean default(TRUE), not null +# created_at :datetime not null +# updated_at :datetime not null +# system :boolean default(FALSE), not null +# priority :boolean default(FALSE), not null +# temperature :float +# top_p :float +# user_id :integer +# mentionable :boolean default(FALSE), not null +# default_llm :text +# max_context_posts :integer +# max_post_context_tokens :integer +# max_context_tokens :integer +# vision_enabled :boolean default(FALSE), not null +# vision_max_pixels :integer default(1048576), not null +# rag_chunk_tokens :integer default(374), not null +# rag_chunk_overlap_tokens :integer default(10), not null +# rag_conversation_chunks :integer default(10), not null # # Indexes # diff --git a/app/serializers/localized_ai_persona_serializer.rb b/app/serializers/localized_ai_persona_serializer.rb index 1ca58089..70a72839 100644 --- a/app/serializers/localized_ai_persona_serializer.rb +++ b/app/serializers/localized_ai_persona_serializer.rb @@ -19,7 +19,10 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer :user_id, :max_context_posts, :vision_enabled, - :vision_max_pixels + :vision_max_pixels, + :rag_chunk_tokens, + :rag_chunk_overlap_tokens, + :rag_conversation_chunks has_one :user, serializer: BasicUserSerializer, embed: :object has_many :rag_uploads, serializer: UploadSerializer, embed: :object diff --git a/assets/javascripts/discourse/admin/models/ai-persona.js b/assets/javascripts/discourse/admin/models/ai-persona.js index bdd88cca..8b731dfd 100644 --- a/assets/javascripts/discourse/admin/models/ai-persona.js +++ b/assets/javascripts/discourse/admin/models/ai-persona.js @@ -2,7 +2,7 @@ import { tracked } from "@glimmer/tracking"; import { ajax } from "discourse/lib/ajax"; import RestModel from "discourse/models/rest"; -const ATTRIBUTES = [ +const CREATE_ATTRIBUTES = [ "id", "name", "description", @@ -24,6 +24,13 @@ const ATTRIBUTES = [ "rag_uploads", ]; +// rag params are populated on save, only show it when editing +const ATTRIBUTES = CREATE_ATTRIBUTES.concat([ + "rag_chunk_tokens", + "rag_chunk_overlap_tokens", + "rag_conversation_chunks", +]); + const SYSTEM_ATTRIBUTES = [ "id", "allowed_group_ids", @@ -38,6 +45,9 @@ const SYSTEM_ATTRIBUTES = [ "vision_enabled", "vision_max_pixels", "rag_uploads", + "rag_chunk_tokens", + "rag_chunk_overlap_tokens", + "rag_conversation_chunks", ]; class CommandOption { @@ -122,16 +132,19 @@ export default class AiPersona extends RestModel { : this.getProperties(ATTRIBUTES); attrs.id = this.id; this.populateCommandOptions(attrs); + return attrs; } createProperties() { - let attrs = this.getProperties(ATTRIBUTES); + let attrs = this.getProperties(CREATE_ATTRIBUTES); this.populateCommandOptions(attrs); return attrs; } workingCopy() { - return AiPersona.create(this.createProperties()); + let attrs = this.getProperties(ATTRIBUTES); + this.populateCommandOptions(attrs); + return AiPersona.create(attrs); } } diff --git a/assets/javascripts/discourse/components/ai-persona-editor.gjs b/assets/javascripts/discourse/components/ai-persona-editor.gjs index 200040a7..d9978712 100644 --- a/assets/javascripts/discourse/components/ai-persona-editor.gjs +++ b/assets/javascripts/discourse/components/ai-persona-editor.gjs @@ -38,6 +38,7 @@ export default class PersonaEditor extends Component { @tracked showDelete = false; @tracked maxPixelsValue = null; @tracked ragIndexingStatuses = null; + @tracked showIndexingOptions = false; @action updateModel() { @@ -48,6 +49,13 @@ export default class PersonaEditor extends Component { ); } + @action + toggleIndexingOptions(event) { + this.showIndexingOptions = !this.showIndexingOptions; + event.preventDefault(); + event.stopPropagation(); + } + findClosestPixelValue(pixels) { let value = "high"; this.maxPixelValues.forEach((info) => { @@ -69,6 +77,12 @@ export default class PersonaEditor extends Component { ]; } + get indexingOptionsText() { + return this.showIndexingOptions + ? I18n.t("discourse_ai.ai_persona.hide_indexing_options") + : I18n.t("discourse_ai.ai_persona.show_indexing_options"); + } + @action async updateAllGroups() { this.allGroups = await Group.findAll(); @@ -448,7 +462,66 @@ export default class PersonaEditor extends Component { @onAdd={{this.addUpload}} @onRemove={{this.removeUpload}} /> + {{this.indexingOptionsText}} + {{#if this.showIndexingOptions}} +
+ + + +
+
+ + + +
+
+ + + +
+ {{/if}} {{/if}}