diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb
index 559d1382..c3e08f4c 100644
--- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb
+++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb
@@ -43,7 +43,7 @@ module DiscourseAi
),
status: 200
end
- rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e
+ rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502
end
@@ -63,7 +63,7 @@ module DiscourseAi
),
status: 200
end
- rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e
+ rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502
end
@@ -111,7 +111,7 @@ module DiscourseAi
)
render json: { success: true }, status: 200
- rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e
+ rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502
end
diff --git a/app/models/completion_prompt.rb b/app/models/completion_prompt.rb
index 12614278..183b7adb 100644
--- a/app/models/completion_prompt.rb
+++ b/app/models/completion_prompt.rb
@@ -67,6 +67,8 @@ end
# created_at :datetime not null
# updated_at :datetime not null
# messages :jsonb
+# temperature :integer
+# stop_sequences :string is an Array
#
# Indexes
#
diff --git a/db/fixtures/ai_helper/603_completion_prompts.rb b/db/fixtures/ai_helper/603_completion_prompts.rb
index d08f78e5..d9f40c0c 100644
--- a/db/fixtures/ai_helper/603_completion_prompts.rb
+++ b/db/fixtures/ai_helper/603_completion_prompts.rb
@@ -5,47 +5,55 @@ CompletionPrompt.seed do |cp|
cp.id = -301
cp.name = "translate"
cp.prompt_type = CompletionPrompt.prompt_types[:text]
- cp.messages = { insts: <<~TEXT }
- I want you to act as an English translator, spelling corrector and improver. I will write to you
- in any language and you will detect the language, translate it and answer in the corrected and
- improved version of my text, in English. I want you to replace my simplified A0-level words and
- sentences with more beautiful and elegant, upper level English words and sentences.
- Keep the meaning same, but make them more literary. I want you to only reply the correction,
- the improvements and nothing else, do not write explanations.
- You will find the text between XML tags.
- TEXT
+ cp.stop_sequences = ["\n", ""]
+ cp.temperature = 0.2
+ cp.messages = {
+ insts: <<~TEXT,
+ I want you to act as an English translator, spelling corrector and improver. I will write to you
+ in any language and you will detect the language, translate it and answer in the corrected and
+ improved version of my text, in English. I want you to replace my simplified A0-level words and
+ sentences with more beautiful and elegant, upper level English words and sentences.
+ Keep the meaning same, but make them more literary. I want you to only reply the correction,
+ the improvements and nothing else, do not write explanations.
+ You will find the text between XML tags.
+ Include your translation between XML tags.
+ TEXT
+ examples: [
+ ["Hello world", ""],
+ ["Bonjour le monde", ""],
+ ],
+ }
end
CompletionPrompt.seed do |cp|
cp.id = -303
cp.name = "proofread"
cp.prompt_type = CompletionPrompt.prompt_types[:diff]
+ cp.temperature = 0
+ cp.stop_sequences = ["\n"]
cp.messages = {
insts: <<~TEXT,
You are a markdown proofreader. You correct egregious typos and phrasing issues but keep the user's original voice.
You do not touch code blocks. I will provide you with text to proofread. If nothing needs fixing, then you will echo the text back.
- Optionally, a user can specify intensity. Intensity 10 is a pedantic English teacher correcting the text.
- Intensity 1 is a minimal proofreader. By default, you operate at intensity 1.
You will find the text between XML tags.
+ You will ALWAYS return the corrected text between XML tags.
TEXT
examples: [
[
"![amazing car|100x100, 22%](upload://hapy.png)",
- "![Amazing car|100x100, 22%](upload://hapy.png)",
+ "",
],
[<<~TEXT, "The rain in Spain, stays mainly in the Plane."],
- Intensity 1:
The rain in spain stays mainly in the plane.
TEXT
[
- "The rain in Spain, stays mainly in the Plane.",
- "The rain in Spain, stays mainly in the Plane.",
+ "The rain in Spain, stays mainly in the Plane.",
+ "",
],
[<<~TEXT, <<~TEXT],
- Intensity 1:
Hello,
Sometimes the logo isn't changing automatically when color scheme changes.
@@ -53,13 +61,14 @@ CompletionPrompt.seed do |cp|
![Screen Recording 2023-03-17 at 18.04.22|video](upload://2rcVL0ZMxHPNtPWQbZjwufKpWVU.mov)
TEXT
+
TEXT
[<<~TEXT, <<~TEXT],
- Intensity 1:
Any ideas what is wrong with this peace of cod?
> This quot contains a typo
```ruby
@@ -69,6 +78,7 @@ CompletionPrompt.seed do |cp|
```
TEXT
+
TEXT
],
}
@@ -85,15 +96,19 @@ CompletionPrompt.seed do |cp|
cp.id = -304
cp.name = "markdown_table"
cp.prompt_type = CompletionPrompt.prompt_types[:diff]
+ cp.temperature = 0.5
+ cp.stop_sequences = ["\n"]
cp.messages = {
insts: <<~TEXT,
You are a markdown table formatter, I will provide you text inside XML tags and you will format it into a markdown table
TEXT
examples: [
["sam,joe,jane\nage: 22| 10|11", <<~TEXT],
+
TEXT
[<<~TEXT, <<~TEXT],
@@ -102,22 +117,26 @@ CompletionPrompt.seed do |cp|
fred: height 22
TEXT
+
TEXT
[<<~TEXT, <<~TEXT],
- chrome 22ms (first load 10ms)
- firefox 10ms (first load: 9ms)
+ chrome 22ms (first load 10ms)
+ firefox 10ms (first load: 9ms)
TEXT
+
TEXT
],
}
diff --git a/db/migrate/20240104013944_add_params_to_completion_prompt.rb b/db/migrate/20240104013944_add_params_to_completion_prompt.rb
new file mode 100644
index 00000000..7d179e13
--- /dev/null
+++ b/db/migrate/20240104013944_add_params_to_completion_prompt.rb
@@ -0,0 +1,8 @@
+# frozen_string_literal: true
+
+class AddParamsToCompletionPrompt < ActiveRecord::Migration[7.0]
+ def change
+ add_column :completion_prompts, :temperature, :integer
+ add_column :completion_prompts, :stop_sequences, :string, array: true
+ end
+end
diff --git a/lib/ai_helper/assistant.rb b/lib/ai_helper/assistant.rb
index e9cbc96c..038ef4bd 100644
--- a/lib/ai_helper/assistant.rb
+++ b/lib/ai_helper/assistant.rb
@@ -36,20 +36,26 @@ module DiscourseAi
llm = DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model)
generic_prompt = completion_prompt.messages_with_input(input)
- llm.completion!(generic_prompt, user, &block)
+ llm.generate(
+ generic_prompt,
+ user: user,
+ temperature: completion_prompt.temperature,
+ stop_sequences: completion_prompt.stop_sequences,
+ &block
+ )
end
def generate_and_send_prompt(completion_prompt, input, user)
completion_result = generate_prompt(completion_prompt, input, user)
result = { type: completion_prompt.prompt_type }
- result[:diff] = parse_diff(input, completion_result) if completion_prompt.diff?
-
result[:suggestions] = (
if completion_prompt.list?
parse_list(completion_result).map { |suggestion| sanitize_result(suggestion) }
else
- [sanitize_result(completion_result)]
+ sanitized = sanitize_result(completion_result)
+ result[:diff] = parse_diff(input, sanitized) if completion_prompt.diff?
+ [sanitized]
end
)
@@ -79,25 +85,15 @@ module DiscourseAi
private
- def sanitize_result(result)
- tags_to_remove = %w[
-
-
-
-
-
-
-
-
-
-
-
-
-
- ]
+ SANITIZE_REGEX_STR =
+ %w[term context topic replyTo input output result]
+ .map { |tag| "<#{tag}>\\n?|\\n?#{tag}>" }
+ .join("|")
- result.dup.tap { |dup_result| tags_to_remove.each { |tag| dup_result.gsub!(tag, "") } }
+ SANITIZE_REGEX = Regexp.new(SANITIZE_REGEX_STR, Regexp::IGNORECASE | Regexp::MULTILINE)
+
+ def sanitize_result(result)
+ result.gsub(SANITIZE_REGEX, "")
end
def publish_update(channel, payload, user)
diff --git a/lib/ai_helper/painter.rb b/lib/ai_helper/painter.rb
index 0b1041dd..ec8358ee 100644
--- a/lib/ai_helper/painter.rb
+++ b/lib/ai_helper/painter.rb
@@ -38,7 +38,10 @@ module DiscourseAi
You'll find the post between XML tags.
TEXT
- DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model).completion!(prompt, user)
+ DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model).generate(
+ prompt,
+ user: user,
+ )
end
end
end
diff --git a/lib/automation/llm_triage.rb b/lib/automation/llm_triage.rb
index a1393bc4..ad83dbe6 100644
--- a/lib/automation/llm_triage.rb
+++ b/lib/automation/llm_triage.rb
@@ -31,24 +31,16 @@ module DiscourseAi
result = nil
llm = DiscourseAi::Completions::Llm.proxy(model)
- key =
- if model.include?("claude")
- :max_tokens_to_sample
- else
- :max_tokens
- end
- prompt = {
- insts: filled_system_prompt,
- params: {
- model => {
- key => (llm.tokenizer.tokenize(search_for_text).length * 2 + 10),
- :temperature => 0,
- },
- },
- }
+ prompt = { insts: filled_system_prompt }
- result = llm.completion!(prompt, Discourse.system_user)
+ result =
+ llm.generate(
+ prompt,
+ temperature: 0,
+ max_tokens: llm.tokenizer.tokenize(search_for_text).length * 2 + 10,
+ user: Discourse.system_user,
+ )
if result.strip == search_for_text.strip
user = User.find_by_username(canned_reply_user) if canned_reply_user.present?
diff --git a/lib/automation/report_runner.rb b/lib/automation/report_runner.rb
index 08a4548e..c942c3d5 100644
--- a/lib/automation/report_runner.rb
+++ b/lib/automation/report_runner.rb
@@ -115,18 +115,13 @@ module DiscourseAi
insts: "You are a helpful bot specializing in summarizing activity on Discourse sites",
input: input,
final_insts: "Here is the report I generated for you",
- params: {
- @model => {
- temperature: 0,
- },
- },
}
result = +""
puts if Rails.env.development? && @debug_mode
- @llm.completion!(prompt, Discourse.system_user) do |response|
+ @llm.generate(prompt, temperature: 0, user: Discourse.system_user) do |response|
print response if Rails.env.development? && @debug_mode
result << response
end
diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb
index 0c9676a8..033f29ab 100644
--- a/lib/completions/dialects/chat_gpt.rb
+++ b/lib/completions/dialects/chat_gpt.rb
@@ -95,7 +95,7 @@ module DiscourseAi
def max_prompt_tokens
# provide a buffer of 120 tokens - our function counting is not
# 100% accurate and getting numbers to align exactly is very hard
- buffer = (opts[:max_tokens_to_sample] || 2500) + 50
+ buffer = (opts[:max_tokens] || 2500) + 50
if tools.present?
# note this is about 100 tokens over, OpenAI have a more optimal representation
diff --git a/lib/completions/dialects/mixtral.rb b/lib/completions/dialects/mixtral.rb
index 75e0f954..464a1ac4 100644
--- a/lib/completions/dialects/mixtral.rb
+++ b/lib/completions/dialects/mixtral.rb
@@ -27,7 +27,7 @@ module DiscourseAi
if prompt[:examples]
prompt[:examples].each do |example_pair|
mixtral_prompt << "[INST] #{example_pair.first} [/INST]\n"
- mixtral_prompt << "#{example_pair.second}\n"
+ mixtral_prompt << "#{example_pair.second}\n"
end
end
diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb
index 3846d7e4..d98990d8 100644
--- a/lib/completions/endpoints/anthropic.rb
+++ b/lib/completions/endpoints/anthropic.rb
@@ -8,8 +8,24 @@ module DiscourseAi
%w[claude-instant-1 claude-2].include?(model_name)
end
+ def normalize_model_params(model_params)
+ model_params = model_params.dup
+
+ # temperature, stop_sequences are already supported
+ #
+ if model_params[:max_tokens]
+ model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens)
+ end
+
+ model_params
+ end
+
def default_options
- { max_tokens_to_sample: 2000, model: model }
+ {
+ model: model,
+ max_tokens_to_sample: 2_000,
+ stop_sequences: ["\n\nHuman:", ""],
+ }
end
def provider_id
diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb
index 98f29634..5d559b75 100644
--- a/lib/completions/endpoints/aws_bedrock.rb
+++ b/lib/completions/endpoints/aws_bedrock.rb
@@ -13,8 +13,24 @@ module DiscourseAi
SiteSetting.ai_bedrock_region.present?
end
+ def normalize_model_params(model_params)
+ model_params = model_params.dup
+
+ # temperature, stop_sequences are already supported
+ #
+ if model_params[:max_tokens]
+ model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens)
+ end
+
+ model_params
+ end
+
def default_options
- { max_tokens_to_sample: 2_000, stop_sequences: ["\n\nHuman:", ""] }
+ {
+ model: model,
+ max_tokens_to_sample: 2_000,
+ stop_sequences: ["\n\nHuman:", ""],
+ }
end
def provider_id
diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb
index da433c18..a9768e47 100644
--- a/lib/completions/endpoints/base.rb
+++ b/lib/completions/endpoints/base.rb
@@ -32,6 +32,8 @@ module DiscourseAi
end
def perform_completion!(dialect, user, model_params = {})
+ model_params = normalize_model_params(model_params)
+
@streaming_mode = block_given?
prompt = dialect.translate
@@ -199,6 +201,11 @@ module DiscourseAi
attr_reader :model
+ # should normalize temperature, max_tokens, stop_words to endpoint specific values
+ def normalize_model_params(model_params)
+ raise NotImplementedError
+ end
+
def model_uri
raise NotImplementedError
end
@@ -262,7 +269,7 @@ module DiscourseAi
function_buffer.at("tool_id").inner_html = tool_name
end
- read_parameters =
+ _read_parameters =
read_function
.at("parameters")
.elements
diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb
index 56d2b913..ab04961e 100644
--- a/lib/completions/endpoints/canned_response.rb
+++ b/lib/completions/endpoints/canned_response.rb
@@ -16,6 +16,11 @@ module DiscourseAi
@prompt = nil
end
+ def normalize_model_params(model_params)
+ # max_tokens, temperature, stop_sequences are already supported
+ model_params
+ end
+
attr_reader :responses, :completions, :prompt
def perform_completion!(prompt, _user, _model_params)
diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb
index 231309b2..9a1e3711 100644
--- a/lib/completions/endpoints/gemini.rb
+++ b/lib/completions/endpoints/gemini.rb
@@ -9,7 +9,23 @@ module DiscourseAi
end
def default_options
- {}
+ { generationConfig: {} }
+ end
+
+ def normalize_model_params(model_params)
+ model_params = model_params.dup
+
+ if model_params[:stop_sequences]
+ model_params[:stopSequences] = model_params.delete(:stop_sequences)
+ end
+
+ if model_params[:temperature]
+ model_params[:maxOutputTokens] = model_params.delete(:max_tokens)
+ end
+
+ # temperature already supported
+
+ model_params
end
def provider_id
@@ -27,9 +43,11 @@ module DiscourseAi
def prepare_payload(prompt, model_params, dialect)
default_options
- .merge(model_params)
.merge(contents: prompt)
- .tap { |payload| payload[:tools] = dialect.tools if dialect.tools.present? }
+ .tap do |payload|
+ payload[:tools] = dialect.tools if dialect.tools.present?
+ payload[:generationConfig].merge!(model_params) if model_params.present?
+ end
end
def prepare_request(payload)
diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb
index 22fd39f5..4a0f2875 100644
--- a/lib/completions/endpoints/hugging_face.rb
+++ b/lib/completions/endpoints/hugging_face.rb
@@ -19,6 +19,20 @@ module DiscourseAi
{ parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } }
end
+ def normalize_model_params(model_params)
+ model_params = model_params.dup
+
+ if model_params[:stop_sequences]
+ model_params[:stop] = model_params.delete(:stop_sequences)
+ end
+
+ if model_params[:max_tokens]
+ model_params[:max_new_tokens] = model_params.delete(:max_tokens)
+ end
+
+ model_params
+ end
+
def provider_id
AiApiAuditLog::Provider::HuggingFaceTextGeneration
end
diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb
index 2a1d29cb..8e760083 100644
--- a/lib/completions/endpoints/open_ai.rb
+++ b/lib/completions/endpoints/open_ai.rb
@@ -15,6 +15,17 @@ module DiscourseAi
].include?(model_name)
end
+ def normalize_model_params(model_params)
+ model_params = model_params.dup
+
+ # max_tokens, temperature are already supported
+ if model_params[:stop_sequences]
+ model_params[:stop] = model_params.delete(:stop_sequences)
+ end
+
+ model_params
+ end
+
def default_options
{ model: model == "gpt-4-turbo" ? "gpt-4-1106-preview" : model }
end
diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb
index 48db69ed..71385e94 100644
--- a/lib/completions/endpoints/vllm.rb
+++ b/lib/completions/endpoints/vllm.rb
@@ -10,6 +10,17 @@ module DiscourseAi
)
end
+ def normalize_model_params(model_params)
+ model_params = model_params.dup
+
+ # max_tokens, temperature are already supported
+ if model_params[:stop_sequences]
+ model_params[:stop] = model_params.delete(:stop_sequences)
+ end
+
+ model_params
+ end
+
def default_options
{ max_tokens: 2000, model: model }
end
@@ -39,7 +50,6 @@ module DiscourseAi
def prepare_request(payload)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
-
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb
index 16e1c96b..462f0b89 100644
--- a/lib/completions/llm.rb
+++ b/lib/completions/llm.rb
@@ -98,11 +98,24 @@ module DiscourseAi
#
#
#
- def completion!(generic_prompt, user, &partial_read_blk)
- model_params = generic_prompt.dig(:params, model_name) || {}
+ def generate(
+ generic_prompt,
+ temperature: nil,
+ max_tokens: nil,
+ stop_sequences: nil,
+ user:,
+ &partial_read_blk
+ )
+ model_params = {
+ temperature: temperature,
+ max_tokens: max_tokens,
+ stop_sequences: stop_sequences,
+ }
+
+ model_params.merge!(generic_prompt.dig(:params, model_name) || {})
+ model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
dialect = dialect_klass.new(generic_prompt, model_name, opts: model_params)
-
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
end
diff --git a/lib/embeddings/semantic_search.rb b/lib/embeddings/semantic_search.rb
index 36f64051..9db8b825 100644
--- a/lib/embeddings/semantic_search.rb
+++ b/lib/embeddings/semantic_search.rb
@@ -112,7 +112,7 @@ module DiscourseAi
llm_response =
DiscourseAi::Completions::Llm.proxy(
SiteSetting.ai_embeddings_semantic_search_hyde_model,
- ).completion!(prompt, @guardian.user)
+ ).generate(prompt, user: @guardian.user)
Nokogiri::HTML5.fragment(llm_response).at("ai")&.text&.presence || llm_response
end
diff --git a/lib/summarization/strategies/fold_content.rb b/lib/summarization/strategies/fold_content.rb
index dedc4209..ce10afb0 100644
--- a/lib/summarization/strategies/fold_content.rb
+++ b/lib/summarization/strategies/fold_content.rb
@@ -99,7 +99,7 @@ module DiscourseAi
def summarize_single(llm, text, user, opts, &on_partial_blk)
prompt = summarization_prompt(text, opts)
- llm.completion!(prompt, user, &on_partial_blk)
+ llm.generate(prompt, user: user, &on_partial_blk)
end
def summarize_in_chunks(llm, chunks, user, opts)
@@ -107,7 +107,7 @@ module DiscourseAi
prompt = summarization_prompt(chunk[:summary], opts)
prompt[:post_insts] = "Don't use more than 400 words for the summary."
- chunk[:summary] = llm.completion!(prompt, user)
+ chunk[:summary] = llm.generate(prompt, user: user)
chunk
end
end
@@ -117,7 +117,7 @@ module DiscourseAi
prompt[:insts] = <<~TEXT
You are a summarization bot that effectively concatenates disjoint summaries, creating a cohesive narrative.
The narrative you create is in the form of one or multiple paragraphs.
- Your reply MUST BE a single concatenated summary using the summaries I'll provide to you.
+ Your reply MUST BE a single concatenated summary using the summaries I'll provide to you.
I'm NOT interested in anything other than the concatenated summary, don't include additional text or comments.
You understand and generate Discourse forum Markdown.
You format the response, including links, using Markdown.
@@ -131,7 +131,7 @@ module DiscourseAi
TEXT
- llm.completion!(prompt, user, &on_partial_blk)
+ llm.generate(prompt, user: user, &on_partial_blk)
end
def summarization_prompt(input, opts)
diff --git a/spec/lib/completions/dialects/mixtral_spec.rb b/spec/lib/completions/dialects/mixtral_spec.rb
index e45ad950..4f1a5247 100644
--- a/spec/lib/completions/dialects/mixtral_spec.rb
+++ b/spec/lib/completions/dialects/mixtral_spec.rb
@@ -74,7 +74,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do
#{prompt[:post_insts]}
[/INST] Ok
[INST] #{prompt[:examples][0][0]} [/INST]
- #{prompt[:examples][0][1]}
+ #{prompt[:examples][0][1]}
[INST] #{prompt[:input]} [/INST]
TEXT
@@ -102,7 +102,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do
Here are the tools available:
-
+
#{dialect.tools}
#{prompt[:post_insts]}
diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb
index 7caf1a1f..00695830 100644
--- a/spec/lib/completions/endpoints/open_ai_spec.rb
+++ b/spec/lib/completions/endpoints/open_ai_spec.rb
@@ -183,7 +183,7 @@ data: [D|ONE]
partials = []
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
- llm.completion!({ insts: "test" }, Discourse.system_user) { |partial| partials << partial }
+ llm.generate({ insts: "test" }, user: Discourse.system_user) { |partial| partials << partial }
expect(partials.join).to eq("test,test2,test3,test4")
end
@@ -212,7 +212,7 @@ data: [D|ONE]
partials = []
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
- llm.completion!({ insts: "test" }, Discourse.system_user) { |partial| partials << partial }
+ llm.generate({ insts: "test" }, user: Discourse.system_user) { |partial| partials << partial }
expect(partials.join).to eq("test,test1,test2,test3,test4")
end
diff --git a/spec/lib/completions/llm_spec.rb b/spec/lib/completions/llm_spec.rb
index 3a67ff22..41df7fc8 100644
--- a/spec/lib/completions/llm_spec.rb
+++ b/spec/lib/completions/llm_spec.rb
@@ -21,7 +21,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
end
end
- describe "#completion!" do
+ describe "#generate" do
let(:prompt) do
{
insts: <<~TEXT,
@@ -52,7 +52,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
context "when getting the full response" do
it "processes the prompt and return the response" do
- llm_response = llm.completion!(prompt, user)
+ llm_response = llm.generate(prompt, user: user)
expect(llm_response).to eq(canned_response.responses[0])
end
@@ -62,7 +62,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
it "processes the prompt and call the given block with the partial response" do
llm_response = +""
- llm.completion!(prompt, user) { |partial, cancel_fn| llm_response << partial }
+ llm.generate(prompt, user: user) { |partial, cancel_fn| llm_response << partial }
expect(llm_response).to eq(canned_response.responses[0])
end
diff --git a/spec/requests/ai_helper/assistant_controller_spec.rb b/spec/requests/ai_helper/assistant_controller_spec.rb
index 319844a1..5e034caa 100644
--- a/spec/requests/ai_helper/assistant_controller_spec.rb
+++ b/spec/requests/ai_helper/assistant_controller_spec.rb
@@ -59,7 +59,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
it "returns a generic error when the completion call fails" do
DiscourseAi::Completions::Llm
.any_instance
- .expects(:completion!)
+ .expects(:generate)
.raises(DiscourseAi::Completions::Endpoints::Base::CompletionFailed)
post "/discourse-ai/ai-helper/suggest", params: { mode: mode, text: text_to_proofread }