FIX: AI helper not working correctly with mixtral (#399)

* FIX: AI helper not working correctly with mixtral

This PR introduces a new function on the generic llm called #generate

This will replace the implementation of completion!

#generate introduces a new way to pass temperature, max_tokens and stop_sequences

Then LLM implementers need to implement #normalize_model_params to
ensure the generic names match the LLM specific endpoint

This also adds temperature and stop_sequences to completion_prompts
this allows for much more robust completion prompts

* port everything over to #generate

* Fix translation

- On anthropic this no longer throws random "This is your translation:"
- On mixtral this actually works

* fix markdown table generation as well
This commit is contained in:
Sam 2024-01-04 23:53:47 +11:00 committed by GitHub
parent 0483e0bb88
commit 03fc94684b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 217 additions and 92 deletions

View File

@ -43,7 +43,7 @@ module DiscourseAi
), ),
status: 200 status: 200
end end
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"), render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502 status: 502
end end
@ -63,7 +63,7 @@ module DiscourseAi
), ),
status: 200 status: 200
end end
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"), render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502 status: 502
end end
@ -111,7 +111,7 @@ module DiscourseAi
) )
render json: { success: true }, status: 200 render json: { success: true }, status: 200
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"), render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502 status: 502
end end

View File

@ -67,6 +67,8 @@ end
# created_at :datetime not null # created_at :datetime not null
# updated_at :datetime not null # updated_at :datetime not null
# messages :jsonb # messages :jsonb
# temperature :integer
# stop_sequences :string is an Array
# #
# Indexes # Indexes
# #

View File

@ -5,47 +5,55 @@ CompletionPrompt.seed do |cp|
cp.id = -301 cp.id = -301
cp.name = "translate" cp.name = "translate"
cp.prompt_type = CompletionPrompt.prompt_types[:text] cp.prompt_type = CompletionPrompt.prompt_types[:text]
cp.messages = { insts: <<~TEXT } cp.stop_sequences = ["\n</output>", "</output>"]
I want you to act as an English translator, spelling corrector and improver. I will write to you cp.temperature = 0.2
in any language and you will detect the language, translate it and answer in the corrected and cp.messages = {
improved version of my text, in English. I want you to replace my simplified A0-level words and insts: <<~TEXT,
sentences with more beautiful and elegant, upper level English words and sentences. I want you to act as an English translator, spelling corrector and improver. I will write to you
Keep the meaning same, but make them more literary. I want you to only reply the correction, in any language and you will detect the language, translate it and answer in the corrected and
the improvements and nothing else, do not write explanations. improved version of my text, in English. I want you to replace my simplified A0-level words and
You will find the text between <input></input> XML tags. sentences with more beautiful and elegant, upper level English words and sentences.
TEXT Keep the meaning same, but make them more literary. I want you to only reply the correction,
the improvements and nothing else, do not write explanations.
You will find the text between <input></input> XML tags.
Include your translation between <output></output> XML tags.
TEXT
examples: [
["<input>Hello world</input>", "<output>Hello world</output>"],
["<input>Bonjour le monde</input>", "<output>Hello world</output>"],
],
}
end end
CompletionPrompt.seed do |cp| CompletionPrompt.seed do |cp|
cp.id = -303 cp.id = -303
cp.name = "proofread" cp.name = "proofread"
cp.prompt_type = CompletionPrompt.prompt_types[:diff] cp.prompt_type = CompletionPrompt.prompt_types[:diff]
cp.temperature = 0
cp.stop_sequences = ["\n</output>"]
cp.messages = { cp.messages = {
insts: <<~TEXT, insts: <<~TEXT,
You are a markdown proofreader. You correct egregious typos and phrasing issues but keep the user's original voice. You are a markdown proofreader. You correct egregious typos and phrasing issues but keep the user's original voice.
You do not touch code blocks. I will provide you with text to proofread. If nothing needs fixing, then you will echo the text back. You do not touch code blocks. I will provide you with text to proofread. If nothing needs fixing, then you will echo the text back.
Optionally, a user can specify intensity. Intensity 10 is a pedantic English teacher correcting the text.
Intensity 1 is a minimal proofreader. By default, you operate at intensity 1.
You will find the text between <input></input> XML tags. You will find the text between <input></input> XML tags.
You will ALWAYS return the corrected text between <output></output> XML tags.
TEXT TEXT
examples: [ examples: [
[ [
"<input>![amazing car|100x100, 22%](upload://hapy.png)</input>", "<input>![amazing car|100x100, 22%](upload://hapy.png)</input>",
"![Amazing car|100x100, 22%](upload://hapy.png)", "<output>![Amazing car|100x100, 22%](upload://hapy.png)</output>",
], ],
[<<~TEXT, "The rain in Spain, stays mainly in the Plane."], [<<~TEXT, "The rain in Spain, stays mainly in the Plane."],
<input> <input>
Intensity 1:
The rain in spain stays mainly in the plane. The rain in spain stays mainly in the plane.
</input> </input>
TEXT TEXT
[ [
"The rain in Spain, stays mainly in the Plane.", "<input>The rain in Spain, stays mainly in the Plane.</input>",
"The rain in Spain, stays mainly in the Plane.", "<output>The rain in Spain, stays mainly in the Plane.</output>",
], ],
[<<~TEXT, <<~TEXT], [<<~TEXT, <<~TEXT],
<input> <input>
Intensity 1:
Hello, Hello,
Sometimes the logo isn't changing automatically when color scheme changes. Sometimes the logo isn't changing automatically when color scheme changes.
@ -53,13 +61,14 @@ CompletionPrompt.seed do |cp|
![Screen Recording 2023-03-17 at 18.04.22|video](upload://2rcVL0ZMxHPNtPWQbZjwufKpWVU.mov) ![Screen Recording 2023-03-17 at 18.04.22|video](upload://2rcVL0ZMxHPNtPWQbZjwufKpWVU.mov)
</input> </input>
TEXT TEXT
<output>
Hello, Hello,
Sometimes the logo does not change automatically when the color scheme changes. Sometimes the logo does not change automatically when the color scheme changes.
![Screen Recording 2023-03-17 at 18.04.22|video](upload://2rcVL0ZMxHPNtPWQbZjwufKpWVU.mov) ![Screen Recording 2023-03-17 at 18.04.22|video](upload://2rcVL0ZMxHPNtPWQbZjwufKpWVU.mov)
</output>
TEXT TEXT
[<<~TEXT, <<~TEXT], [<<~TEXT, <<~TEXT],
<input> <input>
Intensity 1:
Any ideas what is wrong with this peace of cod? Any ideas what is wrong with this peace of cod?
> This quot contains a typo > This quot contains a typo
```ruby ```ruby
@ -69,6 +78,7 @@ CompletionPrompt.seed do |cp|
``` ```
</input> </input>
TEXT TEXT
<output>
Any ideas what is wrong with this piece of code? Any ideas what is wrong with this piece of code?
> This quot contains a typo > This quot contains a typo
```ruby ```ruby
@ -76,6 +86,7 @@ CompletionPrompt.seed do |cp|
testing.a_typo = 11 testing.a_typo = 11
bad = "bad" bad = "bad"
``` ```
</output>
TEXT TEXT
], ],
} }
@ -85,15 +96,19 @@ CompletionPrompt.seed do |cp|
cp.id = -304 cp.id = -304
cp.name = "markdown_table" cp.name = "markdown_table"
cp.prompt_type = CompletionPrompt.prompt_types[:diff] cp.prompt_type = CompletionPrompt.prompt_types[:diff]
cp.temperature = 0.5
cp.stop_sequences = ["\n</output>"]
cp.messages = { cp.messages = {
insts: <<~TEXT, insts: <<~TEXT,
You are a markdown table formatter, I will provide you text inside <input></input> XML tags and you will format it into a markdown table You are a markdown table formatter, I will provide you text inside <input></input> XML tags and you will format it into a markdown table
TEXT TEXT
examples: [ examples: [
["<input>sam,joe,jane\nage: 22| 10|11</input>", <<~TEXT], ["<input>sam,joe,jane\nage: 22| 10|11</input>", <<~TEXT],
<output>
| | sam | joe | jane | | | sam | joe | jane |
|---|---|---|---| |---|---|---|---|
| age | 22 | 10 | 11 | | age | 22 | 10 | 11 |
</output>
TEXT TEXT
[<<~TEXT, <<~TEXT], [<<~TEXT, <<~TEXT],
<input> <input>
@ -102,22 +117,26 @@ CompletionPrompt.seed do |cp|
fred: height 22 fred: height 22
</input> </input>
TEXT TEXT
<output>
| | speed | age | height | | | speed | age | height |
|---|---|---|---| |---|---|---|---|
| sam | 100 | 22 | - | | sam | 100 | 22 | - |
| jane | - | 10 | - | | jane | - | 10 | - |
| fred | - | - | 22 | | fred | - | - | 22 |
</output>
TEXT TEXT
[<<~TEXT, <<~TEXT], [<<~TEXT, <<~TEXT],
<input> <input>
chrome 22ms (first load 10ms) chrome 22ms (first load 10ms)
firefox 10ms (first load: 9ms) firefox 10ms (first load: 9ms)
</input> </input>
TEXT TEXT
<output>
| Browser | Load Time (ms) | First Load Time (ms) | | Browser | Load Time (ms) | First Load Time (ms) |
|---|---|---| |---|---|---|
| Chrome | 22 | 10 | | Chrome | 22 | 10 |
| Firefox | 10 | 9 | | Firefox | 10 | 9 |
</output>
TEXT TEXT
], ],
} }

View File

@ -0,0 +1,8 @@
# frozen_string_literal: true
class AddParamsToCompletionPrompt < ActiveRecord::Migration[7.0]
def change
add_column :completion_prompts, :temperature, :integer
add_column :completion_prompts, :stop_sequences, :string, array: true
end
end

View File

@ -36,20 +36,26 @@ module DiscourseAi
llm = DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model) llm = DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model)
generic_prompt = completion_prompt.messages_with_input(input) generic_prompt = completion_prompt.messages_with_input(input)
llm.completion!(generic_prompt, user, &block) llm.generate(
generic_prompt,
user: user,
temperature: completion_prompt.temperature,
stop_sequences: completion_prompt.stop_sequences,
&block
)
end end
def generate_and_send_prompt(completion_prompt, input, user) def generate_and_send_prompt(completion_prompt, input, user)
completion_result = generate_prompt(completion_prompt, input, user) completion_result = generate_prompt(completion_prompt, input, user)
result = { type: completion_prompt.prompt_type } result = { type: completion_prompt.prompt_type }
result[:diff] = parse_diff(input, completion_result) if completion_prompt.diff?
result[:suggestions] = ( result[:suggestions] = (
if completion_prompt.list? if completion_prompt.list?
parse_list(completion_result).map { |suggestion| sanitize_result(suggestion) } parse_list(completion_result).map { |suggestion| sanitize_result(suggestion) }
else else
[sanitize_result(completion_result)] sanitized = sanitize_result(completion_result)
result[:diff] = parse_diff(input, sanitized) if completion_prompt.diff?
[sanitized]
end end
) )
@ -79,25 +85,15 @@ module DiscourseAi
private private
def sanitize_result(result) SANITIZE_REGEX_STR =
tags_to_remove = %w[ %w[term context topic replyTo input output result]
<term> .map { |tag| "<#{tag}>\\n?|\\n?</#{tag}>" }
</term> .join("|")
<context>
</context>
<topic>
</topic>
<replyTo>
</replyTo>
<input>
</input>
<output>
</output>
<result>
</result>
]
result.dup.tap { |dup_result| tags_to_remove.each { |tag| dup_result.gsub!(tag, "") } } SANITIZE_REGEX = Regexp.new(SANITIZE_REGEX_STR, Regexp::IGNORECASE | Regexp::MULTILINE)
def sanitize_result(result)
result.gsub(SANITIZE_REGEX, "")
end end
def publish_update(channel, payload, user) def publish_update(channel, payload, user)

View File

@ -38,7 +38,10 @@ module DiscourseAi
You'll find the post between <input></input> XML tags. You'll find the post between <input></input> XML tags.
TEXT TEXT
DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model).completion!(prompt, user) DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model).generate(
prompt,
user: user,
)
end end
end end
end end

View File

@ -31,24 +31,16 @@ module DiscourseAi
result = nil result = nil
llm = DiscourseAi::Completions::Llm.proxy(model) llm = DiscourseAi::Completions::Llm.proxy(model)
key =
if model.include?("claude")
:max_tokens_to_sample
else
:max_tokens
end
prompt = { prompt = { insts: filled_system_prompt }
insts: filled_system_prompt,
params: {
model => {
key => (llm.tokenizer.tokenize(search_for_text).length * 2 + 10),
:temperature => 0,
},
},
}
result = llm.completion!(prompt, Discourse.system_user) result =
llm.generate(
prompt,
temperature: 0,
max_tokens: llm.tokenizer.tokenize(search_for_text).length * 2 + 10,
user: Discourse.system_user,
)
if result.strip == search_for_text.strip if result.strip == search_for_text.strip
user = User.find_by_username(canned_reply_user) if canned_reply_user.present? user = User.find_by_username(canned_reply_user) if canned_reply_user.present?

View File

@ -115,18 +115,13 @@ module DiscourseAi
insts: "You are a helpful bot specializing in summarizing activity on Discourse sites", insts: "You are a helpful bot specializing in summarizing activity on Discourse sites",
input: input, input: input,
final_insts: "Here is the report I generated for you", final_insts: "Here is the report I generated for you",
params: {
@model => {
temperature: 0,
},
},
} }
result = +"" result = +""
puts if Rails.env.development? && @debug_mode puts if Rails.env.development? && @debug_mode
@llm.completion!(prompt, Discourse.system_user) do |response| @llm.generate(prompt, temperature: 0, user: Discourse.system_user) do |response|
print response if Rails.env.development? && @debug_mode print response if Rails.env.development? && @debug_mode
result << response result << response
end end

View File

@ -95,7 +95,7 @@ module DiscourseAi
def max_prompt_tokens def max_prompt_tokens
# provide a buffer of 120 tokens - our function counting is not # provide a buffer of 120 tokens - our function counting is not
# 100% accurate and getting numbers to align exactly is very hard # 100% accurate and getting numbers to align exactly is very hard
buffer = (opts[:max_tokens_to_sample] || 2500) + 50 buffer = (opts[:max_tokens] || 2500) + 50
if tools.present? if tools.present?
# note this is about 100 tokens over, OpenAI have a more optimal representation # note this is about 100 tokens over, OpenAI have a more optimal representation

View File

@ -27,7 +27,7 @@ module DiscourseAi
if prompt[:examples] if prompt[:examples]
prompt[:examples].each do |example_pair| prompt[:examples].each do |example_pair|
mixtral_prompt << "[INST] #{example_pair.first} [/INST]\n" mixtral_prompt << "[INST] #{example_pair.first} [/INST]\n"
mixtral_prompt << "#{example_pair.second}\n" mixtral_prompt << "#{example_pair.second}</s>\n"
end end
end end

View File

@ -8,8 +8,24 @@ module DiscourseAi
%w[claude-instant-1 claude-2].include?(model_name) %w[claude-instant-1 claude-2].include?(model_name)
end end
def normalize_model_params(model_params)
model_params = model_params.dup
# temperature, stop_sequences are already supported
#
if model_params[:max_tokens]
model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens)
end
model_params
end
def default_options def default_options
{ max_tokens_to_sample: 2000, model: model } {
model: model,
max_tokens_to_sample: 2_000,
stop_sequences: ["\n\nHuman:", "</function_calls>"],
}
end end
def provider_id def provider_id

View File

@ -13,8 +13,24 @@ module DiscourseAi
SiteSetting.ai_bedrock_region.present? SiteSetting.ai_bedrock_region.present?
end end
def normalize_model_params(model_params)
model_params = model_params.dup
# temperature, stop_sequences are already supported
#
if model_params[:max_tokens]
model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens)
end
model_params
end
def default_options def default_options
{ max_tokens_to_sample: 2_000, stop_sequences: ["\n\nHuman:", "</function_calls>"] } {
model: model,
max_tokens_to_sample: 2_000,
stop_sequences: ["\n\nHuman:", "</function_calls>"],
}
end end
def provider_id def provider_id

View File

@ -32,6 +32,8 @@ module DiscourseAi
end end
def perform_completion!(dialect, user, model_params = {}) def perform_completion!(dialect, user, model_params = {})
model_params = normalize_model_params(model_params)
@streaming_mode = block_given? @streaming_mode = block_given?
prompt = dialect.translate prompt = dialect.translate
@ -199,6 +201,11 @@ module DiscourseAi
attr_reader :model attr_reader :model
# should normalize temperature, max_tokens, stop_words to endpoint specific values
def normalize_model_params(model_params)
raise NotImplementedError
end
def model_uri def model_uri
raise NotImplementedError raise NotImplementedError
end end
@ -262,7 +269,7 @@ module DiscourseAi
function_buffer.at("tool_id").inner_html = tool_name function_buffer.at("tool_id").inner_html = tool_name
end end
read_parameters = _read_parameters =
read_function read_function
.at("parameters") .at("parameters")
.elements .elements

View File

@ -16,6 +16,11 @@ module DiscourseAi
@prompt = nil @prompt = nil
end end
def normalize_model_params(model_params)
# max_tokens, temperature, stop_sequences are already supported
model_params
end
attr_reader :responses, :completions, :prompt attr_reader :responses, :completions, :prompt
def perform_completion!(prompt, _user, _model_params) def perform_completion!(prompt, _user, _model_params)

View File

@ -9,7 +9,23 @@ module DiscourseAi
end end
def default_options def default_options
{} { generationConfig: {} }
end
def normalize_model_params(model_params)
model_params = model_params.dup
if model_params[:stop_sequences]
model_params[:stopSequences] = model_params.delete(:stop_sequences)
end
if model_params[:temperature]
model_params[:maxOutputTokens] = model_params.delete(:max_tokens)
end
# temperature already supported
model_params
end end
def provider_id def provider_id
@ -27,9 +43,11 @@ module DiscourseAi
def prepare_payload(prompt, model_params, dialect) def prepare_payload(prompt, model_params, dialect)
default_options default_options
.merge(model_params)
.merge(contents: prompt) .merge(contents: prompt)
.tap { |payload| payload[:tools] = dialect.tools if dialect.tools.present? } .tap do |payload|
payload[:tools] = dialect.tools if dialect.tools.present?
payload[:generationConfig].merge!(model_params) if model_params.present?
end
end end
def prepare_request(payload) def prepare_request(payload)

View File

@ -19,6 +19,20 @@ module DiscourseAi
{ parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } } { parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } }
end end
def normalize_model_params(model_params)
model_params = model_params.dup
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
if model_params[:max_tokens]
model_params[:max_new_tokens] = model_params.delete(:max_tokens)
end
model_params
end
def provider_id def provider_id
AiApiAuditLog::Provider::HuggingFaceTextGeneration AiApiAuditLog::Provider::HuggingFaceTextGeneration
end end

View File

@ -15,6 +15,17 @@ module DiscourseAi
].include?(model_name) ].include?(model_name)
end end
def normalize_model_params(model_params)
model_params = model_params.dup
# max_tokens, temperature are already supported
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
model_params
end
def default_options def default_options
{ model: model == "gpt-4-turbo" ? "gpt-4-1106-preview" : model } { model: model == "gpt-4-turbo" ? "gpt-4-1106-preview" : model }
end end

View File

@ -10,6 +10,17 @@ module DiscourseAi
) )
end end
def normalize_model_params(model_params)
model_params = model_params.dup
# max_tokens, temperature are already supported
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
model_params
end
def default_options def default_options
{ max_tokens: 2000, model: model } { max_tokens: 2000, model: model }
end end
@ -39,7 +50,6 @@ module DiscourseAi
def prepare_request(payload) def prepare_request(payload)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end end

View File

@ -98,11 +98,24 @@ module DiscourseAi
# </invoke> # </invoke>
# </function_calls> # </function_calls>
# #
def completion!(generic_prompt, user, &partial_read_blk) def generate(
model_params = generic_prompt.dig(:params, model_name) || {} generic_prompt,
temperature: nil,
max_tokens: nil,
stop_sequences: nil,
user:,
&partial_read_blk
)
model_params = {
temperature: temperature,
max_tokens: max_tokens,
stop_sequences: stop_sequences,
}
model_params.merge!(generic_prompt.dig(:params, model_name) || {})
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
dialect = dialect_klass.new(generic_prompt, model_name, opts: model_params) dialect = dialect_klass.new(generic_prompt, model_name, opts: model_params)
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk) gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
end end

View File

@ -112,7 +112,7 @@ module DiscourseAi
llm_response = llm_response =
DiscourseAi::Completions::Llm.proxy( DiscourseAi::Completions::Llm.proxy(
SiteSetting.ai_embeddings_semantic_search_hyde_model, SiteSetting.ai_embeddings_semantic_search_hyde_model,
).completion!(prompt, @guardian.user) ).generate(prompt, user: @guardian.user)
Nokogiri::HTML5.fragment(llm_response).at("ai")&.text&.presence || llm_response Nokogiri::HTML5.fragment(llm_response).at("ai")&.text&.presence || llm_response
end end

View File

@ -99,7 +99,7 @@ module DiscourseAi
def summarize_single(llm, text, user, opts, &on_partial_blk) def summarize_single(llm, text, user, opts, &on_partial_blk)
prompt = summarization_prompt(text, opts) prompt = summarization_prompt(text, opts)
llm.completion!(prompt, user, &on_partial_blk) llm.generate(prompt, user: user, &on_partial_blk)
end end
def summarize_in_chunks(llm, chunks, user, opts) def summarize_in_chunks(llm, chunks, user, opts)
@ -107,7 +107,7 @@ module DiscourseAi
prompt = summarization_prompt(chunk[:summary], opts) prompt = summarization_prompt(chunk[:summary], opts)
prompt[:post_insts] = "Don't use more than 400 words for the summary." prompt[:post_insts] = "Don't use more than 400 words for the summary."
chunk[:summary] = llm.completion!(prompt, user) chunk[:summary] = llm.generate(prompt, user: user)
chunk chunk
end end
end end
@ -131,7 +131,7 @@ module DiscourseAi
</input> </input>
TEXT TEXT
llm.completion!(prompt, user, &on_partial_blk) llm.generate(prompt, user: user, &on_partial_blk)
end end
def summarization_prompt(input, opts) def summarization_prompt(input, opts)

View File

@ -74,7 +74,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do
#{prompt[:post_insts]} #{prompt[:post_insts]}
[/INST] Ok </s> [/INST] Ok </s>
[INST] #{prompt[:examples][0][0]} [/INST] [INST] #{prompt[:examples][0][0]} [/INST]
#{prompt[:examples][0][1]} #{prompt[:examples][0][1]}</s>
[INST] #{prompt[:input]} [/INST] [INST] #{prompt[:input]} [/INST]
TEXT TEXT

View File

@ -183,7 +183,7 @@ data: [D|ONE]
partials = [] partials = []
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
llm.completion!({ insts: "test" }, Discourse.system_user) { |partial| partials << partial } llm.generate({ insts: "test" }, user: Discourse.system_user) { |partial| partials << partial }
expect(partials.join).to eq("test,test2,test3,test4") expect(partials.join).to eq("test,test2,test3,test4")
end end
@ -212,7 +212,7 @@ data: [D|ONE]
partials = [] partials = []
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
llm.completion!({ insts: "test" }, Discourse.system_user) { |partial| partials << partial } llm.generate({ insts: "test" }, user: Discourse.system_user) { |partial| partials << partial }
expect(partials.join).to eq("test,test1,test2,test3,test4") expect(partials.join).to eq("test,test1,test2,test3,test4")
end end

View File

@ -21,7 +21,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
end end
end end
describe "#completion!" do describe "#generate" do
let(:prompt) do let(:prompt) do
{ {
insts: <<~TEXT, insts: <<~TEXT,
@ -52,7 +52,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
context "when getting the full response" do context "when getting the full response" do
it "processes the prompt and return the response" do it "processes the prompt and return the response" do
llm_response = llm.completion!(prompt, user) llm_response = llm.generate(prompt, user: user)
expect(llm_response).to eq(canned_response.responses[0]) expect(llm_response).to eq(canned_response.responses[0])
end end
@ -62,7 +62,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
it "processes the prompt and call the given block with the partial response" do it "processes the prompt and call the given block with the partial response" do
llm_response = +"" llm_response = +""
llm.completion!(prompt, user) { |partial, cancel_fn| llm_response << partial } llm.generate(prompt, user: user) { |partial, cancel_fn| llm_response << partial }
expect(llm_response).to eq(canned_response.responses[0]) expect(llm_response).to eq(canned_response.responses[0])
end end

View File

@ -59,7 +59,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
it "returns a generic error when the completion call fails" do it "returns a generic error when the completion call fails" do
DiscourseAi::Completions::Llm DiscourseAi::Completions::Llm
.any_instance .any_instance
.expects(:completion!) .expects(:generate)
.raises(DiscourseAi::Completions::Endpoints::Base::CompletionFailed) .raises(DiscourseAi::Completions::Endpoints::Base::CompletionFailed)
post "/discourse-ai/ai-helper/suggest", params: { mode: mode, text: text_to_proofread } post "/discourse-ai/ai-helper/suggest", params: { mode: mode, text: text_to_proofread }