FIX: AI helper not working correctly with mixtral (#399)

* FIX: AI helper not working correctly with mixtral

This PR introduces a new function on the generic llm called #generate

This will replace the implementation of completion!

#generate introduces a new way to pass temperature, max_tokens and stop_sequences

Then LLM implementers need to implement #normalize_model_params to
ensure the generic names match the LLM specific endpoint

This also adds temperature and stop_sequences to completion_prompts
this allows for much more robust completion prompts

* port everything over to #generate

* Fix translation

- On anthropic this no longer throws random "This is your translation:"
- On mixtral this actually works

* fix markdown table generation as well
This commit is contained in:
Sam 2024-01-04 23:53:47 +11:00 committed by GitHub
parent 0483e0bb88
commit 03fc94684b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 217 additions and 92 deletions

View File

@ -43,7 +43,7 @@ module DiscourseAi
),
status: 200
end
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502
end
@ -63,7 +63,7 @@ module DiscourseAi
),
status: 200
end
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502
end
@ -111,7 +111,7 @@ module DiscourseAi
)
render json: { success: true }, status: 200
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502
end

View File

@ -67,6 +67,8 @@ end
# created_at :datetime not null
# updated_at :datetime not null
# messages :jsonb
# temperature :integer
# stop_sequences :string is an Array
#
# Indexes
#

View File

@ -5,47 +5,55 @@ CompletionPrompt.seed do |cp|
cp.id = -301
cp.name = "translate"
cp.prompt_type = CompletionPrompt.prompt_types[:text]
cp.messages = { insts: <<~TEXT }
I want you to act as an English translator, spelling corrector and improver. I will write to you
in any language and you will detect the language, translate it and answer in the corrected and
improved version of my text, in English. I want you to replace my simplified A0-level words and
sentences with more beautiful and elegant, upper level English words and sentences.
Keep the meaning same, but make them more literary. I want you to only reply the correction,
the improvements and nothing else, do not write explanations.
You will find the text between <input></input> XML tags.
TEXT
cp.stop_sequences = ["\n</output>", "</output>"]
cp.temperature = 0.2
cp.messages = {
insts: <<~TEXT,
I want you to act as an English translator, spelling corrector and improver. I will write to you
in any language and you will detect the language, translate it and answer in the corrected and
improved version of my text, in English. I want you to replace my simplified A0-level words and
sentences with more beautiful and elegant, upper level English words and sentences.
Keep the meaning same, but make them more literary. I want you to only reply the correction,
the improvements and nothing else, do not write explanations.
You will find the text between <input></input> XML tags.
Include your translation between <output></output> XML tags.
TEXT
examples: [
["<input>Hello world</input>", "<output>Hello world</output>"],
["<input>Bonjour le monde</input>", "<output>Hello world</output>"],
],
}
end
CompletionPrompt.seed do |cp|
cp.id = -303
cp.name = "proofread"
cp.prompt_type = CompletionPrompt.prompt_types[:diff]
cp.temperature = 0
cp.stop_sequences = ["\n</output>"]
cp.messages = {
insts: <<~TEXT,
You are a markdown proofreader. You correct egregious typos and phrasing issues but keep the user's original voice.
You do not touch code blocks. I will provide you with text to proofread. If nothing needs fixing, then you will echo the text back.
Optionally, a user can specify intensity. Intensity 10 is a pedantic English teacher correcting the text.
Intensity 1 is a minimal proofreader. By default, you operate at intensity 1.
You will find the text between <input></input> XML tags.
You will ALWAYS return the corrected text between <output></output> XML tags.
TEXT
examples: [
[
"<input>![amazing car|100x100, 22%](upload://hapy.png)</input>",
"![Amazing car|100x100, 22%](upload://hapy.png)",
"<output>![Amazing car|100x100, 22%](upload://hapy.png)</output>",
],
[<<~TEXT, "The rain in Spain, stays mainly in the Plane."],
<input>
Intensity 1:
The rain in spain stays mainly in the plane.
</input>
TEXT
[
"The rain in Spain, stays mainly in the Plane.",
"The rain in Spain, stays mainly in the Plane.",
"<input>The rain in Spain, stays mainly in the Plane.</input>",
"<output>The rain in Spain, stays mainly in the Plane.</output>",
],
[<<~TEXT, <<~TEXT],
<input>
Intensity 1:
Hello,
Sometimes the logo isn't changing automatically when color scheme changes.
@ -53,13 +61,14 @@ CompletionPrompt.seed do |cp|
![Screen Recording 2023-03-17 at 18.04.22|video](upload://2rcVL0ZMxHPNtPWQbZjwufKpWVU.mov)
</input>
TEXT
<output>
Hello,
Sometimes the logo does not change automatically when the color scheme changes.
![Screen Recording 2023-03-17 at 18.04.22|video](upload://2rcVL0ZMxHPNtPWQbZjwufKpWVU.mov)
</output>
TEXT
[<<~TEXT, <<~TEXT],
<input>
Intensity 1:
Any ideas what is wrong with this peace of cod?
> This quot contains a typo
```ruby
@ -69,6 +78,7 @@ CompletionPrompt.seed do |cp|
```
</input>
TEXT
<output>
Any ideas what is wrong with this piece of code?
> This quot contains a typo
```ruby
@ -76,6 +86,7 @@ CompletionPrompt.seed do |cp|
testing.a_typo = 11
bad = "bad"
```
</output>
TEXT
],
}
@ -85,15 +96,19 @@ CompletionPrompt.seed do |cp|
cp.id = -304
cp.name = "markdown_table"
cp.prompt_type = CompletionPrompt.prompt_types[:diff]
cp.temperature = 0.5
cp.stop_sequences = ["\n</output>"]
cp.messages = {
insts: <<~TEXT,
You are a markdown table formatter, I will provide you text inside <input></input> XML tags and you will format it into a markdown table
TEXT
examples: [
["<input>sam,joe,jane\nage: 22| 10|11</input>", <<~TEXT],
<output>
| | sam | joe | jane |
|---|---|---|---|
| age | 22 | 10 | 11 |
</output>
TEXT
[<<~TEXT, <<~TEXT],
<input>
@ -102,22 +117,26 @@ CompletionPrompt.seed do |cp|
fred: height 22
</input>
TEXT
<output>
| | speed | age | height |
|---|---|---|---|
| sam | 100 | 22 | - |
| jane | - | 10 | - |
| fred | - | - | 22 |
</output>
TEXT
[<<~TEXT, <<~TEXT],
<input>
chrome 22ms (first load 10ms)
firefox 10ms (first load: 9ms)
chrome 22ms (first load 10ms)
firefox 10ms (first load: 9ms)
</input>
TEXT
<output>
| Browser | Load Time (ms) | First Load Time (ms) |
|---|---|---|
| Chrome | 22 | 10 |
| Firefox | 10 | 9 |
</output>
TEXT
],
}

View File

@ -0,0 +1,8 @@
# frozen_string_literal: true
class AddParamsToCompletionPrompt < ActiveRecord::Migration[7.0]
def change
add_column :completion_prompts, :temperature, :integer
add_column :completion_prompts, :stop_sequences, :string, array: true
end
end

View File

@ -36,20 +36,26 @@ module DiscourseAi
llm = DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model)
generic_prompt = completion_prompt.messages_with_input(input)
llm.completion!(generic_prompt, user, &block)
llm.generate(
generic_prompt,
user: user,
temperature: completion_prompt.temperature,
stop_sequences: completion_prompt.stop_sequences,
&block
)
end
def generate_and_send_prompt(completion_prompt, input, user)
completion_result = generate_prompt(completion_prompt, input, user)
result = { type: completion_prompt.prompt_type }
result[:diff] = parse_diff(input, completion_result) if completion_prompt.diff?
result[:suggestions] = (
if completion_prompt.list?
parse_list(completion_result).map { |suggestion| sanitize_result(suggestion) }
else
[sanitize_result(completion_result)]
sanitized = sanitize_result(completion_result)
result[:diff] = parse_diff(input, sanitized) if completion_prompt.diff?
[sanitized]
end
)
@ -79,25 +85,15 @@ module DiscourseAi
private
def sanitize_result(result)
tags_to_remove = %w[
<term>
</term>
<context>
</context>
<topic>
</topic>
<replyTo>
</replyTo>
<input>
</input>
<output>
</output>
<result>
</result>
]
SANITIZE_REGEX_STR =
%w[term context topic replyTo input output result]
.map { |tag| "<#{tag}>\\n?|\\n?</#{tag}>" }
.join("|")
result.dup.tap { |dup_result| tags_to_remove.each { |tag| dup_result.gsub!(tag, "") } }
SANITIZE_REGEX = Regexp.new(SANITIZE_REGEX_STR, Regexp::IGNORECASE | Regexp::MULTILINE)
def sanitize_result(result)
result.gsub(SANITIZE_REGEX, "")
end
def publish_update(channel, payload, user)

View File

@ -38,7 +38,10 @@ module DiscourseAi
You'll find the post between <input></input> XML tags.
TEXT
DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model).completion!(prompt, user)
DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model).generate(
prompt,
user: user,
)
end
end
end

View File

@ -31,24 +31,16 @@ module DiscourseAi
result = nil
llm = DiscourseAi::Completions::Llm.proxy(model)
key =
if model.include?("claude")
:max_tokens_to_sample
else
:max_tokens
end
prompt = {
insts: filled_system_prompt,
params: {
model => {
key => (llm.tokenizer.tokenize(search_for_text).length * 2 + 10),
:temperature => 0,
},
},
}
prompt = { insts: filled_system_prompt }
result = llm.completion!(prompt, Discourse.system_user)
result =
llm.generate(
prompt,
temperature: 0,
max_tokens: llm.tokenizer.tokenize(search_for_text).length * 2 + 10,
user: Discourse.system_user,
)
if result.strip == search_for_text.strip
user = User.find_by_username(canned_reply_user) if canned_reply_user.present?

View File

@ -115,18 +115,13 @@ module DiscourseAi
insts: "You are a helpful bot specializing in summarizing activity on Discourse sites",
input: input,
final_insts: "Here is the report I generated for you",
params: {
@model => {
temperature: 0,
},
},
}
result = +""
puts if Rails.env.development? && @debug_mode
@llm.completion!(prompt, Discourse.system_user) do |response|
@llm.generate(prompt, temperature: 0, user: Discourse.system_user) do |response|
print response if Rails.env.development? && @debug_mode
result << response
end

View File

@ -95,7 +95,7 @@ module DiscourseAi
def max_prompt_tokens
# provide a buffer of 120 tokens - our function counting is not
# 100% accurate and getting numbers to align exactly is very hard
buffer = (opts[:max_tokens_to_sample] || 2500) + 50
buffer = (opts[:max_tokens] || 2500) + 50
if tools.present?
# note this is about 100 tokens over, OpenAI have a more optimal representation

View File

@ -27,7 +27,7 @@ module DiscourseAi
if prompt[:examples]
prompt[:examples].each do |example_pair|
mixtral_prompt << "[INST] #{example_pair.first} [/INST]\n"
mixtral_prompt << "#{example_pair.second}\n"
mixtral_prompt << "#{example_pair.second}</s>\n"
end
end

View File

@ -8,8 +8,24 @@ module DiscourseAi
%w[claude-instant-1 claude-2].include?(model_name)
end
def normalize_model_params(model_params)
model_params = model_params.dup
# temperature, stop_sequences are already supported
#
if model_params[:max_tokens]
model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens)
end
model_params
end
def default_options
{ max_tokens_to_sample: 2000, model: model }
{
model: model,
max_tokens_to_sample: 2_000,
stop_sequences: ["\n\nHuman:", "</function_calls>"],
}
end
def provider_id

View File

@ -13,8 +13,24 @@ module DiscourseAi
SiteSetting.ai_bedrock_region.present?
end
def normalize_model_params(model_params)
model_params = model_params.dup
# temperature, stop_sequences are already supported
#
if model_params[:max_tokens]
model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens)
end
model_params
end
def default_options
{ max_tokens_to_sample: 2_000, stop_sequences: ["\n\nHuman:", "</function_calls>"] }
{
model: model,
max_tokens_to_sample: 2_000,
stop_sequences: ["\n\nHuman:", "</function_calls>"],
}
end
def provider_id

View File

@ -32,6 +32,8 @@ module DiscourseAi
end
def perform_completion!(dialect, user, model_params = {})
model_params = normalize_model_params(model_params)
@streaming_mode = block_given?
prompt = dialect.translate
@ -199,6 +201,11 @@ module DiscourseAi
attr_reader :model
# should normalize temperature, max_tokens, stop_words to endpoint specific values
def normalize_model_params(model_params)
raise NotImplementedError
end
def model_uri
raise NotImplementedError
end
@ -262,7 +269,7 @@ module DiscourseAi
function_buffer.at("tool_id").inner_html = tool_name
end
read_parameters =
_read_parameters =
read_function
.at("parameters")
.elements

View File

@ -16,6 +16,11 @@ module DiscourseAi
@prompt = nil
end
def normalize_model_params(model_params)
# max_tokens, temperature, stop_sequences are already supported
model_params
end
attr_reader :responses, :completions, :prompt
def perform_completion!(prompt, _user, _model_params)

View File

@ -9,7 +9,23 @@ module DiscourseAi
end
def default_options
{}
{ generationConfig: {} }
end
def normalize_model_params(model_params)
model_params = model_params.dup
if model_params[:stop_sequences]
model_params[:stopSequences] = model_params.delete(:stop_sequences)
end
if model_params[:temperature]
model_params[:maxOutputTokens] = model_params.delete(:max_tokens)
end
# temperature already supported
model_params
end
def provider_id
@ -27,9 +43,11 @@ module DiscourseAi
def prepare_payload(prompt, model_params, dialect)
default_options
.merge(model_params)
.merge(contents: prompt)
.tap { |payload| payload[:tools] = dialect.tools if dialect.tools.present? }
.tap do |payload|
payload[:tools] = dialect.tools if dialect.tools.present?
payload[:generationConfig].merge!(model_params) if model_params.present?
end
end
def prepare_request(payload)

View File

@ -19,6 +19,20 @@ module DiscourseAi
{ parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } }
end
def normalize_model_params(model_params)
model_params = model_params.dup
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
if model_params[:max_tokens]
model_params[:max_new_tokens] = model_params.delete(:max_tokens)
end
model_params
end
def provider_id
AiApiAuditLog::Provider::HuggingFaceTextGeneration
end

View File

@ -15,6 +15,17 @@ module DiscourseAi
].include?(model_name)
end
def normalize_model_params(model_params)
model_params = model_params.dup
# max_tokens, temperature are already supported
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
model_params
end
def default_options
{ model: model == "gpt-4-turbo" ? "gpt-4-1106-preview" : model }
end

View File

@ -10,6 +10,17 @@ module DiscourseAi
)
end
def normalize_model_params(model_params)
model_params = model_params.dup
# max_tokens, temperature are already supported
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
model_params
end
def default_options
{ max_tokens: 2000, model: model }
end
@ -39,7 +50,6 @@ module DiscourseAi
def prepare_request(payload)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end

View File

@ -98,11 +98,24 @@ module DiscourseAi
# </invoke>
# </function_calls>
#
def completion!(generic_prompt, user, &partial_read_blk)
model_params = generic_prompt.dig(:params, model_name) || {}
def generate(
generic_prompt,
temperature: nil,
max_tokens: nil,
stop_sequences: nil,
user:,
&partial_read_blk
)
model_params = {
temperature: temperature,
max_tokens: max_tokens,
stop_sequences: stop_sequences,
}
model_params.merge!(generic_prompt.dig(:params, model_name) || {})
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
dialect = dialect_klass.new(generic_prompt, model_name, opts: model_params)
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
end

View File

@ -112,7 +112,7 @@ module DiscourseAi
llm_response =
DiscourseAi::Completions::Llm.proxy(
SiteSetting.ai_embeddings_semantic_search_hyde_model,
).completion!(prompt, @guardian.user)
).generate(prompt, user: @guardian.user)
Nokogiri::HTML5.fragment(llm_response).at("ai")&.text&.presence || llm_response
end

View File

@ -99,7 +99,7 @@ module DiscourseAi
def summarize_single(llm, text, user, opts, &on_partial_blk)
prompt = summarization_prompt(text, opts)
llm.completion!(prompt, user, &on_partial_blk)
llm.generate(prompt, user: user, &on_partial_blk)
end
def summarize_in_chunks(llm, chunks, user, opts)
@ -107,7 +107,7 @@ module DiscourseAi
prompt = summarization_prompt(chunk[:summary], opts)
prompt[:post_insts] = "Don't use more than 400 words for the summary."
chunk[:summary] = llm.completion!(prompt, user)
chunk[:summary] = llm.generate(prompt, user: user)
chunk
end
end
@ -117,7 +117,7 @@ module DiscourseAi
prompt[:insts] = <<~TEXT
You are a summarization bot that effectively concatenates disjoint summaries, creating a cohesive narrative.
The narrative you create is in the form of one or multiple paragraphs.
Your reply MUST BE a single concatenated summary using the summaries I'll provide to you.
Your reply MUST BE a single concatenated summary using the summaries I'll provide to you.
I'm NOT interested in anything other than the concatenated summary, don't include additional text or comments.
You understand and generate Discourse forum Markdown.
You format the response, including links, using Markdown.
@ -131,7 +131,7 @@ module DiscourseAi
</input>
TEXT
llm.completion!(prompt, user, &on_partial_blk)
llm.generate(prompt, user: user, &on_partial_blk)
end
def summarization_prompt(input, opts)

View File

@ -74,7 +74,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do
#{prompt[:post_insts]}
[/INST] Ok </s>
[INST] #{prompt[:examples][0][0]} [/INST]
#{prompt[:examples][0][1]}
#{prompt[:examples][0][1]}</s>
[INST] #{prompt[:input]} [/INST]
TEXT
@ -102,7 +102,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do
</function_calls>
Here are the tools available:
<tools>
#{dialect.tools}</tools>
#{prompt[:post_insts]}

View File

@ -183,7 +183,7 @@ data: [D|ONE]
partials = []
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
llm.completion!({ insts: "test" }, Discourse.system_user) { |partial| partials << partial }
llm.generate({ insts: "test" }, user: Discourse.system_user) { |partial| partials << partial }
expect(partials.join).to eq("test,test2,test3,test4")
end
@ -212,7 +212,7 @@ data: [D|ONE]
partials = []
llm = DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo")
llm.completion!({ insts: "test" }, Discourse.system_user) { |partial| partials << partial }
llm.generate({ insts: "test" }, user: Discourse.system_user) { |partial| partials << partial }
expect(partials.join).to eq("test,test1,test2,test3,test4")
end

View File

@ -21,7 +21,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
end
end
describe "#completion!" do
describe "#generate" do
let(:prompt) do
{
insts: <<~TEXT,
@ -52,7 +52,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
context "when getting the full response" do
it "processes the prompt and return the response" do
llm_response = llm.completion!(prompt, user)
llm_response = llm.generate(prompt, user: user)
expect(llm_response).to eq(canned_response.responses[0])
end
@ -62,7 +62,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
it "processes the prompt and call the given block with the partial response" do
llm_response = +""
llm.completion!(prompt, user) { |partial, cancel_fn| llm_response << partial }
llm.generate(prompt, user: user) { |partial, cancel_fn| llm_response << partial }
expect(llm_response).to eq(canned_response.responses[0])
end

View File

@ -59,7 +59,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
it "returns a generic error when the completion call fails" do
DiscourseAi::Completions::Llm
.any_instance
.expects(:completion!)
.expects(:generate)
.raises(DiscourseAi::Completions::Endpoints::Base::CompletionFailed)
post "/discourse-ai/ai-helper/suggest", params: { mode: mode, text: text_to_proofread }