From 03fc94684b27da3612c61a106a03eb4159ad29fc Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 4 Jan 2024 23:53:47 +1100 Subject: [PATCH] FIX: AI helper not working correctly with mixtral (#399) * FIX: AI helper not working correctly with mixtral This PR introduces a new function on the generic llm called #generate This will replace the implementation of completion! #generate introduces a new way to pass temperature, max_tokens and stop_sequences Then LLM implementers need to implement #normalize_model_params to ensure the generic names match the LLM specific endpoint This also adds temperature and stop_sequences to completion_prompts this allows for much more robust completion prompts * port everything over to #generate * Fix translation - On anthropic this no longer throws random "This is your translation:" - On mixtral this actually works * fix markdown table generation as well --- .../ai_helper/assistant_controller.rb | 6 +- app/models/completion_prompt.rb | 2 + .../ai_helper/603_completion_prompts.rb | 57 ++++++++++++------- ...4013944_add_params_to_completion_prompt.rb | 8 +++ lib/ai_helper/assistant.rb | 40 ++++++------- lib/ai_helper/painter.rb | 5 +- lib/automation/llm_triage.rb | 24 +++----- lib/automation/report_runner.rb | 7 +-- lib/completions/dialects/chat_gpt.rb | 2 +- lib/completions/dialects/mixtral.rb | 2 +- lib/completions/endpoints/anthropic.rb | 18 +++++- lib/completions/endpoints/aws_bedrock.rb | 18 +++++- lib/completions/endpoints/base.rb | 9 ++- lib/completions/endpoints/canned_response.rb | 5 ++ lib/completions/endpoints/gemini.rb | 24 +++++++- lib/completions/endpoints/hugging_face.rb | 14 +++++ lib/completions/endpoints/open_ai.rb | 11 ++++ lib/completions/endpoints/vllm.rb | 12 +++- lib/completions/llm.rb | 19 ++++++- lib/embeddings/semantic_search.rb | 2 +- lib/summarization/strategies/fold_content.rb | 8 +-- spec/lib/completions/dialects/mixtral_spec.rb | 4 +- .../lib/completions/endpoints/open_ai_spec.rb | 4 +- spec/lib/completions/llm_spec.rb | 6 +- .../ai_helper/assistant_controller_spec.rb | 2 +- 25 files changed, 217 insertions(+), 92 deletions(-) create mode 100644 db/migrate/20240104013944_add_params_to_completion_prompt.rb diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb index 559d1382..c3e08f4c 100644 --- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb +++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb @@ -43,7 +43,7 @@ module DiscourseAi ), status: 200 end - rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e + rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"), status: 502 end @@ -63,7 +63,7 @@ module DiscourseAi ), status: 200 end - rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e + rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"), status: 502 end @@ -111,7 +111,7 @@ module DiscourseAi ) render json: { success: true }, status: 200 - rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e + rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"), status: 502 end diff --git a/app/models/completion_prompt.rb b/app/models/completion_prompt.rb index 12614278..183b7adb 100644 --- a/app/models/completion_prompt.rb +++ b/app/models/completion_prompt.rb @@ -67,6 +67,8 @@ end # created_at :datetime not null # updated_at :datetime not null # messages :jsonb +# temperature :integer +# stop_sequences :string is an Array # # Indexes # diff --git a/db/fixtures/ai_helper/603_completion_prompts.rb b/db/fixtures/ai_helper/603_completion_prompts.rb index d08f78e5..d9f40c0c 100644 --- a/db/fixtures/ai_helper/603_completion_prompts.rb +++ b/db/fixtures/ai_helper/603_completion_prompts.rb @@ -5,47 +5,55 @@ CompletionPrompt.seed do |cp| cp.id = -301 cp.name = "translate" cp.prompt_type = CompletionPrompt.prompt_types[:text] - cp.messages = { insts: <<~TEXT } - I want you to act as an English translator, spelling corrector and improver. I will write to you - in any language and you will detect the language, translate it and answer in the corrected and - improved version of my text, in English. I want you to replace my simplified A0-level words and - sentences with more beautiful and elegant, upper level English words and sentences. - Keep the meaning same, but make them more literary. I want you to only reply the correction, - the improvements and nothing else, do not write explanations. - You will find the text between XML tags. - TEXT + cp.stop_sequences = ["\n", ""] + cp.temperature = 0.2 + cp.messages = { + insts: <<~TEXT, + I want you to act as an English translator, spelling corrector and improver. I will write to you + in any language and you will detect the language, translate it and answer in the corrected and + improved version of my text, in English. I want you to replace my simplified A0-level words and + sentences with more beautiful and elegant, upper level English words and sentences. + Keep the meaning same, but make them more literary. I want you to only reply the correction, + the improvements and nothing else, do not write explanations. + You will find the text between XML tags. + Include your translation between XML tags. + TEXT + examples: [ + ["Hello world", "Hello world"], + ["Bonjour le monde", "Hello world"], + ], + } end CompletionPrompt.seed do |cp| cp.id = -303 cp.name = "proofread" cp.prompt_type = CompletionPrompt.prompt_types[:diff] + cp.temperature = 0 + cp.stop_sequences = ["\n"] cp.messages = { insts: <<~TEXT, You are a markdown proofreader. You correct egregious typos and phrasing issues but keep the user's original voice. You do not touch code blocks. I will provide you with text to proofread. If nothing needs fixing, then you will echo the text back. - Optionally, a user can specify intensity. Intensity 10 is a pedantic English teacher correcting the text. - Intensity 1 is a minimal proofreader. By default, you operate at intensity 1. You will find the text between XML tags. + You will ALWAYS return the corrected text between XML tags. TEXT examples: [ [ "![amazing car|100x100, 22%](upload://hapy.png)", - "![Amazing car|100x100, 22%](upload://hapy.png)", + "![Amazing car|100x100, 22%](upload://hapy.png)", ], [<<~TEXT, "The rain in Spain, stays mainly in the Plane."], - Intensity 1: The rain in spain stays mainly in the plane. TEXT [ - "The rain in Spain, stays mainly in the Plane.", - "The rain in Spain, stays mainly in the Plane.", + "The rain in Spain, stays mainly in the Plane.", + "The rain in Spain, stays mainly in the Plane.", ], [<<~TEXT, <<~TEXT], - Intensity 1: Hello, Sometimes the logo isn't changing automatically when color scheme changes. @@ -53,13 +61,14 @@ CompletionPrompt.seed do |cp| ![Screen Recording 2023-03-17 at 18.04.22|video](upload://2rcVL0ZMxHPNtPWQbZjwufKpWVU.mov) TEXT + Hello, Sometimes the logo does not change automatically when the color scheme changes. ![Screen Recording 2023-03-17 at 18.04.22|video](upload://2rcVL0ZMxHPNtPWQbZjwufKpWVU.mov) + TEXT [<<~TEXT, <<~TEXT], - Intensity 1: Any ideas what is wrong with this peace of cod? > This quot contains a typo ```ruby @@ -69,6 +78,7 @@ CompletionPrompt.seed do |cp| ``` TEXT + Any ideas what is wrong with this piece of code? > This quot contains a typo ```ruby @@ -76,6 +86,7 @@ CompletionPrompt.seed do |cp| testing.a_typo = 11 bad = "bad" ``` + TEXT ], } @@ -85,15 +96,19 @@ CompletionPrompt.seed do |cp| cp.id = -304 cp.name = "markdown_table" cp.prompt_type = CompletionPrompt.prompt_types[:diff] + cp.temperature = 0.5 + cp.stop_sequences = ["\n"] cp.messages = { insts: <<~TEXT, You are a markdown table formatter, I will provide you text inside XML tags and you will format it into a markdown table TEXT examples: [ ["sam,joe,jane\nage: 22| 10|11", <<~TEXT], + | | sam | joe | jane | |---|---|---|---| | age | 22 | 10 | 11 | + TEXT [<<~TEXT, <<~TEXT], @@ -102,22 +117,26 @@ CompletionPrompt.seed do |cp| fred: height 22 TEXT + | | speed | age | height | |---|---|---|---| | sam | 100 | 22 | - | | jane | - | 10 | - | | fred | - | - | 22 | + TEXT [<<~TEXT, <<~TEXT], - chrome 22ms (first load 10ms) - firefox 10ms (first load: 9ms) + chrome 22ms (first load 10ms) + firefox 10ms (first load: 9ms) TEXT + | Browser | Load Time (ms) | First Load Time (ms) | |---|---|---| | Chrome | 22 | 10 | | Firefox | 10 | 9 | + TEXT ], } diff --git a/db/migrate/20240104013944_add_params_to_completion_prompt.rb b/db/migrate/20240104013944_add_params_to_completion_prompt.rb new file mode 100644 index 00000000..7d179e13 --- /dev/null +++ b/db/migrate/20240104013944_add_params_to_completion_prompt.rb @@ -0,0 +1,8 @@ +# frozen_string_literal: true + +class AddParamsToCompletionPrompt < ActiveRecord::Migration[7.0] + def change + add_column :completion_prompts, :temperature, :integer + add_column :completion_prompts, :stop_sequences, :string, array: true + end +end diff --git a/lib/ai_helper/assistant.rb b/lib/ai_helper/assistant.rb index e9cbc96c..038ef4bd 100644 --- a/lib/ai_helper/assistant.rb +++ b/lib/ai_helper/assistant.rb @@ -36,20 +36,26 @@ module DiscourseAi llm = DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model) generic_prompt = completion_prompt.messages_with_input(input) - llm.completion!(generic_prompt, user, &block) + llm.generate( + generic_prompt, + user: user, + temperature: completion_prompt.temperature, + stop_sequences: completion_prompt.stop_sequences, + &block + ) end def generate_and_send_prompt(completion_prompt, input, user) completion_result = generate_prompt(completion_prompt, input, user) result = { type: completion_prompt.prompt_type } - result[:diff] = parse_diff(input, completion_result) if completion_prompt.diff? - result[:suggestions] = ( if completion_prompt.list? parse_list(completion_result).map { |suggestion| sanitize_result(suggestion) } else - [sanitize_result(completion_result)] + sanitized = sanitize_result(completion_result) + result[:diff] = parse_diff(input, sanitized) if completion_prompt.diff? + [sanitized] end ) @@ -79,25 +85,15 @@ module DiscourseAi private - def sanitize_result(result) - tags_to_remove = %w[ - - - - - - - - - - - - - - - ] + SANITIZE_REGEX_STR = + %w[term context topic replyTo input output result] + .map { |tag| "<#{tag}>\\n?|\\n?" } + .join("|") - result.dup.tap { |dup_result| tags_to_remove.each { |tag| dup_result.gsub!(tag, "") } } + SANITIZE_REGEX = Regexp.new(SANITIZE_REGEX_STR, Regexp::IGNORECASE | Regexp::MULTILINE) + + def sanitize_result(result) + result.gsub(SANITIZE_REGEX, "") end def publish_update(channel, payload, user) diff --git a/lib/ai_helper/painter.rb b/lib/ai_helper/painter.rb index 0b1041dd..ec8358ee 100644 --- a/lib/ai_helper/painter.rb +++ b/lib/ai_helper/painter.rb @@ -38,7 +38,10 @@ module DiscourseAi You'll find the post between XML tags. TEXT - DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model).completion!(prompt, user) + DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model).generate( + prompt, + user: user, + ) end end end diff --git a/lib/automation/llm_triage.rb b/lib/automation/llm_triage.rb index a1393bc4..ad83dbe6 100644 --- a/lib/automation/llm_triage.rb +++ b/lib/automation/llm_triage.rb @@ -31,24 +31,16 @@ module DiscourseAi result = nil llm = DiscourseAi::Completions::Llm.proxy(model) - key = - if model.include?("claude") - :max_tokens_to_sample - else - :max_tokens - end - prompt = { - insts: filled_system_prompt, - params: { - model => { - key => (llm.tokenizer.tokenize(search_for_text).length * 2 + 10), - :temperature => 0, - }, - }, - } + prompt = { insts: filled_system_prompt } - result = llm.completion!(prompt, Discourse.system_user) + result = + llm.generate( + prompt, + temperature: 0, + max_tokens: llm.tokenizer.tokenize(search_for_text).length * 2 + 10, + user: Discourse.system_user, + ) if result.strip == search_for_text.strip user = User.find_by_username(canned_reply_user) if canned_reply_user.present? diff --git a/lib/automation/report_runner.rb b/lib/automation/report_runner.rb index 08a4548e..c942c3d5 100644 --- a/lib/automation/report_runner.rb +++ b/lib/automation/report_runner.rb @@ -115,18 +115,13 @@ module DiscourseAi insts: "You are a helpful bot specializing in summarizing activity on Discourse sites", input: input, final_insts: "Here is the report I generated for you", - params: { - @model => { - temperature: 0, - }, - }, } result = +"" puts if Rails.env.development? && @debug_mode - @llm.completion!(prompt, Discourse.system_user) do |response| + @llm.generate(prompt, temperature: 0, user: Discourse.system_user) do |response| print response if Rails.env.development? && @debug_mode result << response end diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index 0c9676a8..033f29ab 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -95,7 +95,7 @@ module DiscourseAi def max_prompt_tokens # provide a buffer of 120 tokens - our function counting is not # 100% accurate and getting numbers to align exactly is very hard - buffer = (opts[:max_tokens_to_sample] || 2500) + 50 + buffer = (opts[:max_tokens] || 2500) + 50 if tools.present? # note this is about 100 tokens over, OpenAI have a more optimal representation diff --git a/lib/completions/dialects/mixtral.rb b/lib/completions/dialects/mixtral.rb index 75e0f954..464a1ac4 100644 --- a/lib/completions/dialects/mixtral.rb +++ b/lib/completions/dialects/mixtral.rb @@ -27,7 +27,7 @@ module DiscourseAi if prompt[:examples] prompt[:examples].each do |example_pair| mixtral_prompt << "[INST] #{example_pair.first} [/INST]\n" - mixtral_prompt << "#{example_pair.second}\n" + mixtral_prompt << "#{example_pair.second}\n" end end diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index 3846d7e4..d98990d8 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -8,8 +8,24 @@ module DiscourseAi %w[claude-instant-1 claude-2].include?(model_name) end + def normalize_model_params(model_params) + model_params = model_params.dup + + # temperature, stop_sequences are already supported + # + if model_params[:max_tokens] + model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens) + end + + model_params + end + def default_options - { max_tokens_to_sample: 2000, model: model } + { + model: model, + max_tokens_to_sample: 2_000, + stop_sequences: ["\n\nHuman:", ""], + } end def provider_id diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index 98f29634..5d559b75 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -13,8 +13,24 @@ module DiscourseAi SiteSetting.ai_bedrock_region.present? end + def normalize_model_params(model_params) + model_params = model_params.dup + + # temperature, stop_sequences are already supported + # + if model_params[:max_tokens] + model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens) + end + + model_params + end + def default_options - { max_tokens_to_sample: 2_000, stop_sequences: ["\n\nHuman:", ""] } + { + model: model, + max_tokens_to_sample: 2_000, + stop_sequences: ["\n\nHuman:", ""], + } end def provider_id diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index da433c18..a9768e47 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -32,6 +32,8 @@ module DiscourseAi end def perform_completion!(dialect, user, model_params = {}) + model_params = normalize_model_params(model_params) + @streaming_mode = block_given? prompt = dialect.translate @@ -199,6 +201,11 @@ module DiscourseAi attr_reader :model + # should normalize temperature, max_tokens, stop_words to endpoint specific values + def normalize_model_params(model_params) + raise NotImplementedError + end + def model_uri raise NotImplementedError end @@ -262,7 +269,7 @@ module DiscourseAi function_buffer.at("tool_id").inner_html = tool_name end - read_parameters = + _read_parameters = read_function .at("parameters") .elements diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index 56d2b913..ab04961e 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -16,6 +16,11 @@ module DiscourseAi @prompt = nil end + def normalize_model_params(model_params) + # max_tokens, temperature, stop_sequences are already supported + model_params + end + attr_reader :responses, :completions, :prompt def perform_completion!(prompt, _user, _model_params) diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index 231309b2..9a1e3711 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -9,7 +9,23 @@ module DiscourseAi end def default_options - {} + { generationConfig: {} } + end + + def normalize_model_params(model_params) + model_params = model_params.dup + + if model_params[:stop_sequences] + model_params[:stopSequences] = model_params.delete(:stop_sequences) + end + + if model_params[:temperature] + model_params[:maxOutputTokens] = model_params.delete(:max_tokens) + end + + # temperature already supported + + model_params end def provider_id @@ -27,9 +43,11 @@ module DiscourseAi def prepare_payload(prompt, model_params, dialect) default_options - .merge(model_params) .merge(contents: prompt) - .tap { |payload| payload[:tools] = dialect.tools if dialect.tools.present? } + .tap do |payload| + payload[:tools] = dialect.tools if dialect.tools.present? + payload[:generationConfig].merge!(model_params) if model_params.present? + end end def prepare_request(payload) diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb index 22fd39f5..4a0f2875 100644 --- a/lib/completions/endpoints/hugging_face.rb +++ b/lib/completions/endpoints/hugging_face.rb @@ -19,6 +19,20 @@ module DiscourseAi { parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } } end + def normalize_model_params(model_params) + model_params = model_params.dup + + if model_params[:stop_sequences] + model_params[:stop] = model_params.delete(:stop_sequences) + end + + if model_params[:max_tokens] + model_params[:max_new_tokens] = model_params.delete(:max_tokens) + end + + model_params + end + def provider_id AiApiAuditLog::Provider::HuggingFaceTextGeneration end diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 2a1d29cb..8e760083 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -15,6 +15,17 @@ module DiscourseAi ].include?(model_name) end + def normalize_model_params(model_params) + model_params = model_params.dup + + # max_tokens, temperature are already supported + if model_params[:stop_sequences] + model_params[:stop] = model_params.delete(:stop_sequences) + end + + model_params + end + def default_options { model: model == "gpt-4-turbo" ? "gpt-4-1106-preview" : model } end diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb index 48db69ed..71385e94 100644 --- a/lib/completions/endpoints/vllm.rb +++ b/lib/completions/endpoints/vllm.rb @@ -10,6 +10,17 @@ module DiscourseAi ) end + def normalize_model_params(model_params) + model_params = model_params.dup + + # max_tokens, temperature are already supported + if model_params[:stop_sequences] + model_params[:stop] = model_params.delete(:stop_sequences) + end + + model_params + end + def default_options { max_tokens: 2000, model: model } end @@ -39,7 +50,6 @@ module DiscourseAi def prepare_request(payload) headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } - Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 16e1c96b..462f0b89 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -98,11 +98,24 @@ module DiscourseAi # # # - def completion!(generic_prompt, user, &partial_read_blk) - model_params = generic_prompt.dig(:params, model_name) || {} + def generate( + generic_prompt, + temperature: nil, + max_tokens: nil, + stop_sequences: nil, + user:, + &partial_read_blk + ) + model_params = { + temperature: temperature, + max_tokens: max_tokens, + stop_sequences: stop_sequences, + } + + model_params.merge!(generic_prompt.dig(:params, model_name) || {}) + model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? } dialect = dialect_klass.new(generic_prompt, model_name, opts: model_params) - gateway.perform_completion!(dialect, user, model_params, &partial_read_blk) end diff --git a/lib/embeddings/semantic_search.rb b/lib/embeddings/semantic_search.rb index 36f64051..9db8b825 100644 --- a/lib/embeddings/semantic_search.rb +++ b/lib/embeddings/semantic_search.rb @@ -112,7 +112,7 @@ module DiscourseAi llm_response = DiscourseAi::Completions::Llm.proxy( SiteSetting.ai_embeddings_semantic_search_hyde_model, - ).completion!(prompt, @guardian.user) + ).generate(prompt, user: @guardian.user) Nokogiri::HTML5.fragment(llm_response).at("ai")&.text&.presence || llm_response end diff --git a/lib/summarization/strategies/fold_content.rb b/lib/summarization/strategies/fold_content.rb index dedc4209..ce10afb0 100644 --- a/lib/summarization/strategies/fold_content.rb +++ b/lib/summarization/strategies/fold_content.rb @@ -99,7 +99,7 @@ module DiscourseAi def summarize_single(llm, text, user, opts, &on_partial_blk) prompt = summarization_prompt(text, opts) - llm.completion!(prompt, user, &on_partial_blk) + llm.generate(prompt, user: user, &on_partial_blk) end def summarize_in_chunks(llm, chunks, user, opts) @@ -107,7 +107,7 @@ module DiscourseAi prompt = summarization_prompt(chunk[:summary], opts) prompt[:post_insts] = "Don't use more than 400 words for the summary." - chunk[:summary] = llm.completion!(prompt, user) + chunk[:summary] = llm.generate(prompt, user: user) chunk end end @@ -117,7 +117,7 @@ module DiscourseAi prompt[:insts] = <<~TEXT You are a summarization bot that effectively concatenates disjoint summaries, creating a cohesive narrative. The narrative you create is in the form of one or multiple paragraphs. - Your reply MUST BE a single concatenated summary using the summaries I'll provide to you. + Your reply MUST BE a single concatenated summary using the summaries I'll provide to you. I'm NOT interested in anything other than the concatenated summary, don't include additional text or comments. You understand and generate Discourse forum Markdown. You format the response, including links, using Markdown. @@ -131,7 +131,7 @@ module DiscourseAi TEXT - llm.completion!(prompt, user, &on_partial_blk) + llm.generate(prompt, user: user, &on_partial_blk) end def summarization_prompt(input, opts) diff --git a/spec/lib/completions/dialects/mixtral_spec.rb b/spec/lib/completions/dialects/mixtral_spec.rb index e45ad950..4f1a5247 100644 --- a/spec/lib/completions/dialects/mixtral_spec.rb +++ b/spec/lib/completions/dialects/mixtral_spec.rb @@ -74,7 +74,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do #{prompt[:post_insts]} [/INST] Ok [INST] #{prompt[:examples][0][0]} [/INST] - #{prompt[:examples][0][1]} + #{prompt[:examples][0][1]} [INST] #{prompt[:input]} [/INST] TEXT @@ -102,7 +102,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do Here are the tools available: - + #{dialect.tools} #{prompt[:post_insts]} diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index 7caf1a1f..00695830 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -183,7 +183,7 @@ data: [D|ONE] partials = [] llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") - llm.completion!({ insts: "test" }, Discourse.system_user) { |partial| partials << partial } + llm.generate({ insts: "test" }, user: Discourse.system_user) { |partial| partials << partial } expect(partials.join).to eq("test,test2,test3,test4") end @@ -212,7 +212,7 @@ data: [D|ONE] partials = [] llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") - llm.completion!({ insts: "test" }, Discourse.system_user) { |partial| partials << partial } + llm.generate({ insts: "test" }, user: Discourse.system_user) { |partial| partials << partial } expect(partials.join).to eq("test,test1,test2,test3,test4") end diff --git a/spec/lib/completions/llm_spec.rb b/spec/lib/completions/llm_spec.rb index 3a67ff22..41df7fc8 100644 --- a/spec/lib/completions/llm_spec.rb +++ b/spec/lib/completions/llm_spec.rb @@ -21,7 +21,7 @@ RSpec.describe DiscourseAi::Completions::Llm do end end - describe "#completion!" do + describe "#generate" do let(:prompt) do { insts: <<~TEXT, @@ -52,7 +52,7 @@ RSpec.describe DiscourseAi::Completions::Llm do context "when getting the full response" do it "processes the prompt and return the response" do - llm_response = llm.completion!(prompt, user) + llm_response = llm.generate(prompt, user: user) expect(llm_response).to eq(canned_response.responses[0]) end @@ -62,7 +62,7 @@ RSpec.describe DiscourseAi::Completions::Llm do it "processes the prompt and call the given block with the partial response" do llm_response = +"" - llm.completion!(prompt, user) { |partial, cancel_fn| llm_response << partial } + llm.generate(prompt, user: user) { |partial, cancel_fn| llm_response << partial } expect(llm_response).to eq(canned_response.responses[0]) end diff --git a/spec/requests/ai_helper/assistant_controller_spec.rb b/spec/requests/ai_helper/assistant_controller_spec.rb index 319844a1..5e034caa 100644 --- a/spec/requests/ai_helper/assistant_controller_spec.rb +++ b/spec/requests/ai_helper/assistant_controller_spec.rb @@ -59,7 +59,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do it "returns a generic error when the completion call fails" do DiscourseAi::Completions::Llm .any_instance - .expects(:completion!) + .expects(:generate) .raises(DiscourseAi::Completions::Endpoints::Base::CompletionFailed) post "/discourse-ai/ai-helper/suggest", params: { mode: mode, text: text_to_proofread }