FIX: AI helper not working correctly with mixtral (#399)

* FIX: AI helper not working correctly with mixtral

This PR introduces a new function on the generic llm called #generate

This will replace the implementation of completion!

#generate introduces a new way to pass temperature, max_tokens and stop_sequences

Then LLM implementers need to implement #normalize_model_params to
ensure the generic names match the LLM specific endpoint

This also adds temperature and stop_sequences to completion_prompts
this allows for much more robust completion prompts

* port everything over to #generate

* Fix translation

- On anthropic this no longer throws random "This is your translation:"
- On mixtral this actually works

* fix markdown table generation as well
This commit is contained in:
Sam 2024-01-04 23:53:47 +11:00 committed by GitHub
parent 0483e0bb88
commit 03fc94684b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 217 additions and 92 deletions

View File

@ -43,7 +43,7 @@ module DiscourseAi
),
status: 200
end
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502
end
@ -63,7 +63,7 @@ module DiscourseAi
),
status: 200
end
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502
end
@ -111,7 +111,7 @@ module DiscourseAi
)
render json: { success: true }, status: 200
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502
end

View File

@ -67,6 +67,8 @@ end
# created_at :datetime not null
# updated_at :datetime not null
# messages :jsonb
# temperature :integer
# stop_sequences :string is an Array
#
# Indexes
#

View File

@ -5,7 +5,10 @@ CompletionPrompt.seed do |cp|
cp.id = -301
cp.name = "translate"
cp.prompt_type = CompletionPrompt.prompt_types[:text]
cp.messages = { insts: <<~TEXT }
cp.stop_sequences = ["\n</output>", "</output>"]
cp.temperature = 0.2
cp.messages = {
insts: <<~TEXT,
I want you to act as an English translator, spelling corrector and improver. I will write to you
in any language and you will detect the language, translate it and answer in the corrected and
improved version of my text, in English. I want you to replace my simplified A0-level words and
@ -13,39 +16,44 @@ CompletionPrompt.seed do |cp|
Keep the meaning same, but make them more literary. I want you to only reply the correction,
the improvements and nothing else, do not write explanations.
You will find the text between <input></input> XML tags.
Include your translation between <output></output> XML tags.
TEXT
examples: [
["<input>Hello world</input>", "<output>Hello world</output>"],
["<input>Bonjour le monde</input>", "<output>Hello world</output>"],
],
}
end
CompletionPrompt.seed do |cp|
cp.id = -303
cp.name = "proofread"
cp.prompt_type = CompletionPrompt.prompt_types[:diff]
cp.temperature = 0
cp.stop_sequences = ["\n</output>"]
cp.messages = {
insts: <<~TEXT,
You are a markdown proofreader. You correct egregious typos and phrasing issues but keep the user's original voice.
You do not touch code blocks. I will provide you with text to proofread. If nothing needs fixing, then you will echo the text back.
Optionally, a user can specify intensity. Intensity 10 is a pedantic English teacher correcting the text.
Intensity 1 is a minimal proofreader. By default, you operate at intensity 1.
You will find the text between <input></input> XML tags.
You will ALWAYS return the corrected text between <output></output> XML tags.
TEXT
examples: [
[
"<input>![amazing car|100x100, 22%](upload://hapy.png)</input>",
"![Amazing car|100x100, 22%](upload://hapy.png)",
"<output>![Amazing car|100x100, 22%](upload://hapy.png)</output>",
],
[<<~TEXT, "The rain in Spain, stays mainly in the Plane."],
<input>
Intensity 1:
The rain in spain stays mainly in the plane.
</input>
TEXT
[
"The rain in Spain, stays mainly in the Plane.",
"The rain in Spain, stays mainly in the Plane.",
"<input>The rain in Spain, stays mainly in the Plane.</input>",
"<output>The rain in Spain, stays mainly in the Plane.</output>",
],
[<<~TEXT, <<~TEXT],
<input>
Intensity 1:
Hello,
Sometimes the logo isn't changing automatically when color scheme changes.
@ -53,13 +61,14 @@ CompletionPrompt.seed do |cp|
![Screen Recording 2023-03-17 at 18.04.22|video](upload://2rcVL0ZMxHPNtPWQbZjwufKpWVU.mov)
</input>
TEXT
<output>
Hello,
Sometimes the logo does not change automatically when the color scheme changes.
![Screen Recording 2023-03-17 at 18.04.22|video](upload://2rcVL0ZMxHPNtPWQbZjwufKpWVU.mov)
</output>
TEXT
[<<~TEXT, <<~TEXT],
<input>
Intensity 1:
Any ideas what is wrong with this peace of cod?
> This quot contains a typo
```ruby
@ -69,6 +78,7 @@ CompletionPrompt.seed do |cp|
```
</input>
TEXT
<output>
Any ideas what is wrong with this piece of code?
> This quot contains a typo
```ruby
@ -76,6 +86,7 @@ CompletionPrompt.seed do |cp|
testing.a_typo = 11
bad = "bad"
```
</output>
TEXT
],
}
@ -85,15 +96,19 @@ CompletionPrompt.seed do |cp|
cp.id = -304
cp.name = "markdown_table"
cp.prompt_type = CompletionPrompt.prompt_types[:diff]
cp.temperature = 0.5
cp.stop_sequences = ["\n</output>"]
cp.messages = {
insts: <<~TEXT,
You are a markdown table formatter, I will provide you text inside <input></input> XML tags and you will format it into a markdown table
TEXT
examples: [
["<input>sam,joe,jane\nage: 22| 10|11</input>", <<~TEXT],
<output>
| | sam | joe | jane |
|---|---|---|---|
| age | 22 | 10 | 11 |
</output>
TEXT
[<<~TEXT, <<~TEXT],
<input>
@ -102,11 +117,13 @@ CompletionPrompt.seed do |cp|
fred: height 22
</input>
TEXT
<output>
| | speed | age | height |
|---|---|---|---|
| sam | 100 | 22 | - |
| jane | - | 10 | - |
| fred | - | - | 22 |
</output>
TEXT
[<<~TEXT, <<~TEXT],
<input>
@ -114,10 +131,12 @@ CompletionPrompt.seed do |cp|
firefox 10ms (first load: 9ms)
</input>
TEXT
<output>
| Browser | Load Time (ms) | First Load Time (ms) |
|---|---|---|
| Chrome | 22 | 10 |
| Firefox | 10 | 9 |
</output>
TEXT
],
}

View File

@ -0,0 +1,8 @@
# frozen_string_literal: true
class AddParamsToCompletionPrompt < ActiveRecord::Migration[7.0]
def change
add_column :completion_prompts, :temperature, :integer
add_column :completion_prompts, :stop_sequences, :string, array: true
end
end

View File

@ -36,20 +36,26 @@ module DiscourseAi
llm = DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model)
generic_prompt = completion_prompt.messages_with_input(input)
llm.completion!(generic_prompt, user, &block)
llm.generate(
generic_prompt,
user: user,
temperature: completion_prompt.temperature,
stop_sequences: completion_prompt.stop_sequences,
&block
)
end
def generate_and_send_prompt(completion_prompt, input, user)
completion_result = generate_prompt(completion_prompt, input, user)
result = { type: completion_prompt.prompt_type }
result[:diff] = parse_diff(input, completion_result) if completion_prompt.diff?
result[:suggestions] = (
if completion_prompt.list?
parse_list(completion_result).map { |suggestion| sanitize_result(suggestion) }
else
[sanitize_result(completion_result)]
sanitized = sanitize_result(completion_result)
result[:diff] = parse_diff(input, sanitized) if completion_prompt.diff?
[sanitized]
end
)
@ -79,25 +85,15 @@ module DiscourseAi
private
def sanitize_result(result)
tags_to_remove = %w[
<term>
</term>
<context>
</context>
<topic>
</topic>
<replyTo>
</replyTo>
<input>
</input>
<output>
</output>
<result>
</result>
]
SANITIZE_REGEX_STR =
%w[term context topic replyTo input output result]
.map { |tag| "<#{tag}>\\n?|\\n?</#{tag}>" }
.join("|")
result.dup.tap { |dup_result| tags_to_remove.each { |tag| dup_result.gsub!(tag, "") } }
SANITIZE_REGEX = Regexp.new(SANITIZE_REGEX_STR, Regexp::IGNORECASE | Regexp::MULTILINE)
def sanitize_result(result)
result.gsub(SANITIZE_REGEX, "")
end
def publish_update(channel, payload, user)

View File

@ -38,7 +38,10 @@ module DiscourseAi
You'll find the post between <input></input> XML tags.
TEXT
DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model).completion!(prompt, user)
DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model).generate(
prompt,
user: user,
)
end
end
end

View File

@ -31,24 +31,16 @@ module DiscourseAi
result = nil
llm = DiscourseAi::Completions::Llm.proxy(model)
key =
if model.include?("claude")
:max_tokens_to_sample
else
:max_tokens
end
prompt = {
insts: filled_system_prompt,
params: {
model => {
key => (llm.tokenizer.tokenize(search_for_text).length * 2 + 10),
:temperature => 0,
},
},
}
prompt = { insts: filled_system_prompt }
result = llm.completion!(prompt, Discourse.system_user)
result =
llm.generate(
prompt,
temperature: 0,
max_tokens: llm.tokenizer.tokenize(search_for_text).length * 2 + 10,
user: Discourse.system_user,
)
if result.strip == search_for_text.strip
user = User.find_by_username(canned_reply_user) if canned_reply_user.present?

View File

@ -115,18 +115,13 @@ module DiscourseAi
insts: "You are a helpful bot specializing in summarizing activity on Discourse sites",
input: input,
final_insts: "Here is the report I generated for you",
params: {
@model => {
temperature: 0,
},
},
}
result = +""
puts if Rails.env.development? && @debug_mode
@llm.completion!(prompt, Discourse.system_user) do |response|
@llm.generate(prompt, temperature: 0, user: Discourse.system_user) do |response|
print response if Rails.env.development? && @debug_mode
result << response
end

View File

@ -95,7 +95,7 @@ module DiscourseAi
def max_prompt_tokens
# provide a buffer of 120 tokens - our function counting is not
# 100% accurate and getting numbers to align exactly is very hard
buffer = (opts[:max_tokens_to_sample] || 2500) + 50
buffer = (opts[:max_tokens] || 2500) + 50
if tools.present?
# note this is about 100 tokens over, OpenAI have a more optimal representation

View File

@ -27,7 +27,7 @@ module DiscourseAi
if prompt[:examples]
prompt[:examples].each do |example_pair|
mixtral_prompt << "[INST] #{example_pair.first} [/INST]\n"
mixtral_prompt << "#{example_pair.second}\n"
mixtral_prompt << "#{example_pair.second}</s>\n"
end
end

View File

@ -8,8 +8,24 @@ module DiscourseAi
%w[claude-instant-1 claude-2].include?(model_name)
end
def normalize_model_params(model_params)
model_params = model_params.dup
# temperature, stop_sequences are already supported
#
if model_params[:max_tokens]
model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens)
end
model_params
end
def default_options
{ max_tokens_to_sample: 2000, model: model }
{
model: model,
max_tokens_to_sample: 2_000,
stop_sequences: ["\n\nHuman:", "</function_calls>"],
}
end
def provider_id

View File

@ -13,8 +13,24 @@ module DiscourseAi
SiteSetting.ai_bedrock_region.present?
end
def normalize_model_params(model_params)
model_params = model_params.dup
# temperature, stop_sequences are already supported
#
if model_params[:max_tokens]
model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens)
end
model_params
end
def default_options
{ max_tokens_to_sample: 2_000, stop_sequences: ["\n\nHuman:", "</function_calls>"] }
{
model: model,
max_tokens_to_sample: 2_000,
stop_sequences: ["\n\nHuman:", "</function_calls>"],
}
end
def provider_id

View File

@ -32,6 +32,8 @@ module DiscourseAi
end
def perform_completion!(dialect, user, model_params = {})
model_params = normalize_model_params(model_params)
@streaming_mode = block_given?
prompt = dialect.translate
@ -199,6 +201,11 @@ module DiscourseAi
attr_reader :model
# should normalize temperature, max_tokens, stop_words to endpoint specific values
def normalize_model_params(model_params)
raise NotImplementedError
end
def model_uri
raise NotImplementedError
end
@ -262,7 +269,7 @@ module DiscourseAi
function_buffer.at("tool_id").inner_html = tool_name
end
read_parameters =
_read_parameters =
read_function
.at("parameters")
.elements

View File

@ -16,6 +16,11 @@ module DiscourseAi
@prompt = nil
end
def normalize_model_params(model_params)
# max_tokens, temperature, stop_sequences are already supported
model_params
end
attr_reader :responses, :completions, :prompt
def perform_completion!(prompt, _user, _model_params)

View File

@ -9,7 +9,23 @@ module DiscourseAi
end
def default_options
{}
{ generationConfig: {} }
end
def normalize_model_params(model_params)
model_params = model_params.dup
if model_params[:stop_sequences]
model_params[:stopSequences] = model_params.delete(:stop_sequences)
end
if model_params[:temperature]
model_params[:maxOutputTokens] = model_params.delete(:max_tokens)
end
# temperature already supported
model_params
end
def provider_id
@ -27,9 +43,11 @@ module DiscourseAi
def prepare_payload(prompt, model_params, dialect)
default_options
.merge(model_params)
.merge(contents: prompt)
.tap { |payload| payload[:tools] = dialect.tools if dialect.tools.present? }
.tap do |payload|
payload[:tools] = dialect.tools if dialect.tools.present?
payload[:generationConfig].merge!(model_params) if model_params.present?
end
end
def prepare_request(payload)

View File

@ -19,6 +19,20 @@ module DiscourseAi
{ parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } }
end
def normalize_model_params(model_params)
model_params = model_params.dup
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
if model_params[:max_tokens]
model_params[:max_new_tokens] = model_params.delete(:max_tokens)
end
model_params
end
def provider_id
AiApiAuditLog::Provider::HuggingFaceTextGeneration
end

View File

@ -15,6 +15,17 @@ module DiscourseAi
].include?(model_name)
end
def normalize_model_params(model_params)
model_params = model_params.dup
# max_tokens, temperature are already supported
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
model_params
end
def default_options
{ model: model == "gpt-4-turbo" ? "gpt-4-1106-preview" : model }
end

View File

@ -10,6 +10,17 @@ module DiscourseAi
)
end
def normalize_model_params(model_params)
model_params = model_params.dup
# max_tokens, temperature are already supported
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
model_params
end
def default_options
{ max_tokens: 2000, model: model }
end
@ -39,7 +50,6 @@ module DiscourseAi
def prepare_request(payload)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end

View File

@ -98,11 +98,24 @@ module DiscourseAi
# </invoke>
# </function_calls>
#
def completion!(generic_prompt, user, &partial_read_blk)
model_params = generic_prompt.dig(:params, model_name) || {}
def generate(
generic_prompt,
temperature: nil,
max_tokens: nil,
stop_sequences: nil,
user:,
&partial_read_blk
)
model_params = {
temperature: temperature,
max_tokens: max_tokens,
stop_sequences: stop_sequences,
}
model_params.merge!(generic_prompt.dig(:params, model_name) || {})
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
dialect = dialect_klass.new(generic_prompt, model_name, opts: model_params)
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
end

View File

@ -112,7 +112,7 @@ module DiscourseAi
llm_response =
DiscourseAi::Completions::Llm.proxy(
SiteSetting.ai_embeddings_semantic_search_hyde_model,
).completion!(prompt, @guardian.user)
).generate(prompt, user: @guardian.user)
Nokogiri::HTML5.fragment(llm_response).at("ai")&.text&.presence || llm_response
end

View File

@ -99,7 +99,7 @@ module DiscourseAi
def summarize_single(llm, text, user, opts, &on_partial_blk)
prompt = summarization_prompt(text, opts)
llm.completion!(prompt, user, &on_partial_blk)
llm.generate(prompt, user: user, &on_partial_blk)
end
def summarize_in_chunks(llm, chunks, user, opts)
@ -107,7 +107,7 @@ module DiscourseAi
prompt = summarization_prompt(chunk[:summary], opts)
prompt[:post_insts] = "Don't use more than 400 words for the summary."
chunk[:summary] = llm.completion!(prompt, user)
chunk[:summary] = llm.generate(prompt, user: user)
chunk
end
end
@ -131,7 +131,7 @@ module DiscourseAi
</input>
TEXT
llm.completion!(prompt, user, &on_partial_blk)
llm.generate(prompt, user: user, &on_partial_blk)
end
def summarization_prompt(input, opts)

View File

@ -74,7 +74,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do
#{prompt[:post_insts]}
[/INST] Ok </s>
[INST] #{prompt[:examples][0][0]} [/INST]
#{prompt[:examples][0][1]}
#{prompt[:examples][0][1]}</s>
[INST] #{prompt[:input]} [/INST]
TEXT

View File

@ -183,7 +183,7 @@ data: [D|ONE]
partials = []
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
llm.completion!({ insts: "test" }, Discourse.system_user) { |partial| partials << partial }
llm.generate({ insts: "test" }, user: Discourse.system_user) { |partial| partials << partial }
expect(partials.join).to eq("test,test2,test3,test4")
end
@ -212,7 +212,7 @@ data: [D|ONE]
partials = []
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
llm.completion!({ insts: "test" }, Discourse.system_user) { |partial| partials << partial }
llm.generate({ insts: "test" }, user: Discourse.system_user) { |partial| partials << partial }
expect(partials.join).to eq("test,test1,test2,test3,test4")
end

View File

@ -21,7 +21,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
end
end
describe "#completion!" do
describe "#generate" do
let(:prompt) do
{
insts: <<~TEXT,
@ -52,7 +52,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
context "when getting the full response" do
it "processes the prompt and return the response" do
llm_response = llm.completion!(prompt, user)
llm_response = llm.generate(prompt, user: user)
expect(llm_response).to eq(canned_response.responses[0])
end
@ -62,7 +62,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
it "processes the prompt and call the given block with the partial response" do
llm_response = +""
llm.completion!(prompt, user) { |partial, cancel_fn| llm_response << partial }
llm.generate(prompt, user: user) { |partial, cancel_fn| llm_response << partial }
expect(llm_response).to eq(canned_response.responses[0])
end

View File

@ -59,7 +59,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
it "returns a generic error when the completion call fails" do
DiscourseAi::Completions::Llm
.any_instance
.expects(:completion!)
.expects(:generate)
.raises(DiscourseAi::Completions::Endpoints::Base::CompletionFailed)
post "/discourse-ai/ai-helper/suggest", params: { mode: mode, text: text_to_proofread }