DEV: AI bot migration to the Llm pattern. (#343)

* DEV: AI bot migration to the Llm pattern.

We added tool and conversation context support to the Llm service in discourse-ai#366, meaning we met all the conditions to migrate this module.

This PR migrates to the new pattern, meaning adding a new bot now requires minimal effort as long as the service supports it. On top of this, we introduce the concept of a "Playground" to separate the PM-specific bits from the completion, allowing us to use the bot in other contexts like chat in the future. Commands are called tools, and we simplified all the placeholder logic to perform updates in a single place, making the flow more one-wayish.

* Followup fixes based on testing

* Cleanup unused inference code

* FIX: text-based tools could be in the middle of a sentence

* GPT-4-turbo support

* Use new LLM API
This commit is contained in:
Roman Rizzi 2024-01-04 10:44:07 -03:00 committed by GitHub
parent 03fc94684b
commit f9d7d7f5f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
103 changed files with 2641 additions and 5179 deletions

View File

@ -12,11 +12,11 @@ module DiscourseAi
# localized for system personas # localized for system personas
LocalizedAiPersonaSerializer.new(persona, root: false) LocalizedAiPersonaSerializer.new(persona, root: false)
end end
commands = tools =
DiscourseAi::AiBot::Personas::Persona.all_available_commands.map do |command| DiscourseAi::AiBot::Personas::Persona.all_available_tools.map do |tool|
AiCommandSerializer.new(command, root: false) AiToolSerializer.new(tool, root: false)
end end
render json: { ai_personas: ai_personas, meta: { commands: commands } } render json: { ai_personas: ai_personas, meta: { commands: tools } }
end end
def show def show

View File

@ -8,17 +8,25 @@ module ::Jobs
return unless bot_user = User.find_by(id: args[:bot_user_id]) return unless bot_user = User.find_by(id: args[:bot_user_id])
return unless post = Post.includes(:topic).find_by(id: args[:post_id]) return unless post = Post.includes(:topic).find_by(id: args[:post_id])
kwargs = {} begin
kwargs[:user] = post.user persona = nil
if persona_id = post.topic.custom_fields["ai_persona_id"] if persona_id = post.topic.custom_fields["ai_persona_id"]
kwargs[:persona_id] = persona_id.to_i persona =
else DiscourseAi::AiBot::Personas::Persona.find_by(user: post.user, id: persona_id.to_i)
kwargs[:persona_name] = post.topic.custom_fields["ai_persona"] raise DiscourseAi::AiBot::Bot::BOT_NOT_FOUND if persona.nil?
end end
begin if !persona && persona_name = post.topic.custom_fields["ai_persona"]
bot = DiscourseAi::AiBot::Bot.as(bot_user, **kwargs) persona =
bot.reply_to(post) DiscourseAi::AiBot::Personas::Persona.find_by(user: post.user, name: persona_name)
raise DiscourseAi::AiBot::Bot::BOT_NOT_FOUND if persona.nil?
end
persona ||= DiscourseAi::AiBot::Personas::General
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona.new)
DiscourseAi::AiBot::Playground.new(bot).reply_to(post)
rescue DiscourseAi::AiBot::Bot::BOT_NOT_FOUND rescue DiscourseAi::AiBot::Bot::BOT_NOT_FOUND
Rails.logger.warn( Rails.logger.warn(
"Bot not found for post #{post.id} - perhaps persona was deleted or bot was disabled", "Bot not found for post #{post.id} - perhaps persona was deleted or bot was disabled",

View File

@ -11,7 +11,7 @@ module ::Jobs
return unless post.topic.custom_fields[DiscourseAi::AiBot::EntryPoint::REQUIRE_TITLE_UPDATE] return unless post.topic.custom_fields[DiscourseAi::AiBot::EntryPoint::REQUIRE_TITLE_UPDATE]
bot.update_pm_title(post) DiscourseAi::AiBot::Playground.new(bot).title_playground(post)
end end
end end
end end

View File

@ -67,7 +67,7 @@ class AiPersona < ActiveRecord::Base
id = self.id id = self.id
system = self.system system = self.system
persona_class = DiscourseAi::AiBot::Personas.system_personas_by_id[self.id] persona_class = DiscourseAi::AiBot::Personas::Persona.system_personas_by_id[self.id]
if persona_class if persona_class
persona_class.define_singleton_method :allowed_group_ids do persona_class.define_singleton_method :allowed_group_ids do
allowed_group_ids allowed_group_ids
@ -90,8 +90,10 @@ class AiPersona < ActiveRecord::Base
options = {} options = {}
commands = tools = self.respond_to?(:commands) ? self.commands : self.tools
self.commands.filter_map do |element|
tools =
tools.filter_map do |element|
inner_name = element inner_name = element
current_options = nil current_options = nil
@ -100,8 +102,12 @@ class AiPersona < ActiveRecord::Base
current_options = element[1] current_options = element[1]
end end
# Won't migrate data yet. Let's rewrite to the tool name.
inner_name = inner_name.gsub("Command", "")
inner_name = "List#{inner_name}" if %w[Categories Tags].include?(inner_name)
begin begin
klass = ("DiscourseAi::AiBot::Commands::#{inner_name}").constantize klass = ("DiscourseAi::AiBot::Tools::#{inner_name}").constantize
options[klass] = current_options if current_options options[klass] = current_options if current_options
klass klass
rescue StandardError rescue StandardError
@ -143,8 +149,8 @@ class AiPersona < ActiveRecord::Base
super(*args, **kwargs) super(*args, **kwargs)
end end
define_method :commands do define_method :tools do
commands tools
end end
define_method :options do define_method :options do

View File

@ -1,10 +1,10 @@
# frozen_string_literal: true # frozen_string_literal: true
class AiCommandSerializer < ApplicationSerializer class AiToolSerializer < ApplicationSerializer
attributes :options, :id, :name, :help attributes :options, :id, :name, :help
def include_options? def include_options?
object.options.present? object.accepted_options.present?
end end
def id def id
@ -21,7 +21,7 @@ class AiCommandSerializer < ApplicationSerializer
def options def options
options = {} options = {}
object.options.each do |option| object.accepted_options.each do |option|
options[option.name] = { options[option.name] = {
name: option.localized_name, name: option.localized_name,
description: option.localized_description, description: option.localized_description,

View File

@ -132,6 +132,8 @@ en:
attribution: "Image by Stable Diffusion XL" attribution: "Image by Stable Diffusion XL"
ai_bot: ai_bot:
placeholder_reply: "I will reply shortly..."
personas: personas:
cannot_delete_system_persona: "System personas cannot be deleted, please disable it instead" cannot_delete_system_persona: "System personas cannot be deleted, please disable it instead"
cannot_edit_system_persona: "System personas can only be renamed, you may not edit commands or system prompt, instead disable and make a copy" cannot_edit_system_persona: "System personas can only be renamed, you may not edit commands or system prompt, instead disable and make a copy"
@ -157,6 +159,7 @@ en:
name: "DALL-E 3" name: "DALL-E 3"
description: "AI Bot specialized in generating images using DALL-E 3" description: "AI Bot specialized in generating images using DALL-E 3"
topic_not_found: "Summary unavailable, topic not found!" topic_not_found: "Summary unavailable, topic not found!"
summarizing: "Summarizing topic"
searching: "Searching for: '%{query}'" searching: "Searching for: '%{query}'"
command_options: command_options:
search: search:

View File

@ -1,6 +1,6 @@
# frozen_string_literal: true # frozen_string_literal: true
DiscourseAi::AiBot::Personas.system_personas.each do |persona_class, id| DiscourseAi::AiBot::Personas::Persona.system_personas.each do |persona_class, id|
persona = AiPersona.find_by(id: id) persona = AiPersona.find_by(id: id)
if !persona if !persona
persona = AiPersona.new persona = AiPersona.new
@ -32,7 +32,7 @@ DiscourseAi::AiBot::Personas.system_personas.each do |persona_class, id|
persona.system = true persona.system = true
instance = persona_class.new instance = persona_class.new
persona.commands = instance.commands.map { |command| command.to_s.split("::").last } persona.commands = instance.tools.map { |tool| tool.to_s.split("::").last }
persona.system_prompt = instance.system_prompt persona.system_prompt = instance.system_prompt
persona.save!(validate: false) persona.save!(validate: false)
end end

View File

@ -1,74 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
class AnthropicBot < Bot
def self.can_reply_as?(bot_user)
bot_user.id == DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
end
def bot_prompt_with_topic_context(post, allow_commands:)
super(post, allow_commands: allow_commands).join("\n\n") + "\n\nAssistant:"
end
def prompt_limit(allow_commands: true)
# no side channel for commands, so we can ignore allow commands
50_000 # https://console.anthropic.com/docs/prompt-design#what-is-a-prompt
end
def title_prompt(post)
super(post).join("\n\n") + "\n\nAssistant:"
end
def get_delta(partial, context)
completion = partial[:completion]
if completion&.start_with?(" ") && !context[:processed_first]
completion = completion[1..-1]
context[:processed_first] = true
end
completion
end
def tokenizer
DiscourseAi::Tokenizer::AnthropicTokenizer
end
private
def build_message(poster_username, content, system: false, function: nil)
role = poster_username == bot_user.username ? "Assistant" : "Human"
if system || function
content
else
"#{role}: #{content}"
end
end
def model_for
"claude-2"
end
def get_updated_title(prompt)
DiscourseAi::Inference::AnthropicCompletions.perform!(
prompt,
model_for,
temperature: 0.4,
max_tokens: 40,
).dig(:completion)
end
def submit_prompt(prompt, post: nil, prefer_low_cost: false, &blk)
DiscourseAi::Inference::AnthropicCompletions.perform!(
prompt,
model_for,
temperature: 0.4,
max_tokens: 3000,
post: post,
stop_sequences: ["\n\nHuman:", "</function_calls>"],
&blk
)
end
end
end
end

View File

@ -3,464 +3,152 @@
module DiscourseAi module DiscourseAi
module AiBot module AiBot
class Bot class Bot
class FunctionCalls
attr_accessor :maybe_buffer, :maybe_found, :custom
def initialize
@functions = []
@current_function = nil
@found = false
@cancel_completion = false
@maybe_buffer = +""
@maybe_found = false
@custom = false
end
def custom?
@custom
end
def found?
!@functions.empty? || @found
end
def found!
@found = true
end
def maybe_found?
@maybe_found
end
def cancel_completion?
@cancel_completion
end
def cancel_completion!
@cancel_completion = true
end
def add_function(name)
@current_function = { name: name, arguments: +"" }
@functions << @current_function
end
def add_argument_fragment(fragment)
@current_function[:arguments] << fragment
end
def length
@functions.length
end
def each
@functions.each { |function| yield function }
end
def to_a
@functions
end
end
attr_reader :bot_user, :persona
BOT_NOT_FOUND = Class.new(StandardError) BOT_NOT_FOUND = Class.new(StandardError)
MAX_COMPLETIONS = 5 MAX_COMPLETIONS = 5
def self.as(bot_user, persona_id: nil, persona_name: nil, user: nil) def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new)
available_bots = [DiscourseAi::AiBot::OpenAiBot, DiscourseAi::AiBot::AnthropicBot] new(bot_user, persona)
bot =
available_bots.detect(-> { raise BOT_NOT_FOUND }) do |bot_klass|
bot_klass.can_reply_as?(bot_user)
end end
persona = nil def initialize(bot_user, persona)
if persona_id
persona = DiscourseAi::AiBot::Personas.find_by(user: user, id: persona_id)
raise BOT_NOT_FOUND if persona.nil?
end
if !persona && persona_name
persona = DiscourseAi::AiBot::Personas.find_by(user: user, name: persona_name)
raise BOT_NOT_FOUND if persona.nil?
end
bot.new(bot_user, persona: persona&.new)
end
def initialize(bot_user, persona: nil)
@bot_user = bot_user @bot_user = bot_user
@persona = persona || DiscourseAi::AiBot::Personas::General.new @persona = persona
end end
def update_pm_title(post) attr_reader :bot_user
prompt = title_prompt(post)
new_title = get_updated_title(prompt).strip.split("\n").last def get_updated_title(conversation_context, post_user)
title_prompt = { insts: <<~TEXT, conversation_context: conversation_context }
PostRevisor.new(post.topic.first_post, post.topic).revise!( You are titlebot. Given a topic, you will figure out a title.
bot_user, You will never respond with anything but 7 word topic title.
title: new_title.sub(/\A"/, "").sub(/"\Z/, ""),
)
post.topic.custom_fields.delete(DiscourseAi::AiBot::EntryPoint::REQUIRE_TITLE_UPDATE)
post.topic.save_custom_fields
end
def reply_to(
post,
total_completions: 0,
bot_reply_post: nil,
prefer_low_cost: false,
standalone: false
)
return if total_completions > MAX_COMPLETIONS
# do not allow commands when we are at the end of chain (total completions == MAX_COMPLETIONS)
allow_commands = (total_completions < MAX_COMPLETIONS)
prompt =
if standalone && post.post_custom_prompt
username, standalone_prompt = post.post_custom_prompt.custom_prompt.last
[build_message(username, standalone_prompt)]
else
bot_prompt_with_topic_context(post, allow_commands: allow_commands)
end
redis_stream_key = nil
partial_reply = +""
reply = +(bot_reply_post ? bot_reply_post.raw.dup : "")
start = Time.now
setup_cancel = false
context = {}
functions = FunctionCalls.new
submit_prompt(prompt, post: post, prefer_low_cost: prefer_low_cost) do |partial, cancel|
current_delta = get_delta(partial, context)
partial_reply << current_delta
if !available_functions.empty?
populate_functions(
partial: partial,
reply: partial_reply,
functions: functions,
current_delta: current_delta,
done: false,
)
cancel&.call if functions.cancel_completion?
end
if functions.maybe_buffer.present? && !functions.maybe_found?
reply << functions.maybe_buffer
functions.maybe_buffer = +""
end
reply << current_delta if !functions.found? && !functions.maybe_found?
if redis_stream_key && !Discourse.redis.get(redis_stream_key)
cancel&.call
bot_reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) if bot_reply_post
end
# Minor hack to skip the delay during tests.
next if (Time.now - start < 0.5) && !Rails.env.test?
if bot_reply_post
Discourse.redis.expire(redis_stream_key, 60)
start = Time.now
publish_update(bot_reply_post, raw: reply.dup)
else
bot_reply_post =
PostCreator.create!(
bot_user,
topic_id: post.topic_id,
raw: reply,
skip_validations: true,
)
end
if !setup_cancel && bot_reply_post
redis_stream_key = "gpt_cancel:#{bot_reply_post.id}"
Discourse.redis.setex(redis_stream_key, 60, 1)
setup_cancel = true
end
end
if !available_functions.empty?
populate_functions(
partial: nil,
reply: partial_reply,
current_delta: "",
functions: functions,
done: true,
)
end
if functions.maybe_buffer.present?
reply << functions.maybe_buffer
functions.maybe_buffer = +""
end
if bot_reply_post
publish_update(bot_reply_post, done: true)
bot_reply_post.revise(
bot_user,
{ raw: reply },
skip_validations: true,
skip_revision: true,
)
bot_reply_post.post_custom_prompt ||= post.build_post_custom_prompt(custom_prompt: [])
prompt = post.post_custom_prompt.custom_prompt || []
truncated_reply = partial_reply
# TODO: we may want to move this code
if functions.length > 0 && partial_reply.include?("</invoke>")
# recover stop word potentially
truncated_reply =
partial_reply.split("</invoke>").first + "</invoke>\n</function_calls>"
end
prompt << [truncated_reply, bot_user.username] if truncated_reply.present?
post.post_custom_prompt.update!(custom_prompt: prompt)
end
if functions.length > 0
chain = false
standalone = false
functions.each do |function|
name, args = function[:name], function[:arguments]
if command_klass = available_commands.detect { |cmd| cmd.invoked?(name) }
command =
command_klass.new(
bot: self,
args: args,
post: bot_reply_post,
parent_post: post,
xml_format: !functions.custom?,
)
chain_intermediate, bot_reply_post = command.invoke!
chain ||= chain_intermediate
standalone ||= command.standalone?
end
end
if chain
reply_to(
bot_reply_post,
total_completions: total_completions + 1,
bot_reply_post: bot_reply_post,
standalone: standalone,
)
end
end
rescue => e
if Rails.env.development?
p e
puts e.backtrace
end
raise e if Rails.env.test?
Discourse.warn_exception(e, message: "ai-bot: Reply failed")
end
def extra_tokens_per_message
0
end
def bot_prompt_with_topic_context(post, allow_commands:)
messages = []
conversation = conversation_context(post)
rendered_system_prompt = system_prompt(post, allow_commands: allow_commands)
total_prompt_tokens = tokenize(rendered_system_prompt).length + extra_tokens_per_message
prompt_limit = self.prompt_limit(allow_commands: allow_commands)
conversation.each do |raw, username, function|
break if total_prompt_tokens >= prompt_limit
tokens = tokenize(raw.to_s + username.to_s)
while !raw.blank? &&
tokens.length + total_prompt_tokens + extra_tokens_per_message > prompt_limit
raw = raw[0..-100] || ""
tokens = tokenize(raw.to_s + username.to_s)
end
next if raw.blank?
total_prompt_tokens += tokens.length + extra_tokens_per_message
messages.unshift(build_message(username, raw, function: !!function))
end
messages.unshift(build_message(bot_user.username, rendered_system_prompt, system: true))
messages
end
def prompt_limit(allow_commands: false)
raise NotImplemented
end
def title_prompt(post)
prompt = <<~TEXT
You are titlebot. Given a topic you will figure out a title.
You will never respond with anything but a 7 word topic title.
TEXT TEXT
messages = [build_message(bot_user.username, prompt, system: true)]
messages << build_message("User", <<~TEXT) title_prompt[
Suggest a 7 word title for the following topic without quoting any of it: :input
] = "Based on our previous conversation, suggest a 7 word title without quoting any of it."
<content> DiscourseAi::Completions::Llm
#{post.topic.posts.map(&:raw).join("\n\n")[0..prompt_limit(allow_commands: false)]} .proxy(model)
</content> .generate(title_prompt, user: post_user)
TEXT .strip
messages .split("\n")
.last
end end
def available_commands def reply(context, &update_blk)
@persona.available_commands prompt = persona.craft_prompt(context)
end
def system_prompt_style!(style) total_completions = 0
@style = style ongoing_chain = true
end low_cost = false
raw_context = []
def system_prompt(post, allow_commands:) while total_completions <= MAX_COMPLETIONS && ongoing_chain
return "You are a helpful Bot" if @style == :simple current_model = model(prefer_low_cost: low_cost)
llm = DiscourseAi::Completions::Llm.proxy(current_model)
tool_found = false
@persona.render_system_prompt( llm.generate(prompt, user: context[:user]) do |partial, cancel|
topic: post.topic, if (tool = persona.find_tool(partial))
allow_commands: allow_commands, tool_found = true
render_function_instructions: ongoing_chain = tool.chain_next_response?
allow_commands && include_function_instructions_in_system_prompt?, low_cost = tool.low_cost?
) tool_call_id = tool.tool_call_id
end invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json
def include_function_instructions_in_system_prompt? invocation_context = {
true type: "tool",
end name: tool_call_id,
content: invocation_result_json,
}
tool_context = {
type: "tool_call",
name: tool_call_id,
content: { name: tool.name, arguments: tool.parameters }.to_json,
}
def function_list prompt[:conversation_context] ||= []
@persona.function_list
end
def tokenizer if tool.standalone?
raise NotImplemented prompt[:conversation_context] = [invocation_context, tool_context]
end
def tokenize(text)
tokenizer.tokenize(text)
end
def submit_prompt(
prompt,
post:,
prefer_low_cost: false,
temperature: nil,
max_tokens: nil,
&blk
)
raise NotImplemented
end
def get_delta(partial, context)
raise NotImplemented
end
def populate_functions(partial:, reply:, functions:, done:, current_delta:)
if !done
search_length = "<function_calls>".length
index = -1
while index > -search_length
substr = reply[index..-1] || reply
index -= 1
functions.maybe_found = "<function_calls>".start_with?(substr)
break if functions.maybe_found?
end
functions.maybe_buffer << current_delta if functions.maybe_found?
functions.found! if reply.match?(/^<function_calls>/i)
if functions.found?
functions.maybe_buffer = functions.maybe_buffer.to_s.split("<")[0..-2].join("<")
functions.cancel_completion! if reply.match?(%r{</function_calls>}i)
end
else else
functions_string = reply.scan(%r{(<function_calls>(.*?)</invoke>)}im)&.first&.first prompt[:conversation_context] = [invocation_context, tool_context] +
if functions_string prompt[:conversation_context]
function_list
.parse_prompt(functions_string + "</function_calls>")
.each do |function|
functions.add_function(function[:name])
functions.add_argument_fragment(function[:arguments].to_json)
end
end
end
end end
def available_functions raw_context << [tool_context[:content], tool_call_id, "tool_call"]
@persona.available_functions raw_context << [invocation_result_json, tool_call_id, "tool"]
end
protected
def get_updated_title(prompt)
raise NotImplemented
end
def model_for(bot)
raise NotImplemented
end
def conversation_context(post)
context =
post
.topic
.posts
.includes(:user)
.joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id")
.where("post_number <= ?", post.post_number)
.order("post_number desc")
.where("post_type = ?", Post.types[:regular])
.limit(50)
.pluck(:raw, :username, "post_custom_prompts.custom_prompt")
result = []
first = true
context.each do |raw, username, custom_prompt|
if custom_prompt.present?
if first
custom_prompt.reverse_each { |message| result << message }
first = false
else else
result << custom_prompt.first update_blk.call(partial, cancel, nil)
end
else
result << [raw, username]
end end
end end
ongoing_chain = false if !tool_found
total_completions += 1
# do not allow tools when we are at the end of a chain (total_completions == MAX_COMPLETIONS)
prompt.delete(:tools) if total_completions == MAX_COMPLETIONS
end
raw_context
end
private
attr_reader :persona
def invoke_tool(tool, llm, cancel, &update_blk)
update_blk.call("", cancel, build_placeholder(tool.summary, ""))
result =
tool.invoke(bot_user, llm) do |progress|
placeholder = build_placeholder(tool.summary, progress)
update_blk.call("", cancel, placeholder)
end
tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)
update_blk.call(tool_details, cancel, nil)
result result
end end
def publish_update(bot_reply_post, payload) def model(prefer_low_cost: false)
MessageBus.publish( default_model =
"discourse-ai/ai-bot/topic/#{bot_reply_post.topic_id}", case bot_user.id
payload.merge(post_id: bot_reply_post.id, post_number: bot_reply_post.post_number), when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
user_ids: bot_reply_post.topic.allowed_user_ids, "claude-2"
) when DiscourseAi::AiBot::EntryPoint::GPT4_ID
"gpt-4"
when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
"gpt-4-turbo"
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
"gpt-3.5-turbo-16k"
else
nil
end
if %w[gpt-4 gpt-4-turbo].include?(default_model) && prefer_low_cost
return "gpt-3.5-turbo-16k"
end
default_model
end
def tool_invocation?(partial)
Nokogiri::HTML5.fragment(partial).at("invoke").present?
end
def build_placeholder(summary, details, custom_raw: nil)
placeholder = +(<<~HTML)
<details>
<summary>#{summary}</summary>
<p>#{details}</p>
</details>
HTML
placeholder << custom_raw << "\n" if custom_raw
placeholder
end end
end end
end end

View File

@ -1,45 +0,0 @@
#frozen_string_literal: true
module DiscourseAi::AiBot::Commands
class CategoriesCommand < Command
class << self
def name
"categories"
end
def desc
"Will list the categories on the current discourse instance, prefer to format with # in front of the category name"
end
def parameters
[]
end
end
def result_name
"results"
end
def description_args
{ count: @last_count || 0 }
end
def process
columns = {
name: "Name",
slug: "Slug",
description: "Description",
posts_year: "Posts Year",
posts_month: "Posts Month",
posts_week: "Posts Week",
id: "id",
parent_category_id: "parent_category_id",
}
rows = Category.where(read_restricted: false).limit(100).pluck(*columns.keys)
@last_count = rows.length
format_results(rows, columns.values)
end
end
end

View File

@ -1,237 +0,0 @@
#frozen_string_literal: true
module DiscourseAi
module AiBot
module Commands
class Command
CARET = "<!-- caret -->"
PROGRESS_CARET = "<!-- progress -->"
class << self
def name
raise NotImplemented
end
def invoked?(cmd_name)
cmd_name == name
end
def desc
raise NotImplemented
end
def custom_system_message
end
def parameters
raise NotImplemented
end
def options
[]
end
def help
I18n.t("discourse_ai.ai_bot.command_help.#{name}")
end
def option(name, type:)
Option.new(command: self, name: name, type: type)
end
end
attr_reader :bot_user, :bot
def initialize(bot:, args:, post: nil, parent_post: nil, xml_format: false)
@bot = bot
@bot_user = bot&.bot_user
@args = args
@post = post
@parent_post = parent_post
@xml_format = xml_format
@placeholder = +(<<~HTML).strip
<details>
<summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary>
<p>
#{CARET}
</p>
</details>
#{PROGRESS_CARET}
HTML
@invoked = false
end
def persona_options
return @persona_options if @persona_options
@persona_options = HashWithIndifferentAccess.new
# during tests we may operate without a bot
return @persona_options if !self.bot
self.class.options.each do |option|
val = self.bot.persona.options.dig(self.class, option.name)
@persona_options[option.name] = val if val
end
@persona_options
end
def tokenizer
bot.tokenizer
end
def standalone?
false
end
def low_cost?
false
end
def result_name
raise NotImplemented
end
def name
raise NotImplemented
end
def process(post)
raise NotImplemented
end
def description_args
{}
end
def custom_raw
end
def chain_next_response
true
end
def show_progress(text, progress_caret: false)
return if !@post
return if !@placeholder
# during tests we may have none
caret = progress_caret ? PROGRESS_CARET : CARET
new_placeholder = @placeholder.sub(caret, text + caret)
raw = @post.raw.sub(@placeholder, new_placeholder)
@placeholder = new_placeholder
@post.revise(bot_user, { raw: raw }, skip_validations: true, skip_revision: true)
end
def localized_description
I18n.t(
"discourse_ai.ai_bot.command_description.#{self.class.name}",
self.description_args,
)
end
def invoke!
raise StandardError.new("Command can only be invoked once!") if @invoked
@invoked = true
if !@post
@post =
PostCreator.create!(
bot_user,
raw: @placeholder,
topic_id: @parent_post.topic_id,
skip_validations: true,
skip_rate_limiter: true,
)
else
@post.revise(
bot_user,
{ raw: @post.raw + "\n\n" + @placeholder + "\n\n" },
skip_validations: true,
skip_revision: true,
)
end
@post.post_custom_prompt ||= @post.build_post_custom_prompt(custom_prompt: [])
prompt = @post.post_custom_prompt.custom_prompt || []
parsed_args = JSON.parse(@args).symbolize_keys
function_results = process(**parsed_args).to_json
function_results = <<~XML if @xml_format
<function_results>
<result>
<tool_name>#{self.class.name}</tool_name>
<json>
#{function_results}
</json>
</result>
</function_results>
XML
prompt << [function_results, self.class.name, "function"]
@post.post_custom_prompt.update!(custom_prompt: prompt)
raw = +(<<~HTML)
<details>
<summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary>
<p>
#{localized_description}
</p>
</details>
HTML
raw << custom_raw if custom_raw.present?
raw = @post.raw.sub(@placeholder, raw)
@post.revise(bot_user, { raw: raw }, skip_validations: true, skip_revision: true)
if chain_next_response
# somewhat annoying but whitespace was stripped in revise
# so we need to save again
@post.raw = raw
@post.save!(validate: false)
end
[chain_next_response, @post]
end
def format_results(rows, column_names = nil, args: nil)
rows = rows&.map { |row| yield row } if block_given?
if !column_names
index = -1
column_indexes = {}
rows =
rows&.map do |data|
new_row = []
data.each do |key, value|
found_index = column_indexes[key.to_s] ||= (index += 1)
new_row[found_index] = value
end
new_row
end
column_names = column_indexes.keys
end
# this is not the most efficient format
# however this is needed cause GPT 3.5 / 4 was steered using JSON
result = { column_names: column_names, rows: rows }
result[:args] = args if args
result
end
protected
attr_reader :bot_user, :args
end
end
end
end

View File

@ -1,122 +0,0 @@
#frozen_string_literal: true
module DiscourseAi::AiBot::Commands
class DallECommand < Command
class << self
def name
"dall_e"
end
def desc
"Renders images from supplied descriptions"
end
def parameters
[
Parameter.new(
name: "prompts",
description:
"The prompts used to generate or create or draw the image (5000 chars or less, be creative) up to 4 prompts",
type: "array",
item_type: "string",
required: true,
),
]
end
end
def result_name
"results"
end
def description_args
{ prompt: @last_prompt }
end
def chain_next_response
false
end
def custom_raw
@custom_raw
end
def process(prompts:)
# max 4 prompts
prompts = prompts.take(4)
@last_prompt = prompts[0]
show_progress(localized_description)
results = nil
# this ensures multisite safety since background threads
# generate the images
api_key = SiteSetting.ai_openai_api_key
api_url = SiteSetting.ai_openai_dall_e_3_url
threads = []
prompts.each_with_index do |prompt, index|
threads << Thread.new(prompt) do |inner_prompt|
attempts = 0
begin
DiscourseAi::Inference::OpenAiImageGenerator.perform!(
inner_prompt,
api_key: api_key,
api_url: api_url,
)
rescue => e
attempts += 1
sleep 2
retry if attempts < 3
Discourse.warn_exception(e, message: "Failed to generate image for prompt #{prompt}")
nil
end
end
end
while true
show_progress(".", progress_caret: true)
break if threads.all? { |t| t.join(2) }
end
results = threads.filter_map(&:value)
if results.blank?
return { prompts: prompts, error: "Something went wrong, could not generate image" }
end
uploads = []
results.each_with_index do |result, index|
result[:data].each do |image|
Tempfile.create("v1_txt2img_#{index}.png") do |file|
file.binmode
file.write(Base64.decode64(image[:b64_json]))
file.rewind
uploads << {
prompt: image[:revised_prompt],
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
}
end
end
end
@custom_raw = <<~RAW
[grid]
#{
uploads
.map do |item|
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
end
.join(" ")
}
[/grid]
RAW
{ prompts: uploads.map { |item| item[:prompt] } }
end
end
end

View File

@ -1,54 +0,0 @@
#frozen_string_literal: true
module DiscourseAi::AiBot::Commands
class DbSchemaCommand < Command
class << self
def name
"schema"
end
def desc
"Will load schema information for specific tables in the database"
end
def parameters
[
Parameter.new(
name: "tables",
description:
"list of tables to load schema information for, comma seperated list eg: (users,posts))",
type: "string",
required: true,
),
]
end
end
def result_name
"results"
end
def description_args
{ tables: @tables.join(", ") }
end
def process(tables:)
@tables = tables.split(",").map(&:strip)
table_info = {}
DB
.query(<<~SQL, @tables)
select table_name, column_name, data_type from information_schema.columns
where table_schema = 'public'
and table_name in (?)
order by table_name
SQL
.each { |row| (table_info[row.table_name] ||= []) << "#{row.column_name} #{row.data_type}" }
schema_info =
table_info.map { |table_name, columns| "#{table_name}(#{columns.join(",")})" }.join("\n")
{ schema_info: schema_info, tables: tables }
end
end
end

View File

@ -1,82 +0,0 @@
#frozen_string_literal: true
module DiscourseAi::AiBot::Commands
class GoogleCommand < Command
class << self
def name
"google"
end
def desc
"Will search using Google - global internet search (supports all Google search operators)"
end
def parameters
[
Parameter.new(
name: "query",
description: "The search query",
type: "string",
required: true,
),
]
end
def custom_system_message
"You were trained on OLD data, lean on search to get up to date information from the web"
end
end
def result_name
"results"
end
def description_args
{
count: @last_num_results || 0,
query: @last_query || "",
url: "https://google.com/search?q=#{CGI.escape(@last_query || "")}",
}
end
def process(query:)
@last_query = query
show_progress(localized_description)
api_key = SiteSetting.ai_google_custom_search_api_key
cx = SiteSetting.ai_google_custom_search_cx
query = CGI.escape(query)
uri =
URI("https://www.googleapis.com/customsearch/v1?key=#{api_key}&cx=#{cx}&q=#{query}&num=10")
body = Net::HTTP.get(uri)
parse_search_json(body, query)
end
def minimize_field(result, field, max_tokens: 100)
data = result[field]
return "" if data.blank?
data = ::DiscourseAi::Tokenizer::BertTokenizer.truncate(data, max_tokens).squish
data
end
def parse_search_json(json_data, query)
parsed = JSON.parse(json_data)
results = parsed["items"]
@last_num_results = parsed.dig("searchInformation", "totalResults").to_i
format_results(results, args: query) do |result|
{
title: minimize_field(result, "title"),
link: minimize_field(result, "link"),
snippet: minimize_field(result, "snippet", max_tokens: 120),
displayLink: minimize_field(result, "displayLink"),
formattedUrl: minimize_field(result, "formattedUrl"),
}
end
end
end
end

View File

@ -1,135 +0,0 @@
#frozen_string_literal: true
module DiscourseAi::AiBot::Commands
class ImageCommand < Command
class << self
def name
"image"
end
def desc
"Renders an image from the description (remove all connector words, keep it to 40 words or less). Despite being a text based bot you can generate images! (when user asks to draw, paint or other synonyms try this)"
end
def parameters
[
Parameter.new(
name: "prompts",
description:
"The prompts used to generate or create or draw the image (40 words or less, be creative) up to 4 prompts",
type: "array",
item_type: "string",
required: true,
),
Parameter.new(
name: "seeds",
description:
"The seed used to generate the image (optional) - can be used to retain image style on amended prompts",
type: "array",
item_type: "integer",
),
]
end
end
def result_name
"results"
end
def description_args
{ prompt: @last_prompt }
end
def chain_next_response
false
end
def custom_raw
@custom_raw
end
def process(prompts:, seeds: nil)
# max 4 prompts
prompts = prompts[0..3]
seeds = seeds[0..3] if seeds
@last_prompt = prompts[0]
show_progress(localized_description)
results = nil
# this ensures multisite safety since background threads
# generate the images
api_key = SiteSetting.ai_stability_api_key
engine = SiteSetting.ai_stability_engine
api_url = SiteSetting.ai_stability_api_url
threads = []
prompts.each_with_index do |prompt, index|
seed = seeds ? seeds[index] : nil
threads << Thread.new(seed, prompt) do |inner_seed, inner_prompt|
attempts = 0
begin
DiscourseAi::Inference::StabilityGenerator.perform!(
inner_prompt,
engine: engine,
api_key: api_key,
api_url: api_url,
image_count: 1,
seed: inner_seed,
)
rescue => e
attempts += 1
retry if attempts < 3
Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}")
nil
end
end
end
while true
show_progress(".", progress_caret: true)
break if threads.all? { |t| t.join(2) }
end
results = threads.map(&:value).compact
if !results.present?
return { prompts: prompts, error: "Something went wrong, could not generate image" }
end
uploads = []
results.each_with_index do |result, index|
result[:artifacts].each do |image|
Tempfile.create("v1_txt2img_#{index}.png") do |file|
file.binmode
file.write(Base64.decode64(image[:base64]))
file.rewind
uploads << {
prompt: prompts[index],
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
seed: image[:seed],
}
end
end
end
@custom_raw = <<~RAW
[grid]
#{
uploads
.map do |item|
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
end
.join(" ")
}
[/grid]
RAW
{ prompts: uploads.map { |item| item[:prompt] }, seeds: uploads.map { |item| item[:seed] } }
end
end
end

View File

@ -1,23 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
module Commands
class Option
attr_reader :command, :name, :type
def initialize(command:, name:, type:)
@command = command
@name = name.to_s
@type = type
end
def localized_name
I18n.t("discourse_ai.ai_bot.command_options.#{command.name}.#{name}.name")
end
def localized_description
I18n.t("discourse_ai.ai_bot.command_options.#{command.name}.#{name}.description")
end
end
end
end
end

View File

@ -1,18 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
module Commands
class Parameter
attr_reader :item_type, :name, :description, :type, :enum, :required
def initialize(name:, description:, type:, enum: nil, required: false, item_type: nil)
@name = name
@description = description
@type = type
@enum = enum
@required = required
@item_type = item_type
end
end
end
end
end

View File

@ -1,77 +0,0 @@
#frozen_string_literal: true
module DiscourseAi::AiBot::Commands
class ReadCommand < Command
class << self
def name
"read"
end
def desc
"Will read a topic or a post on this Discourse instance"
end
def parameters
[
Parameter.new(
name: "topic_id",
description: "the id of the topic to read",
type: "integer",
required: true,
),
Parameter.new(
name: "post_number",
description: "the post number to read",
type: "integer",
required: false,
),
]
end
end
def description_args
{ title: @title, url: @url }
end
def process(topic_id:, post_number: nil)
not_found = { topic_id: topic_id, description: "Topic not found" }
@title = ""
topic_id = topic_id.to_i
topic = Topic.find_by(id: topic_id)
return not_found if !topic || !Guardian.new.can_see?(topic)
@title = topic.title
posts = Post.secured(Guardian.new).where(topic_id: topic_id).order(:post_number).limit(40)
@url = topic.relative_url(post_number)
posts = posts.where("post_number = ?", post_number) if post_number
content = +<<~TEXT.strip
title: #{topic.title}
TEXT
category_names = [topic.category&.parent_category&.name, topic.category&.name].compact.join(
" ",
)
content << "\ncategories: #{category_names}" if category_names.present?
if topic.tags.length > 0
tags = DiscourseTagging.filter_visible(topic.tags, Guardian.new)
content << "\ntags: #{tags.map(&:name).join(", ")}\n\n" if tags.length > 0
end
posts.each { |post| content << "\n\n#{post.username} said:\n\n#{post.raw}" }
# TODO: 16k or 100k models can handle a lot more tokens
content = tokenizer.truncate(content, 1500).squish
result = { topic_id: topic_id, content: content, complete: true }
result[:post_number] = post_number if post_number
result
end
end
end

View File

@ -1,232 +0,0 @@
#frozen_string_literal: true
module DiscourseAi::AiBot::Commands
class SearchCommand < Command
class << self
def name
"search"
end
def desc
"Will search topics in the current discourse instance, when rendering always prefer to link to the topics you find"
end
def options
[option(:base_query, type: :string), option(:max_results, type: :integer)]
end
def parameters
[
Parameter.new(
name: "search_query",
description:
"Specific keywords to search for, space seperated (correct bad spelling, remove connector words)",
type: "string",
),
Parameter.new(
name: "user",
description:
"Filter search results to this username (only include if user explicitly asks to filter by user)",
type: "string",
),
Parameter.new(
name: "order",
description: "search result order",
type: "string",
enum: %w[latest latest_topic oldest views likes],
),
Parameter.new(
name: "limit",
description:
"Number of results to return. Defaults to maximum number of results. Only set if absolutely necessary",
type: "integer",
),
Parameter.new(
name: "max_posts",
description:
"maximum number of posts on the topics (topics where lots of people posted)",
type: "integer",
),
Parameter.new(
name: "tags",
description:
"list of tags to search for. Use + to join with OR, use , to join with AND",
type: "string",
),
Parameter.new(
name: "category",
description: "category name to filter to",
type: "string",
),
Parameter.new(
name: "before",
description: "only topics created before a specific date YYYY-MM-DD",
type: "string",
),
Parameter.new(
name: "after",
description: "only topics created after a specific date YYYY-MM-DD",
type: "string",
),
Parameter.new(
name: "status",
description: "search for topics in a particular state",
type: "string",
enum: %w[open closed archived noreplies single_user],
),
]
end
def custom_system_message
<<~TEXT
You were trained on OLD data, lean on search to get up to date information about this forum
When searching try to SIMPLIFY search terms
Discourse search joins all terms with AND. Reduce and simplify terms to find more results.
TEXT
end
end
def result_name
"results"
end
def description_args
{
count: @last_num_results || 0,
query: @last_query || "",
url: "#{Discourse.base_path}/search?q=#{CGI.escape(@last_query || "")}",
}
end
MIN_SEMANTIC_RESULTS = 5
def max_semantic_results
max_results / 4
end
def max_results
return 20 if !bot
max_results = persona_options[:max_results].to_i
return [max_results, 100].min if max_results > 0
if bot.prompt_limit(allow_commands: false) > 30_000
60
elsif bot.prompt_limit(allow_commands: false) > 10_000
40
else
20
end
end
def process(**search_args)
limit = nil
search_string =
search_args
.map do |key, value|
if key == :search_query
value
elsif key == :limit
limit = value.to_i
nil
else
"#{key}:#{value}"
end
end
.compact
.join(" ")
@last_query = search_string
show_progress(I18n.t("discourse_ai.ai_bot.searching", query: search_string))
if persona_options[:base_query].present?
search_string = "#{search_string} #{persona_options[:base_query]}"
end
results =
Search.execute(
search_string.to_s + " status:public",
search_type: :full_page,
guardian: Guardian.new(),
)
# let's be frugal with tokens, 50 results is too much and stuff gets cut off
limit ||= max_results
limit = max_results if limit > max_results
should_try_semantic_search = SiteSetting.ai_embeddings_semantic_search_enabled
should_try_semantic_search &&= (limit == max_results)
should_try_semantic_search &&= (search_args[:search_query].present?)
limit = limit - max_semantic_results if should_try_semantic_search
posts = results&.posts || []
posts = posts[0..limit - 1]
if should_try_semantic_search
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(Guardian.new())
topic_ids = Set.new(posts.map(&:topic_id))
search = Search.new(search_string, guardian: Guardian.new)
results = nil
begin
results = semantic_search.search_for_topics(search.term)
rescue => e
Discourse.warn_exception(e, message: "Semantic search failed")
end
if results
results = search.apply_filters(results)
results.each do |post|
next if topic_ids.include?(post.topic_id)
topic_ids << post.topic_id
posts << post
break if posts.length >= max_results
end
end
end
@last_num_results = posts.length
# this is the general pattern from core
# if there are millions of hidden tags it may fail
hidden_tags = nil
if posts.blank?
{ args: search_args, rows: [], instruction: "nothing was found, expand your search" }
else
format_results(posts, args: search_args) do |post|
category_names = [
post.topic.category&.parent_category&.name,
post.topic.category&.name,
].compact.join(" > ")
row = {
title: post.topic.title,
url: Discourse.base_path + post.url,
username: post.user&.username,
excerpt: post.excerpt,
created: post.created_at,
category: category_names,
likes: post.like_count,
topic_views: post.topic.views,
topic_likes: post.topic.like_count,
topic_replies: post.topic.posts_count - 1,
}
if SiteSetting.tagging_enabled
hidden_tags ||= DiscourseTagging.hidden_tag_names
# using map over pluck to avoid n+1 (assuming caller preloading)
tags = post.topic.tags.map(&:name) - hidden_tags
row[:tags] = tags.join(", ") if tags.present?
end
row
end
end
end
end
end

View File

@ -1,85 +0,0 @@
#frozen_string_literal: true
module DiscourseAi::AiBot::Commands
class SearchSettingsCommand < Command
class << self
def name
"search_settings"
end
def desc
"Will search through site settings and return top 20 results"
end
def parameters
[
Parameter.new(
name: "query",
description:
"comma delimited list of settings to search for (e.g. 'setting_1,setting_2')",
type: "string",
required: true,
),
]
end
end
def result_name
"results"
end
def description_args
{ count: @last_num_results || 0, query: @last_query || "" }
end
INCLUDE_DESCRIPTIONS_MAX_LENGTH = 10
MAX_RESULTS = 200
def process(query:)
@last_query = query
@last_num_results = 0
terms = query.split(",").map(&:strip).map(&:downcase).reject(&:blank?)
found =
SiteSetting.all_settings.filter do |setting|
name = setting[:setting].to_s.downcase
description = setting[:description].to_s.downcase
plugin = setting[:plugin].to_s.downcase
search_string = "#{name} #{description} #{plugin}"
terms.any? { |term| search_string.include?(term) }
end
if found.blank?
{
args: {
query: query,
},
rows: [],
instruction: "no settings matched #{query}, expand your search",
}
else
include_descriptions = false
if found.length > MAX_RESULTS
found = found[0..MAX_RESULTS]
elsif found.length < INCLUDE_DESCRIPTIONS_MAX_LENGTH
include_descriptions = true
end
@last_num_results = found.length
format_results(found, args: { query: query }) do |setting|
result = { name: setting[:setting] }
if include_descriptions
result[:description] = setting[:description]
result[:plugin] = setting[:plugin]
end
result
end
end
end
end
end

View File

@ -1,154 +0,0 @@
#frozen_string_literal: true
module DiscourseAi::AiBot::Commands
MAX_CONTEXT_TOKENS = 2000
class SettingContextCommand < Command
def self.rg_installed?
if defined?(@rg_installed)
@rg_installed
else
@rg_installed =
begin
Discourse::Utils.execute_command("which", "rg")
true
rescue Discourse::Utils::CommandError
false
end
end
end
class << self
def name
"setting_context"
end
def desc
"Will provide you with full context regarding a particular site setting in Discourse"
end
def parameters
[
Parameter.new(
name: "setting_name",
description: "The name of the site setting we need context for",
type: "string",
required: true,
),
]
end
end
def result_name
"context"
end
def description_args
{ setting_name: @setting_name }
end
CODE_FILE_EXTENSIONS = "rb,js,gjs,hbs"
def process(setting_name:)
if !self.class.rg_installed?
return(
{
setting_name: setting_name,
context: "This command requires the rg command line tool to be installed on the server",
}
)
end
@setting_name = setting_name
if !SiteSetting.has_setting?(setting_name)
{ setting_name: setting_name, context: "This setting does not exist" }
else
description = SiteSetting.description(setting_name)
result = +"# #{setting_name}\n#{description}\n\n"
setting_info =
find_setting_info(setting_name, [Rails.root.join("config", "site_settings.yml").to_s])
if !setting_info
setting_info =
find_setting_info(setting_name, Dir[Rails.root.join("plugins/**/settings.yml")])
end
result << setting_info
result << "\n\n"
%w[lib app plugins].each do |dir|
path = Rails.root.join(dir).to_s
result << Discourse::Utils.execute_command(
"rg",
setting_name,
path,
"-g",
"!**/spec/**",
"-g",
"!**/dist/**",
"-g",
"*.{#{CODE_FILE_EXTENSIONS}}",
"-C",
"10",
"--color",
"never",
"--heading",
"--no-ignore",
chdir: path,
success_status_codes: [0, 1],
)
end
result.gsub!(/^#{Regexp.escape(Rails.root.to_s)}/, "")
result = tokenizer.truncate(result, MAX_CONTEXT_TOKENS)
{ setting_name: setting_name, context: result }
end
end
def find_setting_info(name, paths)
path, result = nil
paths.each do |search_path|
result =
Discourse::Utils.execute_command(
"rg",
name,
search_path,
"-g",
"*.{#{CODE_FILE_EXTENSIONS}}",
"-A",
"10",
"--color",
"never",
"--heading",
success_status_codes: [0, 1],
)
if !result.blank?
path = search_path
break
end
end
if result.blank?
nil
else
rows = result.split("\n")
leading_spaces = rows[0].match(/^\s*/)[0].length
filtered = []
rows.each do |row|
if !filtered.blank?
break if row.match(/^\s*/)[0].length <= leading_spaces
end
filtered << row
end
filtered.unshift("#{path}")
filtered.join("\n")
end
end
end
end

View File

@ -1,184 +0,0 @@
#frozen_string_literal: true
module DiscourseAi::AiBot::Commands
class SummarizeCommand < Command
class << self
def name
"summarize"
end
def desc
"Will summarize a topic attempting to answer question in guidance"
end
def parameters
[
Parameter.new(
name: "topic_id",
description: "The discourse topic id to summarize",
type: "integer",
required: true,
),
Parameter.new(
name: "guidance",
description: "Special guidance on how to summarize the topic",
type: "string",
),
]
end
end
def result_name
"summary"
end
def standalone?
true
end
def low_cost?
true
end
def description_args
{ url: "#{Discourse.base_path}/t/-/#{@last_topic_id}", title: @last_topic_title || "" }
end
def process(topic_id:, guidance: nil)
@last_topic_id = topic_id
topic_id = topic_id.to_i
topic = nil
if topic_id > 0
topic = Topic.find_by(id: topic_id)
topic = nil if !topic || !Guardian.new.can_see?(topic)
end
@last_summary = nil
if topic
@last_topic_title = topic.title
posts =
Post
.where(topic_id: topic.id)
.where("post_type in (?)", [Post.types[:regular], Post.types[:small_action]])
.where("not hidden")
.order(:post_number)
columns = ["posts.id", :post_number, :raw, :username]
current_post_numbers = posts.limit(5).pluck(:post_number)
current_post_numbers += posts.reorder("posts.score desc").limit(50).pluck(:post_number)
current_post_numbers += posts.reorder("post_number desc").limit(5).pluck(:post_number)
data =
Post
.where(topic_id: topic.id)
.joins(:user)
.where("post_number in (?)", current_post_numbers)
.order(:post_number)
.pluck(*columns)
@last_summary = summarize(data, guidance, topic)
end
if !@last_summary
"Say: No topic found!"
else
"Topic summarized"
end
end
def custom_raw
@last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found")
end
def chain_next_response
false
end
def summarize(data, guidance, topic)
text = +""
data.each do |id, post_number, raw, username|
text << "(#{post_number} #{username} said: #{raw}"
end
summaries = []
current_section = +""
split = []
text
.split(/\s+/)
.each_slice(20) do |slice|
current_section << " "
current_section << slice.join(" ")
# somehow any more will get closer to limits
if bot.tokenize(current_section).length > 2500
split << current_section
current_section = +""
end
end
split << current_section if current_section.present?
split = split[0..3] + split[-3..-1] if split.length > 5
split.each do |section|
# TODO progress meter
summary =
generate_gpt_summary(
section,
topic: topic,
context: "Guidance: #{guidance}\nYou are summarizing the topic: #{topic.title}",
)
summaries << summary
end
if summaries.length > 1
messages = []
messages << { role: "system", content: "You are a helpful bot" }
messages << {
role: "user",
content:
"concatenated the disjoint summaries, creating a cohesive narrative:\n#{summaries.join("\n")}}",
}
bot.submit_prompt(messages, temperature: 0.6, max_tokens: 500, prefer_low_cost: true).dig(
:choices,
0,
:message,
:content,
)
else
summaries.first
end
end
def generate_gpt_summary(text, topic:, context: nil, length: nil)
length ||= 400
prompt = <<~TEXT
#{context}
Summarize the following in #{length} words:
#{text}
TEXT
system_prompt = <<~TEXT
You are a summarization bot.
You effectively summarise any text.
You condense it into a shorter version.
You understand and generate Discourse forum markdown.
Try generating links as well the format is #{topic.url}/POST_NUMBER. eg: [ref](#{topic.url}/77)
TEXT
messages = [{ role: "system", content: system_prompt }]
messages << { role: "user", content: prompt }
result =
bot.submit_prompt(messages, temperature: 0.6, max_tokens: length, prefer_low_cost: true)
result.dig(:choices, 0, :message, :content)
end
end
end

View File

@ -1,42 +0,0 @@
#frozen_string_literal: true
module DiscourseAi::AiBot::Commands
class TagsCommand < Command
class << self
def name
"tags"
end
def desc
"Will list the 100 most popular tags on the current discourse instance"
end
def parameters
[]
end
end
def result_name
"results"
end
def description_args
{ count: @last_count || 0 }
end
def process
column_names = { name: "Name", public_topic_count: "Topic Count" }
tags =
Tag
.where("public_topic_count > 0")
.order(public_topic_count: :desc)
.limit(100)
.pluck(*column_names.keys)
@last_count = tags.length
format_results(tags, column_names.values)
end
end
end

View File

@ -1,49 +0,0 @@
#frozen_string_literal: true
module DiscourseAi::AiBot::Commands
class TimeCommand < Command
class << self
def name
"time"
end
def desc
"Will generate the time in a timezone"
end
def parameters
[
Parameter.new(
name: "timezone",
description: "ALWAYS supply a Ruby compatible timezone",
type: "string",
required: true,
),
]
end
end
def result_name
"time"
end
def description_args
{ timezone: @last_timezone, time: @last_time }
end
def process(timezone:)
time =
begin
Time.now.in_time_zone(timezone)
rescue StandardError
nil
end
time = Time.now if !time
@last_timezone = timezone
@last_time = time.to_s
{ args: { timezone: timezone }, time: time.to_s }
end
end
end

View File

@ -50,7 +50,7 @@ module DiscourseAi
scope.user.in_any_groups?(SiteSetting.ai_bot_allowed_groups_map) scope.user.in_any_groups?(SiteSetting.ai_bot_allowed_groups_map)
end, end,
) do ) do
DiscourseAi::AiBot::Personas DiscourseAi::AiBot::Personas::Persona
.all(user: scope.user) .all(user: scope.user)
.map do |persona| .map do |persona|
{ id: persona.id, name: persona.name, description: persona.description } { id: persona.id, name: persona.name, description: persona.description }
@ -92,32 +92,19 @@ module DiscourseAi
include_condition: -> { SiteSetting.ai_bot_enabled && object.topic.private_message? }, include_condition: -> { SiteSetting.ai_bot_enabled && object.topic.private_message? },
) do ) do
id = topic.custom_fields["ai_persona_id"] id = topic.custom_fields["ai_persona_id"]
name = DiscourseAi::AiBot::Personas.find_by(user: scope.user, id: id.to_i)&.name if id name =
DiscourseAi::AiBot::Personas::Persona.find_by(user: scope.user, id: id.to_i)&.name if id
name || topic.custom_fields["ai_persona"] name || topic.custom_fields["ai_persona"]
end end
plugin.on(:post_created) do |post| plugin.on(:post_created) do |post|
bot_ids = BOTS.map(&:first) bot_ids = BOTS.map(&:first)
if post.post_type == Post.types[:regular] && post.topic.private_message? && # Don't schedule a reply for a bot reply.
!bot_ids.include?(post.user_id) if !bot_ids.include?(post.user_id)
if (SiteSetting.ai_bot_allowed_groups_map & post.user.group_ids).present? bot_user = post.topic.topic_allowed_users.where(user_id: bot_ids).first&.user
bot_id = post.topic.topic_allowed_users.where(user_id: bot_ids).first&.user_id bot = DiscourseAi::AiBot::Bot.as(bot_user)
DiscourseAi::AiBot::Playground.new(bot).update_playground_with(post)
if bot_id
if post.post_number == 1
post.topic.custom_fields[REQUIRE_TITLE_UPDATE] = true
post.topic.save_custom_fields
end
::Jobs.enqueue(:create_ai_reply, post_id: post.id, bot_user_id: bot_id)
::Jobs.enqueue_in(
5.minutes,
:update_ai_bot_pm_title,
post_id: post.id,
bot_user_id: bot_id,
)
end
end
end end
end end

View File

@ -1,157 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
class OpenAiBot < Bot
def self.can_reply_as?(bot_user)
open_ai_bot_ids = [
DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID,
DiscourseAi::AiBot::EntryPoint::GPT4_ID,
DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID,
]
open_ai_bot_ids.include?(bot_user.id)
end
def prompt_limit(allow_commands:)
# provide a buffer of 120 tokens - our function counting is not
# 100% accurate and getting numbers to align exactly is very hard
buffer = reply_params[:max_tokens] + 50
if allow_commands
# note this is about 100 tokens over, OpenAI have a more optimal representation
@function_size ||= tokenize(available_functions.to_json.to_s).length
buffer += @function_size
end
if bot_user.id == DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
150_000 - buffer
elsif bot_user.id == DiscourseAi::AiBot::EntryPoint::GPT4_ID
8192 - buffer
else
16_384 - buffer
end
end
def reply_params
# technically we could allow GPT-3.5 16k more tokens
# but lets just keep it here for now
{ temperature: 0.4, top_p: 0.9, max_tokens: 2500 }
end
def extra_tokens_per_message
# open ai defines about 4 tokens per message of overhead
4
end
def submit_prompt(
prompt,
prefer_low_cost: false,
post: nil,
temperature: nil,
top_p: nil,
max_tokens: nil,
&blk
)
params =
reply_params.merge(
temperature: temperature,
top_p: top_p,
max_tokens: max_tokens,
) { |key, old_value, new_value| new_value.nil? ? old_value : new_value }
model = model_for(low_cost: prefer_low_cost)
params[:functions] = available_functions if available_functions.present?
DiscourseAi::Inference::OpenAiCompletions.perform!(
prompt,
model,
**params,
post: post,
&blk
)
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
def model_for(low_cost: false)
if low_cost || bot_user.id == DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
"gpt-3.5-turbo-16k"
elsif bot_user.id == DiscourseAi::AiBot::EntryPoint::GPT4_ID
"gpt-4"
else
# not quite released yet, once released we should replace with
# gpt-4-turbo
"gpt-4-1106-preview"
end
end
def clean_username(username)
if username.match?(/\0[a-zA-Z0-9_-]{1,64}\z/)
username
else
# not the best in the world, but this is what we have to work with
# if sites enable unicode usernames this can get messy
username.gsub(/[^a-zA-Z0-9_-]/, "_")[0..63]
end
end
def include_function_instructions_in_system_prompt?
# open ai uses a bespoke system for function calls
false
end
private
def populate_functions(partial:, reply:, functions:, done:, current_delta:)
return if !partial
fn = partial.dig(:choices, 0, :delta, :function_call)
if fn
functions.add_function(fn[:name]) if fn[:name].present?
functions.add_argument_fragment(fn[:arguments]) if !fn[:arguments].nil?
functions.custom = true
end
end
def build_message(poster_username, content, function: false, system: false)
is_bot = poster_username == bot_user.username
if function
role = "function"
elsif system
role = "system"
else
role = is_bot ? "assistant" : "user"
end
result = { role: role, content: content }
if function
result[:name] = poster_username
elsif !system && poster_username != bot_user.username && poster_username.present?
# Open AI restrict name to 64 chars and only A-Za-z._ (work around)
result[:name] = clean_username(poster_username)
end
result
end
def get_delta(partial, _context)
partial.dig(:choices, 0, :delta, :content).to_s
end
def get_updated_title(prompt)
DiscourseAi::Inference::OpenAiCompletions.perform!(
prompt,
model_for,
temperature: 0.7,
top_p: 0.9,
max_tokens: 40,
).dig(:choices, 0, :message, :content)
end
end
end
end

View File

@ -1,46 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
module Personas
def self.system_personas
@system_personas ||= {
Personas::General => -1,
Personas::SqlHelper => -2,
Personas::Artist => -3,
Personas::SettingsExplorer => -4,
Personas::Researcher => -5,
Personas::Creative => -6,
Personas::DallE3 => -7,
}
end
def self.system_personas_by_id
@system_personas_by_id ||= system_personas.invert
end
def self.all(user:)
# this needs to be dynamic cause site settings may change
all_available_commands = Persona.all_available_commands
AiPersona.all_personas.filter do |persona|
next false if !user.in_any_groups?(persona.allowed_group_ids)
if persona.system
instance = persona.new
(
instance.required_commands == [] ||
(instance.required_commands - all_available_commands).empty?
)
else
true
end
end
end
def self.find_by(id: nil, name: nil, user:)
all(user: user).find { |persona| persona.id == id || persona.name == name }
end
end
end
end

View File

@ -4,12 +4,12 @@ module DiscourseAi
module AiBot module AiBot
module Personas module Personas
class Artist < Persona class Artist < Persona
def commands def tools
[Commands::ImageCommand] [Tools::Image]
end end
def required_commands def required_tools
[Commands::ImageCommand] [Tools::Image]
end end
def system_prompt def system_prompt

View File

@ -4,7 +4,7 @@ module DiscourseAi
module AiBot module AiBot
module Personas module Personas
class Creative < Persona class Creative < Persona
def commands def tools
[] []
end end

View File

@ -4,12 +4,12 @@ module DiscourseAi
module AiBot module AiBot
module Personas module Personas
class DallE3 < Persona class DallE3 < Persona
def commands def tools
[Commands::DallECommand] [Tools::DallE]
end end
def required_commands def required_tools
[Commands::DallECommand] [Tools::DallE]
end end
def system_prompt def system_prompt

View File

@ -4,15 +4,15 @@ module DiscourseAi
module AiBot module AiBot
module Personas module Personas
class General < Persona class General < Persona
def commands def tools
[ [
Commands::SearchCommand, Tools::Search,
Commands::GoogleCommand, Tools::Google,
Commands::ImageCommand, Tools::Image,
Commands::ReadCommand, Tools::Read,
Commands::ImageCommand, Tools::Image,
Commands::CategoriesCommand, Tools::ListCategories,
Commands::TagsCommand, Tools::ListTags,
] ]
end end

View File

@ -4,19 +4,84 @@ module DiscourseAi
module AiBot module AiBot
module Personas module Personas
class Persona class Persona
def self.name class << self
def system_personas
@system_personas ||= {
Personas::General => -1,
Personas::SqlHelper => -2,
Personas::Artist => -3,
Personas::SettingsExplorer => -4,
Personas::Researcher => -5,
Personas::Creative => -6,
Personas::DallE3 => -7,
}
end
def system_personas_by_id
@system_personas_by_id ||= system_personas.invert
end
def all(user:)
# listing tools has to be dynamic cause site settings may change
AiPersona.all_personas.filter do |persona|
next false if !user.in_any_groups?(persona.allowed_group_ids)
if persona.system
instance = persona.new
(
instance.required_tools == [] ||
(instance.required_tools - all_available_tools).empty?
)
else
true
end
end
end
def find_by(id: nil, name: nil, user:)
all(user: user).find { |persona| persona.id == id || persona.name == name }
end
def name
I18n.t("discourse_ai.ai_bot.personas.#{to_s.demodulize.underscore}.name") I18n.t("discourse_ai.ai_bot.personas.#{to_s.demodulize.underscore}.name")
end end
def self.description def description
I18n.t("discourse_ai.ai_bot.personas.#{to_s.demodulize.underscore}.description") I18n.t("discourse_ai.ai_bot.personas.#{to_s.demodulize.underscore}.description")
end end
def commands def all_available_tools
tools = [
Tools::ListCategories,
Tools::Time,
Tools::Search,
Tools::Summarize,
Tools::Read,
Tools::DbSchema,
Tools::SearchSettings,
Tools::Summarize,
Tools::SettingContext,
]
tools << Tools::ListTags if SiteSetting.tagging_enabled
tools << Tools::Image if SiteSetting.ai_stability_api_key.present?
tools << Tools::DallE if SiteSetting.ai_openai_api_key.present?
if SiteSetting.ai_google_custom_search_api_key.present? &&
SiteSetting.ai_google_custom_search_cx.present?
tools << Tools::Google
end
tools
end
end
def tools
[] []
end end
def required_commands def required_tools
[] []
end end
@ -24,104 +89,55 @@ module DiscourseAi
{} {}
end end
def render_commands(render_function_instructions:) def available_tools
return +"" if available_commands.empty? self.class.all_available_tools.filter { |tool| tools.include?(tool) }
result = +""
if render_function_instructions
result << "\n"
result << function_list.system_prompt
result << "\n"
end
result << available_commands.map(&:custom_system_message).compact.join("\n")
result
end end
def render_system_prompt( def craft_prompt(context)
topic: nil, system_insts =
render_function_instructions: true,
allow_commands: true
)
substitutions = {
site_url: Discourse.base_url,
site_title: SiteSetting.title,
site_description: SiteSetting.site_description,
time: Time.zone.now,
}
substitutions[:participants] = topic.allowed_users.map(&:username).join(", ") if topic
prompt =
system_prompt.gsub(/\{(\w+)\}/) do |match| system_prompt.gsub(/\{(\w+)\}/) do |match|
found = substitutions[match[1..-2].to_sym] found = context[match[1..-2].to_sym]
found.nil? ? match : found.to_s found.nil? ? match : found.to_s
end end
if allow_commands insts = <<~TEXT
prompt += render_commands(render_function_instructions: render_function_instructions) #{system_insts}
end #{available_tools.map(&:custom_system_message).compact_blank.join("\n")}
TEXT
prompt { insts: insts }.tap do |prompt|
end prompt[:tools] = available_tools.map(&:signature) if available_tools
prompt[:conversation_context] = context[:conversation_context] if context[
def available_commands :conversation_context
return @available_commands if @available_commands
@available_commands = all_available_commands.filter { |cmd| commands.include?(cmd) }
end
def available_functions
# note if defined? can be a problem in test
# this can never be nil so it is safe
return @available_functions if @available_functions
functions = []
functions =
available_commands.map do |command|
function =
DiscourseAi::Inference::Function.new(name: command.name, description: command.desc)
command.parameters.each { |parameter| function.add_parameter(parameter) }
function
end
@available_functions = functions
end
def function_list
return @function_list if @function_list
@function_list = DiscourseAi::Inference::FunctionList.new
available_functions.each { |function| @function_list << function }
@function_list
end
def self.all_available_commands
all_commands = [
Commands::CategoriesCommand,
Commands::TimeCommand,
Commands::SearchCommand,
Commands::SummarizeCommand,
Commands::ReadCommand,
Commands::DbSchemaCommand,
Commands::SearchSettingsCommand,
Commands::SummarizeCommand,
Commands::SettingContextCommand,
] ]
end
all_commands << Commands::TagsCommand if SiteSetting.tagging_enabled
all_commands << Commands::ImageCommand if SiteSetting.ai_stability_api_key.present?
all_commands << Commands::DallECommand if SiteSetting.ai_openai_api_key.present?
if SiteSetting.ai_google_custom_search_api_key.present? &&
SiteSetting.ai_google_custom_search_cx.present?
all_commands << Commands::GoogleCommand
end end
all_commands def find_tool(partial)
parsed_function = Nokogiri::HTML5.fragment(partial)
function_id = parsed_function.at("tool_id")&.text
function_name = parsed_function.at("tool_name")&.text
return false if function_name.nil?
tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name }
return false if tool_klass.nil?
arguments =
tool_klass.signature[:parameters]
.to_a
.reduce({}) do |memo, p|
argument = parsed_function.at(p[:name])&.text
next(memo) unless argument
memo[p[:name].to_sym] = argument
memo
end end
def all_available_commands tool_klass.new(
@cmds ||= self.class.all_available_commands arguments,
tool_call_id: function_id,
persona_options: options[tool_klass].to_h,
)
end end
end end
end end

View File

@ -4,12 +4,12 @@ module DiscourseAi
module AiBot module AiBot
module Personas module Personas
class Researcher < Persona class Researcher < Persona
def commands def tools
[Commands::GoogleCommand] [Tools::Google]
end end
def required_commands def required_tools
[Commands::GoogleCommand] [Tools::Google]
end end
def system_prompt def system_prompt

View File

@ -4,15 +4,8 @@ module DiscourseAi
module AiBot module AiBot
module Personas module Personas
class SettingsExplorer < Persona class SettingsExplorer < Persona
def commands def tools
all_available_commands [Tools::SettingContext, Tools::SearchSettings]
end
def all_available_commands
[
DiscourseAi::AiBot::Commands::SettingContextCommand,
DiscourseAi::AiBot::Commands::SearchSettingsCommand,
]
end end
def system_prompt def system_prompt

View File

@ -27,12 +27,8 @@ module DiscourseAi
@schema = schema @schema = schema
end end
def commands def tools
all_available_commands [Tools::DbSchema]
end
def all_available_commands
[DiscourseAi::AiBot::Commands::DbSchemaCommand]
end end
def system_prompt def system_prompt

228
lib/ai_bot/playground.rb Normal file
View File

@ -0,0 +1,228 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
class Playground
# An abstraction to manage the bot and topic interactions.
# The bot will take care of completions while this class updates the topic title
# and stream replies.
REQUIRE_TITLE_UPDATE = "discourse-ai-title-update"
def initialize(bot)
@bot = bot
end
def update_playground_with(post)
if can_attach?(post) && bot.bot_user
schedule_playground_titling(post, bot.bot_user)
schedule_bot_reply(post, bot.bot_user)
end
end
def conversation_context(post)
# Pay attention to the `post_number <= ?` here.
# We want to inject the last post as context because they are translated differently.
context =
post
.topic
.posts
.includes(:user)
.joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id")
.where("post_number <= ?", post.post_number)
.order("post_number desc")
.where("post_type = ?", Post.types[:regular])
.limit(50)
.pluck(:raw, :username, "post_custom_prompts.custom_prompt")
result = []
first = true
context.each do |raw, username, custom_prompt|
custom_prompt_translation =
Proc.new do |message|
# We can't keep backwards-compatibility for stored functions.
# Tool syntax requires a tool_call_id which we don't have.
if message[2] != "function"
custom_context = {
content: message[0],
type: message[2].present? ? message[2] : "assistant",
}
custom_context[:name] = message[1] if custom_context[:type] != "assistant"
result << custom_context
end
end
if custom_prompt.present?
if first
custom_prompt.reverse_each(&custom_prompt_translation)
first = false
else
tool_call_and_tool = custom_prompt.first(2)
tool_call_and_tool.reverse_each(&custom_prompt_translation)
end
else
context = {
content: raw,
type: (available_bot_usernames.include?(username) ? "assistant" : "user"),
}
context[:name] = username if context[:type] == "user"
result << context
end
end
result
end
def title_playground(post)
context = conversation_context(post)
bot
.get_updated_title(context, post.user)
.tap do |new_title|
PostRevisor.new(post.topic.first_post, post.topic).revise!(
bot.bot_user,
title: new_title.sub(/\A"/, "").sub(/"\Z/, ""),
)
post.topic.custom_fields.delete(DiscourseAi::AiBot::EntryPoint::REQUIRE_TITLE_UPDATE)
post.topic.save_custom_fields
end
end
def reply_to(post)
reply = +""
start = Time.now
context = {
site_url: Discourse.base_url,
site_title: SiteSetting.title,
site_description: SiteSetting.site_description,
time: Time.zone.now,
participants: post.topic.allowed_users.map(&:username).join(", "),
conversation_context: conversation_context(post),
user: post.user,
}
reply_post =
PostCreator.create!(
bot.bot_user,
topic_id: post.topic_id,
raw: I18n.t("discourse_ai.ai_bot.placeholder_reply"),
skip_validations: true,
)
redis_stream_key = "gpt_cancel:#{reply_post.id}"
Discourse.redis.setex(redis_stream_key, 60, 1)
new_custom_prompts =
bot.reply(context) do |partial, cancel, placeholder|
reply << partial
raw = reply.dup
raw << "\n\n" << placeholder if placeholder.present?
if !Discourse.redis.get(redis_stream_key)
cancel&.call
reply_post.update!(raw: reply, cooked: PrettyText.cook(reply))
end
# Minor hack to skip the delay during tests.
if placeholder.blank?
next if (Time.now - start < 0.5) && !Rails.env.test?
start = Time.now
end
Discourse.redis.expire(redis_stream_key, 60)
publish_update(reply_post, raw: raw)
end
return if reply.blank?
reply_post.tap do |bot_reply|
publish_update(bot_reply, done: true)
bot_reply.revise(
bot.bot_user,
{ raw: reply },
skip_validations: true,
skip_revision: true,
)
bot_reply.post_custom_prompt ||= bot_reply.build_post_custom_prompt(custom_prompt: [])
prompt = bot_reply.post_custom_prompt.custom_prompt || []
prompt.concat(new_custom_prompts)
prompt << [reply, bot.bot_user.username]
bot_reply.post_custom_prompt.update!(custom_prompt: prompt)
end
end
private
attr_reader :bot
def can_attach?(post)
return false if bot.bot_user.nil?
return false if post.post_type != Post.types[:regular]
return false unless post.topic.private_message?
return false if (SiteSetting.ai_bot_allowed_groups_map & post.user.group_ids).blank?
true
end
def schedule_playground_titling(post, bot_user)
if post.post_number == 1
post.topic.custom_fields[REQUIRE_TITLE_UPDATE] = true
post.topic.save_custom_fields
end
::Jobs.enqueue_in(
5.minutes,
:update_ai_bot_pm_title,
post_id: post.id,
bot_user_id: bot_user.id,
)
end
def schedule_bot_reply(post, bot_user)
::Jobs.enqueue(:create_ai_reply, post_id: post.id, bot_user_id: bot_user.id)
end
def context(topic)
{
site_url: Discourse.base_url,
site_title: SiteSetting.title,
site_description: SiteSetting.site_description,
time: Time.zone.now,
participants: topic.allowed_users.map(&:username).join(", "),
}
end
def publish_update(bot_reply_post, payload)
MessageBus.publish(
"discourse-ai/ai-bot/topic/#{bot_reply_post.topic_id}",
payload.merge(post_id: bot_reply_post.id, post_number: bot_reply_post.post_number),
user_ids: bot_reply_post.topic.allowed_user_ids,
)
end
def available_bot_usernames
@bot_usernames ||= DiscourseAi::AiBot::EntryPoint::BOTS.map(&:second)
end
def clean_username(username)
if username.match?(/\0[a-zA-Z0-9_-]{1,64}\z/)
username
else
# not the best in the world, but this is what we have to work with
# if sites enable unicode usernames this can get messy
username.gsub(/[^a-zA-Z0-9_-]/, "_")[0..63]
end
end
end
end
end

125
lib/ai_bot/tools/dall_e.rb Normal file
View File

@ -0,0 +1,125 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class DallE < Tool
def self.signature
{
name: name,
description: "Renders images from supplied descriptions",
parameters: [
{
name: "prompts",
description:
"The prompts used to generate or create or draw the image (5000 chars or less, be creative) up to 4 prompts",
type: "array",
item_type: "string",
required: true,
},
],
}
end
def self.name
"dall_e"
end
def prompts
parameters[:prompts]
end
def chain_next_response?
false
end
def invoke(bot_user, _llm)
# max 4 prompts
max_prompts = prompts.take(4)
progress = +""
yield(progress)
results = nil
# this ensures multisite safety since background threads
# generate the images
api_key = SiteSetting.ai_openai_api_key
api_url = SiteSetting.ai_openai_dall_e_3_url
threads = []
max_prompts.each_with_index do |prompt, index|
threads << Thread.new(prompt) do |inner_prompt|
attempts = 0
begin
DiscourseAi::Inference::OpenAiImageGenerator.perform!(
inner_prompt,
api_key: api_key,
api_url: api_url,
)
rescue => e
attempts += 1
sleep 2
retry if attempts < 3
Discourse.warn_exception(
e,
message: "Failed to generate image for prompt #{prompt}",
)
nil
end
end
end
while true
progress << "."
yield(progress)
break if threads.all? { |t| t.join(2) }
end
results = threads.filter_map(&:value)
if results.blank?
return { prompts: max_prompts, error: "Something went wrong, could not generate image" }
end
uploads = []
results.each_with_index do |result, index|
result[:data].each do |image|
Tempfile.create("v1_txt2img_#{index}.png") do |file|
file.binmode
file.write(Base64.decode64(image[:b64_json]))
file.rewind
uploads << {
prompt: image[:revised_prompt],
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
}
end
end
end
self.custom_raw = <<~RAW
[grid]
#{
uploads
.map do |item|
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
end
.join(" ")
}
[/grid]
RAW
{ prompts: uploads.map { |item| item[:prompt] } }
end
protected
def description_args
{ prompt: prompts.first }
end
end
end
end
end

View File

@ -0,0 +1,62 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class DbSchema < Tool
def self.signature
{
name: name,
description: "Will load schema information for specific tables in the database",
parameters: [
{
name: "tables",
description:
"list of tables to load schema information for, comma seperated list eg: (users,posts))",
type: "string",
required: true,
},
],
}
end
def self.name
"schema"
end
def tables
parameters[:tables]
end
def invoke(_bot_user, _llm)
tables_arr = tables.split(",").map(&:strip)
table_info = {}
DB
.query(<<~SQL, tables_arr)
select table_name, column_name, data_type from information_schema.columns
where table_schema = 'public'
and table_name in (?)
order by table_name
SQL
.each do |row|
(table_info[row.table_name] ||= []) << "#{row.column_name} #{row.data_type}"
end
schema_info =
table_info
.map { |table_name, columns| "#{table_name}(#{columns.join(",")})" }
.join("\n")
{ schema_info: schema_info, tables: tables }
end
protected
def description_args
{ tables: tables.join(", ") }
end
end
end
end
end

View File

@ -0,0 +1,85 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class Google < Tool
def self.signature
{
name: name,
description:
"Will search using Google - global internet search (supports all Google search operators)",
parameters: [
{ name: "query", description: "The search query", type: "string", required: true },
],
}
end
def self.custom_system_message
"You were trained on OLD data, lean on search to get up to date information from the web"
end
def self.name
"google"
end
def query
parameters[:query].to_s
end
def invoke(bot_user, llm)
yield("") # Triggers placeholder update
api_key = SiteSetting.ai_google_custom_search_api_key
cx = SiteSetting.ai_google_custom_search_cx
escaped_query = CGI.escape(query)
uri =
URI(
"https://www.googleapis.com/customsearch/v1?key=#{api_key}&cx=#{cx}&q=#{escaped_query}&num=10",
)
body = Net::HTTP.get(uri)
parse_search_json(body, escaped_query, llm)
end
attr_reader :results_count
protected
def description_args
{
count: results_count || 0,
query: query,
url: "https://google.com/search?q=#{CGI.escape(query)}",
}
end
private
def minimize_field(result, field, llm, max_tokens: 100)
data = result[field]
return "" if data.blank?
llm.tokenizer.truncate(data, max_tokens).squish
end
def parse_search_json(json_data, escaped_query, llm)
parsed = JSON.parse(json_data)
results = parsed["items"]
@results_count = parsed.dig("searchInformation", "totalResults").to_i
format_results(results, args: escaped_query) do |result|
{
title: minimize_field(result, "title", llm),
link: minimize_field(result, "link", llm),
snippet: minimize_field(result, "snippet", llm, max_tokens: 120),
displayLink: minimize_field(result, "displayLink", llm),
formattedUrl: minimize_field(result, "formattedUrl", llm),
}
end
end
end
end
end
end

144
lib/ai_bot/tools/image.rb Normal file
View File

@ -0,0 +1,144 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class Image < Tool
def self.signature
{
name: name,
description:
"Renders an image from the description (remove all connector words, keep it to 40 words or less). Despite being a text based bot you can generate images! (when user asks to draw, paint or other synonyms try this)",
parameters: [
{
name: "prompts",
description:
"The prompts used to generate or create or draw the image (40 words or less, be creative) up to 4 prompts",
type: "array",
item_type: "string",
required: true,
},
{
name: "seeds",
description:
"The seed used to generate the image (optional) - can be used to retain image style on amended prompts",
type: "array",
item_type: "integer",
required: true,
},
],
}
end
def self.name
"image"
end
def prompts
JSON.parse(parameters[:prompts].to_s)
end
def seeds
parameters[:seeds]
end
def chain_next_response?
false
end
def invoke(bot_user, _llm)
# max 4 prompts
selected_prompts = prompts.take(4)
seeds = seeds.take(4) if seeds
progress = +""
yield(progress)
results = nil
# this ensures multisite safety since background threads
# generate the images
api_key = SiteSetting.ai_stability_api_key
engine = SiteSetting.ai_stability_engine
api_url = SiteSetting.ai_stability_api_url
threads = []
selected_prompts.each_with_index do |prompt, index|
seed = seeds ? seeds[index] : nil
threads << Thread.new(seed, prompt) do |inner_seed, inner_prompt|
attempts = 0
begin
DiscourseAi::Inference::StabilityGenerator.perform!(
inner_prompt,
engine: engine,
api_key: api_key,
api_url: api_url,
image_count: 1,
seed: inner_seed,
)
rescue => e
attempts += 1
retry if attempts < 3
Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}")
nil
end
end
end
while true
progress << "."
yield(progress)
break if threads.all? { |t| t.join(2) }
end
results = threads.map(&:value).compact
if !results.present?
return { prompts: prompts, error: "Something went wrong, could not generate image" }
end
uploads = []
results.each_with_index do |result, index|
result[:artifacts].each do |image|
Tempfile.create("v1_txt2img_#{index}.png") do |file|
file.binmode
file.write(Base64.decode64(image[:base64]))
file.rewind
uploads << {
prompt: prompts[index],
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
seed: image[:seed],
}
end
end
end
@custom_raw = <<~RAW
[grid]
#{
uploads
.map do |item|
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
end
.join(" ")
}
[/grid]
RAW
{
prompts: uploads.map { |item| item[:prompt] },
seeds: uploads.map { |item| item[:seed] },
}
end
protected
def description_args
{ prompt: prompts.first }
end
end
end
end
end

View File

@ -0,0 +1,46 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class ListCategories < Tool
def self.signature
{
name: name,
description:
"Will list the categories on the current discourse instance, prefer to format with # in front of the category name",
}
end
def self.name
"categories"
end
def invoke(_bot_user, _llm)
columns = {
name: "Name",
slug: "Slug",
description: "Description",
posts_year: "Posts Year",
posts_month: "Posts Month",
posts_week: "Posts Week",
id: "id",
parent_category_id: "parent_category_id",
}
rows = Category.where(read_restricted: false).limit(100).pluck(*columns.keys)
@last_count = rows.length
{ rows: rows, column_names: columns.values }
end
private
def description_args
{ count: @last_count || 0 }
end
end
end
end
end

View File

@ -0,0 +1,41 @@
#frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class ListTags < Tool
def self.signature
{
name: name,
description: "Will list the 100 most popular tags on the current discourse instance",
}
end
def self.name
"tags"
end
def invoke(_bot_user, _llm)
column_names = { name: "Name", public_topic_count: "Topic Count" }
tags =
Tag
.where("public_topic_count > 0")
.order(public_topic_count: :desc)
.limit(100)
.pluck(*column_names.keys)
@last_count = tags.length
format_results(tags, column_names.values)
end
private
def description_args
{ count: @last_count || 0 }
end
end
end
end
end

View File

@ -0,0 +1,25 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class Option
attr_reader :tool, :name, :type
def initialize(tool:, name:, type:)
@tool = tool
@name = name.to_s
@type = type
end
def localized_name
I18n.t("discourse_ai.ai_bot.command_options.#{tool.signature[:name]}.#{name}.name")
end
def localized_description
I18n.t("discourse_ai.ai_bot.command_options.#{tool.signature[:name]}.#{name}.description")
end
end
end
end
end

90
lib/ai_bot/tools/read.rb Normal file
View File

@ -0,0 +1,90 @@
#frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class Read < Tool
def self.signature
{
name: name,
description: "Will read a topic or a post on this Discourse instance",
parameters: [
{
name: "topic_id",
description: "the id of the topic to read",
type: "integer",
required: true,
},
{
name: "post_number",
description: "the post number to read",
type: "integer",
required: true,
},
],
}
end
def self.name
"read"
end
attr_reader :title, :url
def topic_id
parameters[:topic_id]
end
def post_number
parameters[:post_number]
end
def invoke(_bot_user, llm)
not_found = { topic_id: topic_id, description: "Topic not found" }
@title = ""
topic = Topic.find_by(id: topic_id.to_i)
return not_found if !topic || !Guardian.new.can_see?(topic)
@title = topic.title
posts = Post.secured(Guardian.new).where(topic_id: topic_id).order(:post_number).limit(40)
@url = topic.relative_url(post_number)
posts = posts.where("post_number = ?", post_number) if post_number
content = +<<~TEXT.strip
title: #{topic.title}
TEXT
category_names = [
topic.category&.parent_category&.name,
topic.category&.name,
].compact.join(" ")
content << "\ncategories: #{category_names}" if category_names.present?
if topic.tags.length > 0
tags = DiscourseTagging.filter_visible(topic.tags, Guardian.new)
content << "\ntags: #{tags.map(&:name).join(", ")}\n\n" if tags.length > 0
end
posts.each { |post| content << "\n\n#{post.username} said:\n\n#{post.raw}" }
# TODO: 16k or 100k models can handle a lot more tokens
content = llm.tokenizer.truncate(content, 1500).squish
result = { topic_id: topic_id, content: content, complete: true }
result[:post_number] = post_number if post_number
result
end
protected
def description_args
{ title: title, url: url }
end
end
end
end
end

223
lib/ai_bot/tools/search.rb Normal file
View File

@ -0,0 +1,223 @@
#frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class Search < Tool
MIN_SEMANTIC_RESULTS = 5
class << self
def signature
{
name: name,
description:
"Will search topics in the current discourse instance, when rendering always prefer to link to the topics you find",
parameters: [
{
name: "search_query",
description:
"Specific keywords to search for, space seperated (correct bad spelling, remove connector words)",
type: "string",
},
{
name: "user",
description:
"Filter search results to this username (only include if user explicitly asks to filter by user)",
type: "string",
},
{
name: "order",
description: "search result order",
type: "string",
enum: %w[latest latest_topic oldest views likes],
},
{
name: "limit",
description:
"limit number of results returned (generally prefer to just keep to default)",
type: "integer",
},
{
name: "max_posts",
description:
"maximum number of posts on the topics (topics where lots of people posted)",
type: "integer",
},
{
name: "tags",
description:
"list of tags to search for. Use + to join with OR, use , to join with AND",
type: "string",
},
{ name: "category", description: "category name to filter to", type: "string" },
{
name: "before",
description: "only topics created before a specific date YYYY-MM-DD",
type: "string",
},
{
name: "after",
description: "only topics created after a specific date YYYY-MM-DD",
type: "string",
},
{
name: "status",
description: "search for topics in a particular state",
type: "string",
enum: %w[open closed archived noreplies single_user],
},
],
}
end
def name
"search"
end
def custom_system_message
<<~TEXT
You were trained on OLD data, lean on search to get up to date information about this forum
When searching try to SIMPLIFY search terms
Discourse search joins all terms with AND. Reduce and simplify terms to find more results.
TEXT
end
def accepted_options
[option(:base_query, type: :string), option(:max_results, type: :integer)]
end
end
def search_args
parameters.slice(:user, :order, :max_posts, :tags, :before, :after, :status)
end
def invoke(bot_user, llm)
search_string =
search_args.reduce(+parameters[:search_query].to_s) do |memo, (key, value)|
return memo if value.blank?
memo << " " << "#{key}:#{value}"
end
@last_query = search_string
yield(I18n.t("discourse_ai.ai_bot.searching", query: search_string))
if options[:base_query].present?
search_string = "#{search_string} #{options[:base_query]}"
end
results =
::Search.execute(
search_string.to_s + " status:public",
search_type: :full_page,
guardian: Guardian.new(),
)
# let's be frugal with tokens, 50 results is too much and stuff gets cut off
max_results = calculate_max_results(llm)
results_limit = parameters[:limit] || max_results
results_limit = max_results if parameters[:limit].to_i > max_results
should_try_semantic_search =
SiteSetting.ai_embeddings_semantic_search_enabled && results_limit == max_results &&
parameters[:search_query].present?
max_semantic_results = max_results / 4
results_limit = results_limit - max_semantic_results if should_try_semantic_search
posts = results&.posts || []
posts = posts[0..results_limit.to_i - 1]
if should_try_semantic_search
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(Guardian.new())
topic_ids = Set.new(posts.map(&:topic_id))
search = ::Search.new(search_string, guardian: Guardian.new)
results = nil
begin
results = semantic_search.search_for_topics(search.term)
rescue => e
Discourse.warn_exception(e, message: "Semantic search failed")
end
if results
results = search.apply_filters(results)
results.each do |post|
next if topic_ids.include?(post.topic_id)
topic_ids << post.topic_id
posts << post
break if posts.length >= max_results
end
end
end
@last_num_results = posts.length
# this is the general pattern from core
# if there are millions of hidden tags it may fail
hidden_tags = nil
if posts.blank?
{ args: parameters, rows: [], instruction: "nothing was found, expand your search" }
else
format_results(posts, args: parameters) do |post|
category_names = [
post.topic.category&.parent_category&.name,
post.topic.category&.name,
].compact.join(" > ")
row = {
title: post.topic.title,
url: Discourse.base_path + post.url,
username: post.user&.username,
excerpt: post.excerpt,
created: post.created_at,
category: category_names,
likes: post.like_count,
topic_views: post.topic.views,
topic_likes: post.topic.like_count,
topic_replies: post.topic.posts_count - 1,
}
if SiteSetting.tagging_enabled
hidden_tags ||= DiscourseTagging.hidden_tag_names
# using map over pluck to avoid n+1 (assuming caller preloading)
tags = post.topic.tags.map(&:name) - hidden_tags
row[:tags] = tags.join(", ") if tags.present?
end
row
end
end
end
protected
def description_args
{
count: @last_num_results || 0,
query: @last_query || "",
url: "#{Discourse.base_path}/search?q=#{CGI.escape(@last_query || "")}",
}
end
private
def calculate_max_results(llm)
max_results = options[:max_results].to_i
return [max_results, 100].min if max_results > 0
if llm.max_prompt_tokens > 30_000
60
elsif llm.max_prompt_tokens > 10_000
40
else
20
end
end
end
end
end
end

View File

@ -0,0 +1,88 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class SearchSettings < Tool
INCLUDE_DESCRIPTIONS_MAX_LENGTH = 10
MAX_RESULTS = 200
def self.signature
{
name: name,
description: "Will search through site settings and return top 20 results",
parameters: [
{
name: "query",
description:
"comma delimited list of settings to search for (e.g. 'setting_1,setting_2')",
type: "string",
required: true,
},
],
}
end
def self.name
"search_settings"
end
def query
parameters[:query].to_s
end
def invoke(_bot_user, _llm)
@last_num_results = 0
terms = query.split(",").map(&:strip).map(&:downcase).reject(&:blank?)
found =
SiteSetting.all_settings.filter do |setting|
name = setting[:setting].to_s.downcase
description = setting[:description].to_s.downcase
plugin = setting[:plugin].to_s.downcase
search_string = "#{name} #{description} #{plugin}"
terms.any? { |term| search_string.include?(term) }
end
if found.blank?
{
args: {
query: query,
},
rows: [],
instruction: "no settings matched #{query}, expand your search",
}
else
include_descriptions = false
if found.length > MAX_RESULTS
found = found[0..MAX_RESULTS]
elsif found.length < INCLUDE_DESCRIPTIONS_MAX_LENGTH
include_descriptions = true
end
@last_num_results = found.length
format_results(found, args: { query: query }) do |setting|
result = { name: setting[:setting] }
if include_descriptions
result[:description] = setting[:description]
result[:plugin] = setting[:plugin]
end
result
end
end
end
protected
def description_args
{ count: @last_num_results || 0, query: parameters[:query].to_s }
end
end
end
end
end

View File

@ -0,0 +1,160 @@
#frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class SettingContext < Tool
MAX_CONTEXT_TOKENS = 2000
CODE_FILE_EXTENSIONS = "rb,js,gjs,hbs"
class << self
def rg_installed?
if defined?(@rg_installed)
@rg_installed
else
@rg_installed =
begin
Discourse::Utils.execute_command("which", "rg")
true
rescue Discourse::Utils::CommandError
false
end
end
end
def signature
{
name: name,
description:
"Will provide you with full context regarding a particular site setting in Discourse",
parameters: [
{
name: "setting_name",
description: "The name of the site setting we need context for",
type: "string",
required: true,
},
],
}
end
def name
"setting_context"
end
end
def setting_name
parameters[:setting_name]
end
def invoke(_bot_user, llm)
if !self.class.rg_installed?
return(
{
setting_name: setting_name,
context:
"This command requires the rg command line tool to be installed on the server",
}
)
end
if !SiteSetting.has_setting?(setting_name)
{ setting_name: setting_name, context: "This setting does not exist" }
else
description = SiteSetting.description(setting_name)
result = +"# #{setting_name}\n#{description}\n\n"
setting_info =
find_setting_info(setting_name, [Rails.root.join("config", "site_settings.yml").to_s])
if !setting_info
setting_info =
find_setting_info(setting_name, Dir[Rails.root.join("plugins/**/settings.yml")])
end
result << setting_info
result << "\n\n"
%w[lib app plugins].each do |dir|
path = Rails.root.join(dir).to_s
result << Discourse::Utils.execute_command(
"rg",
setting_name,
path,
"-g",
"!**/spec/**",
"-g",
"!**/dist/**",
"-g",
"*.{#{CODE_FILE_EXTENSIONS}}",
"-C",
"10",
"--color",
"never",
"--heading",
"--no-ignore",
chdir: path,
success_status_codes: [0, 1],
)
end
result.gsub!(/^#{Regexp.escape(Rails.root.to_s)}/, "")
result = llm.tokenizer.truncate(result, MAX_CONTEXT_TOKENS)
{ setting_name: setting_name, context: result }
end
end
private
def find_setting_info(name, paths)
path, result = nil
paths.each do |search_path|
result =
Discourse::Utils.execute_command(
"rg",
name,
search_path,
"-g",
"*.{#{CODE_FILE_EXTENSIONS}}",
"-A",
"10",
"--color",
"never",
"--heading",
success_status_codes: [0, 1],
)
if !result.blank?
path = search_path
break
end
end
if result.blank?
nil
else
rows = result.split("\n")
leading_spaces = rows[0].match(/^\s*/)[0].length
filtered = []
rows.each do |row|
if !filtered.blank?
break if row.match(/^\s*/)[0].length <= leading_spaces
end
filtered << row
end
filtered.unshift("#{path}")
filtered.join("\n")
end
end
def description_args
parameters.slice(:setting_name)
end
end
end
end
end

View File

@ -0,0 +1,183 @@
#frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class Summarize < Tool
def self.signature
{
name: name,
description: "Will summarize a topic attempting to answer question in guidance",
parameters: [
{
name: "topic_id",
description: "The discourse topic id to summarize",
type: "integer",
required: true,
},
{
name: "guidance",
description: "Special guidance on how to summarize the topic",
type: "string",
},
],
}
end
def self.name
"summary"
end
def topic_id
parameters[:topic_id].to_i
end
def guidance
parameters[:guidance]
end
def chain_next_response?
false
end
def standalone?
true
end
def low_cost?
true
end
def custom_raw
@last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found")
end
def invoke(bot_user, llm, &progress_blk)
topic = nil
if topic_id > 0
topic = Topic.find_by(id: topic_id)
topic = nil if !topic || !Guardian.new.can_see?(topic)
end
@last_summary = nil
if topic
@last_topic_title = topic.title
posts =
Post
.where(topic_id: topic.id)
.where("post_type in (?)", [Post.types[:regular], Post.types[:small_action]])
.where("not hidden")
.order(:post_number)
columns = ["posts.id", :post_number, :raw, :username]
current_post_numbers = posts.limit(5).pluck(:post_number)
current_post_numbers += posts.reorder("posts.score desc").limit(50).pluck(:post_number)
current_post_numbers += posts.reorder("post_number desc").limit(5).pluck(:post_number)
data =
Post
.where(topic_id: topic.id)
.joins(:user)
.where("post_number in (?)", current_post_numbers)
.order(:post_number)
.pluck(*columns)
@last_summary = summarize(data, topic, guidance, bot_user, llm, &progress_blk)
end
if !@last_summary
"Say: No topic found!"
else
"Topic summarized"
end
end
protected
def description_args
{ url: "#{Discourse.base_path}/t/-/#{@last_topic_id}", title: @last_topic_title || "" }
end
private
def summarize(data, topic, guidance, bot_user, llm, &progress_blk)
text = +""
data.each do |id, post_number, raw, username|
text << "(#{post_number} #{username} said: #{raw}"
end
summaries = []
current_section = +""
split = []
text
.split(/\s+/)
.each_slice(20) do |slice|
current_section << " "
current_section << slice.join(" ")
# somehow any more will get closer to limits
if llm.tokenizer.tokenize(current_section).length > 2500
split << current_section
current_section = +""
end
end
split << current_section if current_section.present?
split = split[0..3] + split[-3..-1] if split.length > 5
progress = +I18n.t("discourse_ai.ai_bot.summarizing")
progress_blk.call(progress)
split.each do |section|
progress << "."
progress_blk.call(progress)
prompt = section_prompt(topic, section, guidance)
summary = llm.generate(prompt, temperature: 0.6, max_tokens: 400, user: bot_user)
summaries << summary
end
if summaries.length > 1
progress << "."
progress_blk.call(progress)
contatenation_prompt = {
insts: "You are a helpful bot",
input:
"concatenated the disjoint summaries, creating a cohesive narrative:\n#{summaries.join("\n")}}",
}
llm.generate(contatenation_prompt, temperature: 0.6, max_tokens: 500, user: bot_user)
else
summaries.first
end
end
def section_prompt(topic, text, guidance)
insts = <<~TEXT
You are a summarization bot.
You effectively summarise any text.
You condense it into a shorter version.
You understand and generate Discourse forum markdown.
Try generating links as well the format is #{topic.url}/POST_NUMBER. eg: [ref](#{topic.url}/77)
TEXT
{ insts: insts, input: <<~TEXT }
Guidance: #{guidance}
You are summarizing the topic: #{topic.title}
Summarize the following in 400 words:
#{text}
TEXT
end
end
end
end
end

52
lib/ai_bot/tools/time.rb Normal file
View File

@ -0,0 +1,52 @@
#frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class Time < Tool
def self.signature
{
name: name,
description: "Will generate the time in a timezone",
parameters: [
{
name: "timezone",
description: "ALWAYS supply a Ruby compatible timezone",
type: "string",
required: true,
},
],
}
end
def self.name
"time"
end
def timezone
parameters[:timezone].to_s
end
def invoke(_bot_user, _llm)
time =
begin
::Time.now.in_time_zone(timezone)
rescue StandardError
nil
end
time = ::Time.now if !time
@last_time = time.to_s
{ args: { timezone: timezone }, time: time.to_s }
end
private
def description_args
{ timezone: timezone, time: @last_time }
end
end
end
end
end

124
lib/ai_bot/tools/tool.rb Normal file
View File

@ -0,0 +1,124 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class Tool
class << self
def signature
raise NotImplemented
end
def name
raise NotImplemented
end
def accepted_options
[]
end
def option(name, type:)
Option.new(tool: self, name: name, type: type)
end
def help
I18n.t("discourse_ai.ai_bot.command_help.#{signature[:name]}")
end
def custom_system_message
nil
end
end
attr_accessor :custom_raw
def initialize(parameters, tool_call_id: "", persona_options: {})
@parameters = parameters
@tool_call_id = tool_call_id
@persona_options = persona_options
end
attr_reader :parameters, :tool_call_id
def name
self.class.name
end
def summary
I18n.t("discourse_ai.ai_bot.command_summary.#{name}")
end
def details
I18n.t("discourse_ai.ai_bot.command_description.#{name}", description_args)
end
def help
I18n.t("discourse_ai.ai_bot.command_help.#{name}")
end
def options
self
.class
.accepted_options
.reduce(HashWithIndifferentAccess.new) do |memo, option|
val = @persona_options[option.name]
memo[option.name] = val if val
memo
end
end
def chain_next_response?
true
end
def standalone?
false
end
def low_cost?
false
end
protected
def accepted_options
[]
end
def option(name, type:)
Option.new(tool: self, name: name, type: type)
end
def description_args
{}
end
def format_results(rows, column_names = nil, args: nil)
rows = rows&.map { |row| yield row } if block_given?
if !column_names
index = -1
column_indexes = {}
rows =
rows&.map do |data|
new_row = []
data.each do |key, value|
found_index = column_indexes[key.to_s] ||= (index += 1)
new_row[found_index] = value
end
new_row
end
column_names = column_indexes.keys
end
# this is not the most efficient format
# however this is needed cause GPT 3.5 / 4 was steered using JSON
result = { column_names: column_names, rows: rows }
result[:args] = args if args
result
end
end
end
end
end

View File

@ -46,10 +46,9 @@ module DiscourseAi
prompt[:tools].map do |t| prompt[:tools].map do |t|
tool = t.dup tool = t.dup
if tool[:parameters] tool[:parameters] = t[:parameters]
tool[:parameters] = t[:parameters].reduce( .to_a
{ type: "object", properties: {}, required: [] }, .reduce({ type: "object", properties: {}, required: [] }) do |memo, p|
) do |memo, p|
name = p[:name] name = p[:name]
memo[:required] << name if p[:required] memo[:required] << name if p[:required]
@ -58,7 +57,6 @@ module DiscourseAi
memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type] memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
memo memo
end end
end
{ type: "function", function: tool } { type: "function", function: tool }
end end
@ -71,9 +69,12 @@ module DiscourseAi
trimmed_context.reverse.map do |context| trimmed_context.reverse.map do |context|
if context[:type] == "tool_call" if context[:type] == "tool_call"
function = JSON.parse(context[:content], symbolize_names: true)
function[:arguments] = function[:arguments].to_json
{ {
role: "assistant", role: "assistant",
tool_calls: [{ type: "function", function: context[:content], id: context[:name] }], tool_calls: [{ type: "function", function: function, id: context[:name] }],
} }
else else
translated = context.slice(:content) translated = context.slice(:content)

View File

@ -39,12 +39,12 @@ module DiscourseAi
def conversation_context def conversation_context
return "" if prompt[:conversation_context].blank? return "" if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context]) clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
trimmed_context = trim_context(clean_context)
trimmed_context trimmed_context
.reverse .reverse
.reduce(+"") do |memo, context| .reduce(+"") do |memo, context|
next(memo) if context[:type] == "tool_call"
memo << (context[:type] == "user" ? "Human:" : "Assistant:") memo << (context[:type] == "user" ? "Human:" : "Assistant:")
if context[:type] == "tool" if context[:type] == "tool"

View File

@ -97,6 +97,13 @@ module DiscourseAi
message_tokens = calculate_message_token(dupped_context) message_tokens = calculate_message_token(dupped_context)
# Don't trim tool call metadata.
if context[:type] == "tool_call"
current_token_count += calculate_message_token(context) + per_message_overhead
memo << context
next(memo)
end
# Trimming content to make sure we respect token limit. # Trimming content to make sure we respect token limit.
while dupped_context[:content].present? && while dupped_context[:content].present? &&
message_tokens + current_token_count + per_message_overhead > prompt_limit message_tokens + current_token_count + per_message_overhead > prompt_limit

View File

@ -39,12 +39,13 @@ module DiscourseAi
def conversation_context def conversation_context
return "" if prompt[:conversation_context].blank? return "" if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context]) clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
trimmed_context = trim_context(clean_context)
trimmed_context trimmed_context
.reverse .reverse
.reduce(+"") do |memo, context| .reduce(+"") do |memo, context|
next(memo) if context[:type] == "tool_call"
if context[:type] == "tool" if context[:type] == "tool"
memo << <<~TEXT memo << <<~TEXT
[INST] [INST]

View File

@ -39,12 +39,12 @@ module DiscourseAi
def conversation_context def conversation_context
return "" if prompt[:conversation_context].blank? return "" if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context]) clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
trimmed_context = trim_context(clean_context)
trimmed_context trimmed_context
.reverse .reverse
.reduce(+"") do |memo, context| .reduce(+"") do |memo, context|
next(memo) if context[:type] == "tool_call"
memo << "[INST] " if context[:type] == "user" memo << "[INST] " if context[:type] == "user"
if context[:type] == "tool" if context[:type] == "tool"

View File

@ -36,12 +36,12 @@ module DiscourseAi
def conversation_context def conversation_context
return "" if prompt[:conversation_context].blank? return "" if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context]) clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
trimmed_context = trim_context(clean_context)
trimmed_context trimmed_context
.reverse .reverse
.reduce(+"") do |memo, context| .reduce(+"") do |memo, context|
next(memo) if context[:type] == "tool_call"
memo << (context[:type] == "user" ? "### User:" : "### Assistant:") memo << (context[:type] == "user" ? "### User:" : "### Assistant:")
if context[:type] == "tool" if context[:type] == "tool"

View File

@ -23,7 +23,7 @@ module DiscourseAi
def default_options def default_options
{ {
model: model, model: model,
max_tokens_to_sample: 2_000, max_tokens_to_sample: 3_000,
stop_sequences: ["\n\nHuman:", "</function_calls>"], stop_sequences: ["\n\nHuman:", "</function_calls>"],
} }
end end

View File

@ -26,11 +26,7 @@ module DiscourseAi
end end
def default_options def default_options
{ { max_tokens_to_sample: 3_000, stop_sequences: ["\n\nHuman:", "</function_calls>"] }
model: model,
max_tokens_to_sample: 2_000,
stop_sequences: ["\n\nHuman:", "</function_calls>"],
}
end end
def provider_id def provider_id

View File

@ -87,6 +87,8 @@ module DiscourseAi
return response_data return response_data
end end
has_tool = false
begin begin
cancelled = false cancelled = false
cancel = lambda { cancelled = true } cancel = lambda { cancelled = true }
@ -129,17 +131,19 @@ module DiscourseAi
partial = extract_completion_from(raw_partial) partial = extract_completion_from(raw_partial)
next if response_data.empty? && partial.blank? next if response_data.empty? && partial.blank?
next if partial.nil? next if partial.nil?
partials_raw << partial.to_s
# Skip yield for tools. We'll buffer and yield later. # Stop streaming the response as soon as you find a tool.
if has_tool?(partials_raw) # We'll buffer and yield it later.
has_tool = true if has_tool?(partials_raw)
if has_tool
function_buffer = add_to_buffer(function_buffer, partials_raw, partial) function_buffer = add_to_buffer(function_buffer, partials_raw, partial)
else else
response_data << partial response_data << partial
yield partial, cancel if partial yield partial, cancel if partial
end end
partials_raw << partial.to_s
rescue JSON::ParserError rescue JSON::ParserError
leftover = redo_chunk leftover = redo_chunk
json_error = true json_error = true
@ -158,7 +162,7 @@ module DiscourseAi
end end
# Once we have the full response, try to return the tool as a XML doc. # Once we have the full response, try to return the tool as a XML doc.
if has_tool?(partials_raw) if has_tool
if function_buffer.at("tool_name").text.present? if function_buffer.at("tool_name").text.present?
invocation = +function_buffer.at("function_calls").to_s invocation = +function_buffer.at("function_calls").to_s
invocation << "\n" invocation << "\n"
@ -264,7 +268,7 @@ module DiscourseAi
read_function = Nokogiri::HTML5.fragment(raw_data) read_function = Nokogiri::HTML5.fragment(raw_data)
if tool_name = read_function.at("tool_name").text if tool_name = read_function.at("tool_name")&.text
function_buffer.at("tool_name").inner_html = tool_name function_buffer.at("tool_name").inner_html = tool_name
function_buffer.at("tool_id").inner_html = tool_name function_buffer.at("tool_id").inner_html = tool_name
end end
@ -272,7 +276,8 @@ module DiscourseAi
_read_parameters = _read_parameters =
read_function read_function
.at("parameters") .at("parameters")
.elements &.elements
.to_a
.each do |elem| .each do |elem|
if paramenter = function_buffer.at(elem.name)&.text if paramenter = function_buffer.at(elem.name)&.text
function_buffer.at(elem.name).inner_html = paramenter function_buffer.at(elem.name).inner_html = paramenter

View File

@ -20,7 +20,10 @@ module DiscourseAi
def self.with_prepared_responses(responses) def self.with_prepared_responses(responses)
@canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses) @canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses)
yield(@canned_response).tap { @canned_response = nil } yield(@canned_response)
ensure
# Don't leak prepared response if there's an exception.
@canned_response = nil
end end
def self.proxy(model_name) def self.proxy(model_name)
@ -119,9 +122,15 @@ module DiscourseAi
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk) gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
end end
def max_prompt_tokens
dialect_klass.new({}, model_name).max_prompt_tokens
end
attr_reader :model_name
private private
attr_reader :dialect_klass, :gateway, :model_name attr_reader :dialect_klass, :gateway
end end
end end
end end

View File

@ -1,168 +0,0 @@
# frozen_string_literal: true
require "base64"
require "json"
require "aws-eventstream"
require "aws-sigv4"
module ::DiscourseAi
module Inference
class AmazonBedrockInference
CompletionFailed = Class.new(StandardError)
TIMEOUT = 60
def self.perform!(
prompt,
model = "anthropic.claude-v2",
temperature: nil,
top_p: nil,
top_k: nil,
max_tokens: 20_000,
user_id: nil,
stop_sequences: nil,
tokenizer: Tokenizer::AnthropicTokenizer
)
raise CompletionFailed if model.blank?
raise CompletionFailed if !SiteSetting.ai_bedrock_access_key_id.present?
raise CompletionFailed if !SiteSetting.ai_bedrock_secret_access_key.present?
raise CompletionFailed if !SiteSetting.ai_bedrock_region.present?
signer =
Aws::Sigv4::Signer.new(
access_key_id: SiteSetting.ai_bedrock_access_key_id,
region: SiteSetting.ai_bedrock_region,
secret_access_key: SiteSetting.ai_bedrock_secret_access_key,
service: "bedrock",
)
log = nil
response_data = +""
response_raw = +""
url_api =
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{model}/"
if block_given?
url_api = url_api + "invoke-with-response-stream"
else
url_api = url_api + "invoke"
end
url = URI(url_api)
headers = { "content-type" => "application/json", "Accept" => "*/*" }
payload = { prompt: prompt }
payload[:top_p] = top_p if top_p
payload[:top_k] = top_k if top_k
payload[:max_tokens_to_sample] = max_tokens || 2000
payload[:temperature] = temperature if temperature
payload[:stop_sequences] = stop_sequences if stop_sequences
Net::HTTP.start(
url.host,
url.port,
use_ssl: true,
read_timeout: TIMEOUT,
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
) do |http|
request = Net::HTTP::Post.new(url)
request_body = payload.to_json
request.body = request_body
signed_request =
signer.sign_request(
req: request,
http_method: request.method,
url: url,
body: request.body,
)
request.initialize_http_header(headers.merge!(signed_request.headers))
http.request(request) do |response|
if response.code.to_i != 200
Rails.logger.error(
"BedRockInference: status: #{response.code.to_i} - body: #{response.body}",
)
raise CompletionFailed
end
log =
AiApiAuditLog.create!(
provider_id: AiApiAuditLog::Provider::Anthropic,
raw_request_payload: request_body,
user_id: user_id,
)
if !block_given?
response_body = response.read_body
parsed_response = JSON.parse(response_body, symbolize_names: true)
log.update!(
raw_response_payload: response_body,
request_tokens: tokenizer.size(prompt),
response_tokens: tokenizer.size(parsed_response[:completion]),
)
return parsed_response
end
begin
cancelled = false
cancel = lambda { cancelled = true }
decoder = Aws::EventStream::Decoder.new
response.read_body do |chunk|
if cancelled
http.finish
return
end
response_raw << chunk
begin
message = decoder.decode_chunk(chunk)
partial =
message
.first
.payload
.string
.then { JSON.parse(_1) }
.dig("bytes")
.then { Base64.decode64(_1) }
.then { JSON.parse(_1, symbolize_names: true) }
next if !partial
response_data << partial[:completion].to_s
yield partial, cancel if partial[:completion]
rescue JSON::ParserError,
Aws::EventStream::Errors::MessageChecksumError,
Aws::EventStream::Errors::PreludeChecksumError => e
Rails.logger.error("BedrockInference: #{e}")
end
rescue IOError
raise if !cancelled
end
end
return response_data
end
ensure
if block_given?
log.update!(
raw_response_payload: response_data,
request_tokens: tokenizer.size(prompt),
response_tokens: tokenizer.size(response_data),
)
end
if Rails.env.development? && log
puts "BedrockInference: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
end
end
end
end
end
end

View File

@ -1,158 +0,0 @@
# frozen_string_literal: true
module ::DiscourseAi
module Inference
class AnthropicCompletions
CompletionFailed = Class.new(StandardError)
TIMEOUT = 60
def self.perform!(
prompt,
model = "claude-2",
temperature: nil,
top_p: nil,
max_tokens: nil,
user_id: nil,
stop_sequences: nil,
post: nil,
&blk
)
# HACK to get around the fact that they have different APIs
# we will introduce a proper LLM abstraction layer to handle this shenanigas later this year
if model == "claude-2" && SiteSetting.ai_bedrock_access_key_id.present? &&
SiteSetting.ai_bedrock_secret_access_key.present? &&
SiteSetting.ai_bedrock_region.present?
return(
AmazonBedrockInference.perform!(
prompt,
temperature: temperature,
top_p: top_p,
max_tokens: max_tokens,
user_id: user_id,
stop_sequences: stop_sequences,
&blk
)
)
end
log = nil
response_data = +""
response_raw = +""
url = URI("https://api.anthropic.com/v1/complete")
headers = {
"anthropic-version" => "2023-06-01",
"x-api-key" => SiteSetting.ai_anthropic_api_key,
"content-type" => "application/json",
}
payload = { model: model, prompt: prompt }
payload[:top_p] = top_p if top_p
payload[:max_tokens_to_sample] = max_tokens || 2000
payload[:temperature] = temperature if temperature
payload[:stream] = true if block_given?
payload[:stop_sequences] = stop_sequences if stop_sequences
Net::HTTP.start(
url.host,
url.port,
use_ssl: true,
read_timeout: TIMEOUT,
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
) do |http|
request = Net::HTTP::Post.new(url, headers)
request_body = payload.to_json
request.body = request_body
http.request(request) do |response|
if response.code.to_i != 200
Rails.logger.error(
"AnthropicCompletions: status: #{response.code.to_i} - body: #{response.body}",
)
raise CompletionFailed
end
log =
AiApiAuditLog.create!(
provider_id: AiApiAuditLog::Provider::Anthropic,
raw_request_payload: request_body,
user_id: user_id,
post_id: post&.id,
topic_id: post&.topic_id,
)
if !block_given?
response_body = response.read_body
parsed_response = JSON.parse(response_body, symbolize_names: true)
log.update!(
raw_response_payload: response_body,
request_tokens: DiscourseAi::Tokenizer::AnthropicTokenizer.size(prompt),
response_tokens:
DiscourseAi::Tokenizer::AnthropicTokenizer.size(parsed_response[:completion]),
)
return parsed_response
end
begin
cancelled = false
cancel = lambda { cancelled = true }
response.read_body do |chunk|
if cancelled
http.finish
return
end
response_raw << chunk
chunk
.split("\n")
.each do |line|
data = line.split("data: ", 2)[1]
next if !data
if !cancelled
begin
partial = JSON.parse(data, symbolize_names: true)
response_data << partial[:completion].to_s
# ping has no data... do not yeild it
yield partial, cancel if partial[:completion]
rescue JSON::ParserError
nil
# TODO leftover chunk carry over to next
end
end
end
rescue IOError
raise if !cancelled
end
end
return response_data
end
ensure
if block_given?
log.update!(
raw_response_payload: response_raw,
request_tokens: DiscourseAi::Tokenizer::AnthropicTokenizer.size(prompt),
response_tokens: DiscourseAi::Tokenizer::AnthropicTokenizer.size(response_data),
)
end
if Rails.env.development? && log
puts "AnthropicCompletions: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
end
end
def self.try_parse(data)
JSON.parse(data, symbolize_names: true)
rescue JSON::ParserError
nil
end
end
end
end
end

View File

@ -1,69 +0,0 @@
# frozen_string_literal: true
module ::DiscourseAi
module Inference
class Function
attr_reader :name, :description, :parameters, :type
def initialize(name:, description:, type: nil)
@name = name
@description = description
@type = type || "object"
@parameters = []
end
def add_parameter(parameter = nil, **kwargs)
if parameter
add_parameter_kwargs(
name: parameter.name,
type: parameter.type,
description: parameter.description,
required: parameter.required,
enum: parameter.enum,
item_type: parameter.item_type,
)
else
add_parameter_kwargs(**kwargs)
end
end
def add_parameter_kwargs(
name:,
type:,
description:,
enum: nil,
required: false,
item_type: nil
)
param = { name: name, type: type, description: description, enum: enum, required: required }
param[:enum] = enum if enum
param[:item_type] = item_type if item_type
@parameters << param
end
def to_json(*args)
as_json.to_json(*args)
end
def as_json
required_params = []
properties = {}
parameters.each do |parameter|
definition = { type: parameter[:type], description: parameter[:description] }
definition[:enum] = parameter[:enum] if parameter[:enum]
definition[:items] = { type: parameter[:item_type] } if parameter[:item_type]
required_params << parameter[:name] if parameter[:required]
properties[parameter[:name]] = definition
end
params = { type: @type, properties: properties }
params[:required] = required_params if required_params.present?
{ name: name, description: description, parameters: params }
end
end
end
end

View File

@ -1,119 +0,0 @@
# frozen_string_literal: true
module ::DiscourseAi
module Inference
class FunctionList
def initialize
@functions = []
end
def <<(function)
@functions << function
end
def parse_prompt(prompt)
xml = prompt.sub(%r{<function_calls>(.*)</function_calls>}m, '\1')
if xml.present?
parsed = []
Nokogiri
.XML(xml)
.xpath("//invoke")
.each do |invoke_node|
function = { name: invoke_node.xpath("//tool_name").text, arguments: {} }
parsed << function
invoke_node
.xpath("//parameters")
.children
.each do |parameters_node|
if parameters_node.is_a?(Nokogiri::XML::Element) && name = parameters_node.name
function[:arguments][name.to_sym] = parameters_node.text
end
end
end
coerce_arguments!(parsed)
end
end
def coerce_arguments!(parsed)
parsed.each do |function_call|
arguments = function_call[:arguments]
function = @functions.find { |f| f.name == function_call[:name] }
next if !function
arguments.each do |name, value|
parameter = function.parameters.find { |p| p[:name].to_s == name.to_s }
if !parameter
arguments.delete(name)
next
end
type = parameter[:type]
if type == "array"
begin
arguments[name] = JSON.parse(value)
rescue JSON::ParserError
# maybe LLM chose a different shape for the array
arguments[name] = value.to_s.split("\n").map(&:strip).reject(&:blank?)
end
elsif type == "integer"
arguments[name] = value.to_i
elsif type == "float"
arguments[name] = value.to_f
end
end
end
parsed
end
def system_prompt
tools = +""
@functions.each do |function|
parameters = +""
if function.parameters.present?
parameters << "\n"
function.parameters.each do |parameter|
parameters << <<~PARAMETER
<parameter>
<name>#{parameter[:name]}</name>
<type>#{parameter[:type]}</type>
<description>#{parameter[:description]}</description>
<required>#{parameter[:required]}</required>
PARAMETER
parameters << "<options>#{parameter[:enum].join(",")}</options>\n" if parameter[:enum]
parameters << "</parameter>\n"
end
end
tools << <<~TOOLS
<tool_description>
<tool_name>#{function.name}</tool_name>
<description>#{function.description}</description>
<parameters>#{parameters}</parameters>
</tool_description>
TOOLS
end
<<~PROMPT
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:
<tools>
#{tools}</tools>
PROMPT
end
end
end
end

View File

@ -1,146 +0,0 @@
# frozen_string_literal: true
module ::DiscourseAi
module Inference
class HuggingFaceTextGeneration
CompletionFailed = Class.new(StandardError)
TIMEOUT = 120
def self.perform!(
prompt,
model,
temperature: 0.7,
top_p: nil,
top_k: nil,
typical_p: nil,
max_tokens: 2000,
repetition_penalty: 1.1,
user_id: nil,
tokenizer: DiscourseAi::Tokenizer::Llama2Tokenizer,
token_limit: nil
)
raise CompletionFailed if model.blank?
url = URI(SiteSetting.ai_hugging_face_api_url)
headers = { "Content-Type" => "application/json" }
if SiteSetting.ai_hugging_face_api_key.present?
headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_api_key}"
end
token_limit = token_limit || SiteSetting.ai_hugging_face_token_limit
parameters = {}
payload = { inputs: prompt, parameters: parameters }
prompt_size = tokenizer.size(prompt)
parameters[:top_p] = top_p if top_p
parameters[:top_k] = top_k if top_k
parameters[:typical_p] = typical_p if typical_p
parameters[:max_new_tokens] = token_limit - prompt_size
parameters[:temperature] = temperature if temperature
parameters[:repetition_penalty] = repetition_penalty if repetition_penalty
parameters[:return_full_text] = false
payload[:stream] = true if block_given?
Net::HTTP.start(
url.host,
url.port,
use_ssl: url.scheme == "https",
read_timeout: TIMEOUT,
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
) do |http|
request = Net::HTTP::Post.new(url, headers)
request_body = payload.to_json
request.body = request_body
http.request(request) do |response|
if response.code.to_i != 200
Rails.logger.error(
"HuggingFaceTextGeneration: status: #{response.code.to_i} - body: #{response.body}",
)
raise CompletionFailed
end
log =
AiApiAuditLog.create!(
provider_id: AiApiAuditLog::Provider::HuggingFaceTextGeneration,
raw_request_payload: request_body,
user_id: user_id,
)
if !block_given?
response_body = response.read_body
parsed_response = JSON.parse(response_body, symbolize_names: true)
log.update!(
raw_response_payload: response_body,
request_tokens: tokenizer.size(prompt),
response_tokens: tokenizer.size(parsed_response.first[:generated_text]),
)
return parsed_response
end
response_data = +""
begin
cancelled = false
cancel = lambda { cancelled = true }
response_raw = +""
response.read_body do |chunk|
if cancelled
http.finish
return
end
response_raw << chunk
chunk
.split("\n")
.each do |line|
data = line.split("data:", 2)[1]
next if !data || data.squish == "[DONE]"
if !cancelled
begin
# partial contains the entire payload till now
partial = JSON.parse(data, symbolize_names: true)
# this is the last chunk and contains the full response
next if partial[:token][:special] == true
response_data << partial[:token][:text].to_s
yield partial, cancel
rescue JSON::ParserError
nil
end
end
end
rescue IOError
raise if !cancelled
ensure
log.update!(
raw_response_payload: response_raw,
request_tokens: tokenizer.size(prompt),
response_tokens: tokenizer.size(response_data),
)
end
end
return response_data
end
end
def self.try_parse(data)
JSON.parse(data, symbolize_names: true)
rescue JSON::ParserError
nil
end
end
end
end
end

View File

@ -1,194 +0,0 @@
# frozen_string_literal: true
module ::DiscourseAi
module Inference
class OpenAiCompletions
TIMEOUT = 60
DEFAULT_RETRIES = 3
DEFAULT_RETRY_TIMEOUT_SECONDS = 3
RETRY_TIMEOUT_BACKOFF_MULTIPLIER = 3
CompletionFailed = Class.new(StandardError)
def self.perform!(
messages,
model,
temperature: nil,
top_p: nil,
max_tokens: nil,
functions: nil,
user_id: nil,
retries: DEFAULT_RETRIES,
retry_timeout: DEFAULT_RETRY_TIMEOUT_SECONDS,
post: nil,
&blk
)
log = nil
response_data = +""
response_raw = +""
url =
if model.include?("gpt-4")
if model.include?("turbo") || model.include?("1106-preview")
URI(SiteSetting.ai_openai_gpt4_turbo_url)
elsif model.include?("32k")
URI(SiteSetting.ai_openai_gpt4_32k_url)
else
URI(SiteSetting.ai_openai_gpt4_url)
end
else
if model.include?("16k")
URI(SiteSetting.ai_openai_gpt35_16k_url)
else
URI(SiteSetting.ai_openai_gpt35_url)
end
end
headers = { "Content-Type" => "application/json" }
if url.host.include?("azure")
headers["api-key"] = SiteSetting.ai_openai_api_key
else
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
end
if SiteSetting.ai_openai_organization.present?
headers["OpenAI-Organization"] = SiteSetting.ai_openai_organization
end
payload = { model: model, messages: messages }
payload[:temperature] = temperature if temperature
payload[:top_p] = top_p if top_p
payload[:max_tokens] = max_tokens if max_tokens
payload[:functions] = functions if functions
payload[:stream] = true if block_given?
Net::HTTP.start(
url.host,
url.port,
use_ssl: true,
read_timeout: TIMEOUT,
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
) do |http|
request = Net::HTTP::Post.new(url, headers)
request_body = payload.to_json
request.body = request_body
http.request(request) do |response|
if retries > 0 && response.code.to_i == 429
sleep(retry_timeout)
retries -= 1
retry_timeout *= RETRY_TIMEOUT_BACKOFF_MULTIPLIER
return(
perform!(
messages,
model,
temperature: temperature,
top_p: top_p,
max_tokens: max_tokens,
functions: functions,
user_id: user_id,
retries: retries,
retry_timeout: retry_timeout,
&blk
)
)
elsif response.code.to_i != 200
Rails.logger.error(
"OpenAiCompletions: status: #{response.code.to_i} - body: #{response.body}",
)
raise CompletionFailed, "status: #{response.code.to_i} - body: #{response.body}"
end
log =
AiApiAuditLog.create!(
provider_id: AiApiAuditLog::Provider::OpenAI,
raw_request_payload: request_body,
user_id: user_id,
post_id: post&.id,
topic_id: post&.topic_id,
)
if !blk
response_body = response.read_body
parsed_response = JSON.parse(response_body, symbolize_names: true)
log.update!(
raw_response_payload: response_body,
request_tokens: parsed_response.dig(:usage, :prompt_tokens),
response_tokens: parsed_response.dig(:usage, :completion_tokens),
)
return parsed_response
end
begin
cancelled = false
cancel = lambda { cancelled = true }
leftover = ""
response.read_body do |chunk|
if cancelled
http.finish
break
end
response_raw << chunk
if (leftover + chunk).length < "data: [DONE]".length
leftover += chunk
next
end
(leftover + chunk)
.split("\n")
.each do |line|
data = line.split("data: ", 2)[1]
next if !data || data == "[DONE]"
next if cancelled
partial = nil
begin
partial = JSON.parse(data, symbolize_names: true)
leftover = ""
rescue JSON::ParserError
leftover = line
end
if partial
response_data << partial.dig(:choices, 0, :delta, :content).to_s
response_data << partial.dig(:choices, 0, :delta, :function_call).to_s
blk.call(partial, cancel)
end
end
rescue IOError
raise if !cancelled
end
end
return response_data
end
end
ensure
if log && block_given?
request_tokens = DiscourseAi::Tokenizer::OpenAiTokenizer.size(extract_prompt(messages))
response_tokens = DiscourseAi::Tokenizer::OpenAiTokenizer.size(response_data)
log.update!(
raw_response_payload: response_raw,
request_tokens: request_tokens,
response_tokens: response_tokens,
)
end
if log && Rails.env.development?
puts "OpenAiCompletions: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
end
end
def self.extract_prompt(messages)
messages.map { |message| message[:content] || message["content"] || "" }.join("\n")
end
end
end
end

View File

@ -58,9 +58,7 @@ after_initialize do
end end
if Rails.env.test? if Rails.env.test?
require_relative "spec/support/openai_completions_inference_stubs"
require_relative "spec/support/anthropic_completion_stubs"
require_relative "spec/support/stable_diffusion_stubs"
require_relative "spec/support/embeddings_generation_stubs" require_relative "spec/support/embeddings_generation_stubs"
require_relative "spec/support/stable_diffusion_stubs"
end end
end end

View File

@ -64,7 +64,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end
let(:tool_deltas) { ["<function", <<~REPLY] } let(:tool_deltas) { ["Let me use a tool for that<function", <<~REPLY] }
_calls> _calls>
<invoke> <invoke>
<tool_name>get_weather</tool_name> <tool_name>get_weather</tool_name>

View File

@ -1,183 +0,0 @@
# frozen_string_literal: true
module ::DiscourseAi
module AiBot
describe AnthropicBot do
def bot_user
User.find(EntryPoint::CLAUDE_V2_ID)
end
before do
SiteSetting.ai_bot_enabled_chat_bots = "claude-2"
SiteSetting.ai_bot_enabled = true
end
let(:bot) { described_class.new(bot_user) }
fab!(:post)
describe "system message" do
it "includes the full command framework" do
prompt = bot.system_prompt(post, allow_commands: true)
expect(prompt).to include("read")
expect(prompt).to include("search_query")
end
end
it "does not include half parsed function calls in reply" do
completion1 = "<function"
completion2 = <<~REPLY
_calls>
<invoke>
<tool_name>search</tool_name>
<parameters>
<search_query>hello world</search_query>
</parameters>
</invoke>
</function_calls>
junk
REPLY
completion1 = { completion: completion1 }.to_json
completion2 = { completion: completion2 }.to_json
completion3 = { completion: "<func" }.to_json
request_number = 0
last_body = nil
stub_request(:post, "https://api.anthropic.com/v1/complete").with(
body:
lambda do |body|
last_body = body
request_number == 2
end,
).to_return(status: 200, body: lambda { |request| +"data: #{completion3}" })
stub_request(:post, "https://api.anthropic.com/v1/complete").with(
body:
lambda do |body|
request_number += 1
request_number == 1
end,
).to_return(
status: 200,
body: lambda { |request| +"data: #{completion1}\ndata: #{completion2}" },
)
bot.reply_to(post)
post.topic.reload
raw = post.topic.ordered_posts.last.raw
prompt = JSON.parse(last_body)["prompt"]
# function call is bundled into Assitant prompt
expect(prompt.split("Human:").length).to eq(2)
# this should be stripped
expect(prompt).not_to include("junk")
expect(raw).to end_with("<func")
# leading <function_call> should be stripped
expect(raw).to start_with("\n\n<details")
end
it "does not include Assistant: in front of the system prompt" do
prompt = nil
stub_request(:post, "https://api.anthropic.com/v1/complete").with(
body:
lambda do |body|
json = JSON.parse(body)
prompt = json["prompt"]
true
end,
).to_return(
status: 200,
body: lambda { |request| +"data: " << { completion: "Hello World" }.to_json },
)
bot.reply_to(post)
expect(prompt).not_to be_nil
expect(prompt).not_to start_with("Assistant:")
end
describe "parsing a reply prompt" do
it "can correctly predict that a completion needs to be cancelled" do
functions = DiscourseAi::AiBot::Bot::FunctionCalls.new
# note anthropic API has a silly leading space, we need to make sure we can handle that
prompt = +<<~REPLY.strip
<function_calls>
<invoke>
<tool_name>search</tool_name>
<parameters>
<search_query>hello world</search_query>
<random_stuff>77</random_stuff>
</parameters>
</invoke>
</function_calls
REPLY
bot.populate_functions(
partial: nil,
reply: prompt,
functions: functions,
done: false,
current_delta: "",
)
expect(functions.found?).to eq(true)
expect(functions.cancel_completion?).to eq(false)
prompt << ">"
bot.populate_functions(
partial: nil,
reply: prompt,
functions: functions,
done: true,
current_delta: "",
)
expect(functions.found?).to eq(true)
expect(functions.to_a.length).to eq(1)
expect(functions.to_a).to eq(
[{ name: "search", arguments: "{\"search_query\":\"hello world\"}" }],
)
end
end
describe "#update_with_delta" do
describe "get_delta" do
it "can properly remove first leading space" do
context = {}
reply = +""
reply << bot.get_delta({ completion: " Hello" }, context)
reply << bot.get_delta({ completion: " World" }, context)
expect(reply).to eq("Hello World")
end
it "can properly remove Assistant prefix" do
context = {}
reply = +""
reply << bot.get_delta({ completion: "Hello " }, context)
expect(reply).to eq("Hello ")
reply << bot.get_delta({ completion: "world" }, context)
expect(reply).to eq("Hello world")
end
end
end
end
end
end

View File

@ -1,116 +1,16 @@
# frozen_string_literal: true # frozen_string_literal: true
class FakeBot < DiscourseAi::AiBot::Bot RSpec.describe DiscourseAi::AiBot::Bot do
class Tokenizer subject(:bot) { described_class.as(bot_user) }
def tokenize(text)
text.split(" ")
end
end
def tokenizer
Tokenizer.new
end
def prompt_limit(allow_commands: false)
10_000
end
def build_message(poster_username, content, system: false, function: nil)
role = poster_username == bot_user.username ? "Assistant" : "Human"
"#{role}: #{content}"
end
def submit_prompt(prompt, post: nil, prefer_low_cost: false)
rows = @responses.shift
rows.each { |data| yield data, lambda {} }
end
def get_delta(partial, context)
partial
end
def add_response(response)
@responses ||= []
@responses << response
end
end
describe FakeBot do
before do before do
SiteSetting.ai_bot_enabled_chat_bots = "gpt-4" SiteSetting.ai_bot_enabled_chat_bots = "gpt-4"
SiteSetting.ai_bot_enabled = true SiteSetting.ai_bot_enabled = true
end end
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) } let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) }
fab!(:post) { Fabricate(:post, raw: "hello world") }
it "can handle command truncation for long messages" do let!(:user) { Fabricate(:user) }
bot = FakeBot.new(bot_user)
tags_command = <<~TEXT
<function_calls>
<invoke>
<tool_name>tags</tool_name>
</invoke>
</function_calls>
TEXT
bot.add_response(["hello this is a big test I am testing 123\n", "#{tags_command}\nabc"])
bot.add_response(["this is the reply"])
bot.reply_to(post)
reply = post.topic.posts.order(:post_number).last
expect(reply.raw).not_to include("abc")
expect(reply.post_custom_prompt.custom_prompt.to_s).not_to include("abc")
expect(reply.post_custom_prompt.custom_prompt.length).to eq(3)
expect(reply.post_custom_prompt.custom_prompt[0][0]).to eq(
"hello this is a big test I am testing 123\n#{tags_command.strip}",
)
end
it "can handle command truncation for short bot messages" do
bot = FakeBot.new(bot_user)
tags_command = <<~TEXT
_calls>
<invoke>
<tool_name>tags</tool_name>
</invoke>
</function_calls>
TEXT
bot.add_response(["hello\n<function", "#{tags_command}\nabc"])
bot.add_response(["this is the reply"])
bot.reply_to(post)
reply = post.topic.posts.order(:post_number).last
expect(reply.raw).not_to include("abc")
expect(reply.post_custom_prompt.custom_prompt.to_s).not_to include("abc")
expect(reply.post_custom_prompt.custom_prompt.length).to eq(3)
expect(reply.post_custom_prompt.custom_prompt[0][0]).to eq(
"hello\n<function#{tags_command.strip}",
)
# we don't want function leftovers
expect(reply.raw).to start_with("hello\n\n<details>")
end
end
describe DiscourseAi::AiBot::Bot do
before do
SiteSetting.ai_bot_enabled_chat_bots = "gpt-4"
SiteSetting.ai_bot_enabled = true
end
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) }
let(:bot) { described_class.as(bot_user) }
fab!(:user) { Fabricate(:user) }
let!(:pm) do let!(:pm) do
Fabricate( Fabricate(
:private_message_topic, :private_message_topic,
@ -122,101 +22,39 @@ describe DiscourseAi::AiBot::Bot do
], ],
) )
end end
let!(:first_post) { Fabricate(:post, topic: pm, user: user, raw: "This is a reply by the user") } let!(:pm_post) { Fabricate(:post, topic: pm, user: user, raw: "Does my site has tags?") }
let!(:second_post) do
Fabricate(:post, topic: pm, user: user, raw: "This is a second reply by the user")
end
describe "#system_prompt" do let(:function_call) { <<~TEXT }
it "includes relevant context in system prompt" do Let me try using a function to get more info:<function_calls>
bot.system_prompt_style!(:standard) <invoke>
<tool_name>categories</tool_name>
</invoke>
</function_calls>
TEXT
SiteSetting.title = "My Forum" let(:response) { "As expected, your forum has multiple tags" }
SiteSetting.site_description = "My Forum Description"
system_prompt = bot.system_prompt(second_post, allow_commands: true) let(:llm_responses) { [function_call, response] }
expect(system_prompt).to include(SiteSetting.title) describe "#reply" do
expect(system_prompt).to include(SiteSetting.site_description) context "when using function chaining" do
it "yields a loading placeholder while proceeds to invoke the command" do
tool = DiscourseAi::AiBot::Tools::ListCategories.new({})
partial_placeholder = +(<<~HTML)
<details>
<summary>#{tool.summary}</summary>
<p></p>
</details>
HTML
expect(system_prompt).to include(user.username) context = {}
DiscourseAi::Completions::Llm.with_prepared_responses(llm_responses) do
bot.reply(context) do |_bot_reply_post, cancel, placeholder|
expect(placeholder).to eq(partial_placeholder) if placeholder
end end
end end
describe "#reply_to" do
it "can respond to a search command" do
bot.system_prompt_style!(:simple)
expected_response = {
function_call: {
name: "search",
arguments: { query: "test search" }.to_json,
},
}
prompt = bot.bot_prompt_with_topic_context(second_post, allow_commands: true)
req_opts = bot.reply_params.merge({ functions: bot.available_functions, stream: true })
OpenAiCompletionsInferenceStubs.stub_streamed_response(
prompt,
[expected_response],
model: bot.model_for,
req_opts: req_opts,
)
result =
DiscourseAi::AiBot::Commands::SearchCommand
.new(bot: nil, args: nil)
.process(query: "test search")
.to_json
prompt << { role: "function", content: result, name: "search" }
OpenAiCompletionsInferenceStubs.stub_streamed_response(
prompt,
[content: "I found nothing, sorry"],
model: bot.model_for,
req_opts: req_opts,
)
bot.reply_to(second_post)
last = second_post.topic.posts.order("id desc").first
expect(last.raw).to include("<details>")
expect(last.raw).to include("<summary>Search</summary>")
expect(last.raw).not_to include("translation missing")
expect(last.raw).to include("I found nothing")
expect(last.post_custom_prompt.custom_prompt).to eq(
[[result, "search", "function"], ["I found nothing, sorry", bot_user.username]],
)
log = AiApiAuditLog.find_by(post_id: second_post.id)
expect(log).to be_present
end end
end end
describe "#update_pm_title" do
let(:expected_response) { "This is a suggested title" }
before { SiteSetting.min_personal_message_post_length = 5 }
it "updates the title using bot suggestions" do
OpenAiCompletionsInferenceStubs.stub_response(
bot.title_prompt(second_post),
expected_response,
model: bot.model_for,
req_opts: {
temperature: 0.7,
top_p: 0.9,
max_tokens: 40,
},
)
bot.update_pm_title(second_post)
expect(pm.reload.title).to eq(expected_response)
end
end end
end end

View File

@ -1,13 +0,0 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Commands::CategoriesCommand do
describe "#generate_categories_info" do
it "can generate correct info" do
Fabricate(:category, name: "america", posts_year: 999)
info = DiscourseAi::AiBot::Commands::CategoriesCommand.new(bot: nil, args: nil).process
expect(info.to_s).to include("america")
expect(info.to_s).to include("999")
end
end
end

View File

@ -1,36 +0,0 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Commands::Command do
let(:command) { DiscourseAi::AiBot::Commands::GoogleCommand.new(bot: nil, args: nil) }
before { SiteSetting.ai_bot_enabled = true }
describe "#format_results" do
it "can generate efficient tables of data" do
rows = [1, 2, 3, 4, 5]
column_names = %w[first second third]
formatted =
command.format_results(rows, column_names) { |row| ["row ¦ 1", row + 1, "a|b,\nc"] }
expect(formatted[:column_names].length).to eq(3)
expect(formatted[:rows].length).to eq(5)
expect(formatted.to_s).to include("a|b,\\nc")
end
it "can also generate results by returning hash per row" do
rows = [1, 2, 3, 4, 5]
column_names = %w[first second third]
formatted =
command.format_results(rows, column_names) { |row| ["row ¦ 1", row + 1, "a|b,\nc"] }
formatted2 =
command.format_results(rows) do |row|
{ first: "row ¦ 1", second: row + 1, third: "a|b,\nc" }
end
expect(formatted).to eq(formatted2)
end
end
end

View File

@ -1,207 +0,0 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
before { SearchIndexer.enable }
after { SearchIndexer.disable }
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
fab!(:admin)
fab!(:parent_category) { Fabricate(:category, name: "animals") }
fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") }
fab!(:tag_funny) { Fabricate(:tag, name: "funny") }
fab!(:tag_sad) { Fabricate(:tag, name: "sad") }
fab!(:tag_hidden) { Fabricate(:tag, name: "hidden") }
fab!(:staff_tag_group) do
tag_group = Fabricate.build(:tag_group, name: "Staff only", tag_names: ["hidden"])
tag_group.permissions = [
[Group::AUTO_GROUPS[:staff], TagGroupPermission.permission_types[:full]],
]
tag_group.save!
tag_group
end
fab!(:topic_with_tags) do
Fabricate(:topic, category: category, tags: [tag_funny, tag_sad, tag_hidden])
end
before { SiteSetting.ai_bot_enabled = true }
it "can properly list options" do
options = described_class.options.sort_by(&:name)
expect(options.length).to eq(2)
expect(options.first.name.to_s).to eq("base_query")
expect(options.first.localized_name).not_to include("Translation missing:")
expect(options.first.localized_description).not_to include("Translation missing:")
expect(options.second.name.to_s).to eq("max_results")
expect(options.second.localized_name).not_to include("Translation missing:")
expect(options.second.localized_description).not_to include("Translation missing:")
end
describe "#process" do
it "can retreive options from persona correctly" do
persona =
Fabricate(
:ai_persona,
allowed_group_ids: [Group::AUTO_GROUPS[:admins]],
commands: [["SearchCommand", { "base_query" => "#funny" }]],
)
Group.refresh_automatic_groups!
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona_id: persona.id, user: admin)
search_post = Fabricate(:post, topic: topic_with_tags)
bot_post = Fabricate(:post)
search = described_class.new(bot: bot, post: bot_post, args: nil)
results = search.process(order: "latest")
expect(results[:rows].length).to eq(1)
search_post.topic.tags = []
search_post.topic.save!
# no longer has the tag funny
results = search.process(order: "latest")
expect(results[:rows].length).to eq(0)
end
it "can handle no results" do
post1 = Fabricate(:post, topic: topic_with_tags)
search = described_class.new(bot: nil, post: post1, args: nil)
results = search.process(query: "order:fake ABDDCDCEDGDG")
expect(results[:args]).to eq({ query: "order:fake ABDDCDCEDGDG" })
expect(results[:rows]).to eq([])
end
describe "semantic search" do
let (:query) {
"this is an expanded search"
}
after { DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query) }
it "supports semantic search when enabled" do
SiteSetting.ai_embeddings_semantic_search_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
WebMock.stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
status: 200,
body: JSON.dump(OpenAiCompletionsInferenceStubs.response(query)),
)
hyde_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
query,
hyde_embedding,
)
post1 = Fabricate(:post, topic: topic_with_tags)
search = described_class.new(bot: nil, post: post1, args: nil)
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
.any_instance
.expects(:asymmetric_topics_similarity_search)
.returns([post1.topic_id])
results =
DiscourseAi::Completions::Llm.with_prepared_responses(["<ai>#{query}</ai>"]) do
search.process(search_query: "hello world, sam", status: "public")
end
expect(results[:args]).to eq({ search_query: "hello world, sam", status: "public" })
expect(results[:rows].length).to eq(1)
end
end
it "supports subfolder properly" do
Discourse.stubs(:base_path).returns("/subfolder")
post1 = Fabricate(:post, topic: topic_with_tags)
search = described_class.new(bot: nil, post: post1, args: nil)
results = search.process(limit: 1, user: post1.user.username)
expect(results[:rows].to_s).to include("/subfolder" + post1.url)
end
it "returns rich topic information" do
post1 = Fabricate(:post, like_count: 1, topic: topic_with_tags)
search = described_class.new(bot: nil, post: post1, args: nil)
post1.topic.update!(views: 100, posts_count: 2, like_count: 10)
results = search.process(user: post1.user.username)
row = results[:rows].first
category = row[results[:column_names].index("category")]
expect(category).to eq("animals > amazing-cat")
tags = row[results[:column_names].index("tags")]
expect(tags).to eq("funny, sad")
likes = row[results[:column_names].index("likes")]
expect(likes).to eq(1)
username = row[results[:column_names].index("username")]
expect(username).to eq(post1.user.username)
likes = row[results[:column_names].index("topic_likes")]
expect(likes).to eq(10)
views = row[results[:column_names].index("topic_views")]
expect(views).to eq(100)
replies = row[results[:column_names].index("topic_replies")]
expect(replies).to eq(1)
end
it "scales results to number of tokens" do
SiteSetting.ai_bot_enabled_chat_bots = "gpt-3.5-turbo|gpt-4|claude-2"
post1 = Fabricate(:post)
gpt_3_5_turbo =
DiscourseAi::AiBot::Bot.as(User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID))
gpt4 = DiscourseAi::AiBot::Bot.as(User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID))
claude = DiscourseAi::AiBot::Bot.as(User.find(DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID))
expect(described_class.new(bot: claude, post: post1, args: nil).max_results).to eq(60)
expect(described_class.new(bot: gpt_3_5_turbo, post: post1, args: nil).max_results).to eq(40)
expect(described_class.new(bot: gpt4, post: post1, args: nil).max_results).to eq(20)
persona =
Fabricate(
:ai_persona,
commands: [["SearchCommand", { "max_results" => 6 }]],
enabled: true,
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
)
Group.refresh_automatic_groups!
custom_bot = DiscourseAi::AiBot::Bot.as(bot_user, persona_id: persona.id, user: admin)
expect(described_class.new(bot: custom_bot, post: post1, args: nil).max_results).to eq(6)
end
it "can handle limits" do
post1 = Fabricate(:post, topic: topic_with_tags)
_post2 = Fabricate(:post, user: post1.user)
_post3 = Fabricate(:post, user: post1.user)
# search has no built in support for limit: so handle it from the outside
search = described_class.new(bot: nil, post: post1, args: nil)
results = search.process(limit: 2, user: post1.user.username)
expect(results[:rows].length).to eq(2)
# just searching for everything
results = search.process(order: "latest_topic")
expect(results[:rows].length).to be > 1
end
end
end

View File

@ -1,29 +0,0 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Commands::SearchSettingsCommand do
let(:search) { described_class.new(bot: nil, args: nil) }
describe "#process" do
it "can handle no results" do
results = search.process(query: "this will not exist frogs")
expect(results[:args]).to eq({ query: "this will not exist frogs" })
expect(results[:rows]).to eq([])
end
it "can return more many settings with no descriptions if there are lots of hits" do
results = search.process(query: "a")
expect(results[:rows].length).to be > 30
expect(results[:rows][0].length).to eq(1)
end
it "can return descriptions if there are few matches" do
results =
search.process(query: "this will not be found!@,default_locale,ai_bot_enabled_chat_bots")
expect(results[:rows].length).to eq(2)
expect(results[:rows][0][1]).not_to eq(nil)
end
end
end

View File

@ -1,42 +0,0 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Commands::SummarizeCommand do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:bot) { DiscourseAi::AiBot::OpenAiBot.new(bot_user) }
before { SiteSetting.ai_bot_enabled = true }
describe "#process" do
it "can generate correct info" do
post = Fabricate(:post)
WebMock.stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
status: 200,
body: JSON.dump({ choices: [{ message: { content: "summary stuff" } }] }),
)
summarizer = described_class.new(bot: bot, args: nil, post: post)
info = summarizer.process(topic_id: post.topic_id, guidance: "why did it happen?")
expect(info).to include("Topic summarized")
expect(summarizer.custom_raw).to include("summary stuff")
expect(summarizer.chain_next_response).to eq(false)
end
it "protects hidden data" do
category = Fabricate(:category)
category.set_permissions({})
category.save!
topic = Fabricate(:topic, category_id: category.id)
post = Fabricate(:post, topic: topic)
summarizer = described_class.new(bot: bot, post: post, args: nil)
info = summarizer.process(topic_id: post.topic_id, guidance: "why did it happen?")
expect(info).not_to include(post.raw)
expect(summarizer.custom_raw).to eq(I18n.t("discourse_ai.ai_bot.topic_not_found"))
end
end
end

View File

@ -1,17 +0,0 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Commands::TagsCommand do
describe "#process" do
it "can generate correct info" do
SiteSetting.tagging_enabled = true
Fabricate(:tag, name: "america", public_topic_count: 100)
Fabricate(:tag, name: "not_here", public_topic_count: 0)
info = DiscourseAi::AiBot::Commands::TagsCommand.new(bot: nil, args: nil).process
expect(info.to_s).to include("america")
expect(info.to_s).not_to include("not_here")
end
end
end

View File

@ -1,11 +1,7 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe Jobs::CreateAiReply do RSpec.describe Jobs::CreateAiReply do
before do before { SiteSetting.ai_bot_enabled = true }
# got to do this cause we include times in system message
freeze_time
SiteSetting.ai_bot_enabled = true
end
describe "#execute" do describe "#execute" do
fab!(:topic) { Fabricate(:topic) } fab!(:topic) { Fabricate(:topic) }
@ -17,95 +13,15 @@ RSpec.describe Jobs::CreateAiReply do
before { SiteSetting.min_personal_message_post_length = 5 } before { SiteSetting.min_personal_message_post_length = 5 }
context "when chatting with the Open AI bot" do it "adds a reply from the bot" do
let(:deltas) { expected_response.split(" ").map { |w| { content: "#{w} " } } } DiscourseAi::Completions::Llm.with_prepared_responses([expected_response]) do
before do
bot_user = User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID)
bot = DiscourseAi::AiBot::Bot.as(bot_user)
# time needs to be frozen so time in prompt does not drift
freeze_time
OpenAiCompletionsInferenceStubs.stub_streamed_response(
DiscourseAi::AiBot::OpenAiBot.new(bot_user).bot_prompt_with_topic_context(
post,
allow_commands: true,
),
deltas,
model: bot.model_for,
req_opts: {
temperature: 0.4,
top_p: 0.9,
max_tokens: 2500,
functions: bot.available_functions,
stream: true,
},
)
end
it "adds a reply from the GPT bot" do
subject.execute(
post_id: topic.first_post.id,
bot_user_id: DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID,
)
expect(topic.posts.last.raw).to eq(expected_response)
end
it "streams the reply on the fly to the client through MB" do
messages =
MessageBus.track_publish("discourse-ai/ai-bot/topic/#{topic.id}") do
subject.execute( subject.execute(
post_id: topic.first_post.id, post_id: topic.first_post.id,
bot_user_id: DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID, bot_user_id: DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID,
) )
end end
done_signal = messages.pop
expect(messages.length).to eq(deltas.length)
messages.each_with_index do |m, idx|
expect(m.data[:raw]).to eq(deltas[0..(idx + 1)].map { |d| d[:content] }.join)
end
expect(done_signal.data[:done]).to eq(true)
end
end
context "when chatting with Claude from Anthropic" do
let(:claude_response) { "#{expected_response}" }
let(:deltas) { claude_response.split(" ").map { |w| "#{w} " } }
before do
SiteSetting.ai_bot_enabled_chat_bots = "claude-2"
bot_user = User.find(DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID)
AnthropicCompletionStubs.stub_streamed_response(
DiscourseAi::AiBot::AnthropicBot.new(bot_user).bot_prompt_with_topic_context(
post,
allow_commands: true,
),
deltas,
model: "claude-2",
req_opts: {
max_tokens_to_sample: 3000,
temperature: 0.4,
stream: true,
stop_sequences: ["\n\nHuman:", "</function_calls>"],
},
)
end
it "adds a reply from the Claude bot" do
subject.execute(
post_id: topic.first_post.id,
bot_user_id: DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID,
)
expect(topic.posts.last.raw).to eq(expected_response) expect(topic.posts.last.raw).to eq(expected_response)
end end
end end
end
end end

View File

@ -12,23 +12,6 @@ RSpec.describe Jobs::UpdateAiBotPmTitle do
it "will properly update title on bot PMs" do it "will properly update title on bot PMs" do
SiteSetting.ai_bot_allowed_groups = Group::AUTO_GROUPS[:staff] SiteSetting.ai_bot_allowed_groups = Group::AUTO_GROUPS[:staff]
Jobs.run_immediately!
WebMock
.stub_request(:post, "https://api.anthropic.com/v1/complete")
.with(body: /You are a helpful Discourse assistant/)
.to_return(status: 200, body: "data: {\"completion\": \"Hello back at you\"}", headers: {})
WebMock
.stub_request(:post, "https://api.anthropic.com/v1/complete")
.with(body: /Suggest a 7 word title/)
.to_return(
status: 200,
body: "{\"completion\": \"A great title would be:\n\nMy amazing title\n\n\"}",
headers: {
},
)
post = post =
create_post( create_post(
user: user, user: user,
@ -38,11 +21,20 @@ RSpec.describe Jobs::UpdateAiBotPmTitle do
target_usernames: bot_user.username, target_usernames: bot_user.username,
) )
title_result = "A great title would be:\n\nMy amazing title\n\n"
DiscourseAi::Completions::Llm.with_prepared_responses([title_result]) do
subject.execute(bot_user_id: bot_user.id, post_id: post.id)
expect(post.reload.topic.title).to eq("My amazing title") expect(post.reload.topic.title).to eq("My amazing title")
end
WebMock.reset! another_title = "I'm a different title"
Jobs::UpdateAiBotPmTitle.new.execute(bot_user_id: bot_user.id, post_id: post.id) DiscourseAi::Completions::Llm.with_prepared_responses([another_title]) do
# should be a no op cause title is updated subject.execute(bot_user_id: bot_user.id, post_id: post.id)
expect(post.reload.topic.title).to eq("My amazing title")
end
end end
end end

View File

@ -1,80 +0,0 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::OpenAiBot do
describe "#bot_prompt_with_topic_context" do
fab!(:topic) { Fabricate(:topic) }
def post_body(post_number)
"This is post #{post_number}"
end
def bot_user
User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID)
end
subject { described_class.new(bot_user) }
before do
SiteSetting.ai_bot_enabled_chat_bots = "gpt-4"
SiteSetting.ai_bot_enabled = true
end
context "when cleaning usernames" do
it "can properly clean usernames so OpenAI allows it" do
expect(subject.clean_username("test test")).to eq("test_test")
expect(subject.clean_username("test.test")).to eq("test_test")
expect(subject.clean_username("test😀test")).to eq("test_test")
end
end
context "when the topic has one post" do
fab!(:post_1) { Fabricate(:post, topic: topic, raw: post_body(1), post_number: 1) }
it "includes it in the prompt" do
prompt_messages = subject.bot_prompt_with_topic_context(post_1, allow_commands: true)
post_1_message = prompt_messages[-1]
expect(post_1_message[:role]).to eq("user")
expect(post_1_message[:content]).to eq(post_body(1))
expect(post_1_message[:name]).to eq(post_1.user.username)
end
end
context "when prompt gets very long" do
fab!(:post_1) { Fabricate(:post, topic: topic, raw: "test " * 6000, post_number: 1) }
it "trims the prompt" do
prompt_messages = subject.bot_prompt_with_topic_context(post_1, allow_commands: true)
# trimming is tricky... it needs to account for system message as
# well... just make sure we trim for now
expect(prompt_messages[-1][:content].length).to be < post_1.raw.length
end
end
context "when the topic has multiple posts" do
let!(:post_1) { Fabricate(:post, topic: topic, raw: post_body(1), post_number: 1) }
let!(:post_2) do
Fabricate(:post, topic: topic, user: bot_user, raw: post_body(2), post_number: 2)
end
let!(:post_3) { Fabricate(:post, topic: topic, raw: post_body(3), post_number: 3) }
it "includes them in the prompt respecting the post number order" do
prompt_messages = subject.bot_prompt_with_topic_context(post_3, allow_commands: true)
# negative cause we may have grounding prompts
expect(prompt_messages[-3][:role]).to eq("user")
expect(prompt_messages[-3][:content]).to eq(post_body(1))
expect(prompt_messages[-3][:name]).to eq(post_1.username)
expect(prompt_messages[-2][:role]).to eq("assistant")
expect(prompt_messages[-2][:content]).to eq(post_body(2))
expect(prompt_messages[-1][:role]).to eq("user")
expect(prompt_messages[-1][:content]).to eq(post_body(3))
expect(prompt_messages[-1][:name]).to eq(post_3.username)
end
end
end
end

View File

@ -1,11 +1,11 @@
#frozen_string_literal: true #frozen_string_literal: true
class TestPersona < DiscourseAi::AiBot::Personas::Persona class TestPersona < DiscourseAi::AiBot::Personas::Persona
def commands def tools
[ [
DiscourseAi::AiBot::Commands::TagsCommand, DiscourseAi::AiBot::Tools::ListTags,
DiscourseAi::AiBot::Commands::SearchCommand, DiscourseAi::AiBot::Tools::Search,
DiscourseAi::AiBot::Commands::ImageCommand, DiscourseAi::AiBot::Tools::Image,
] ]
end end
@ -37,41 +37,36 @@ module DiscourseAi::AiBot::Personas
AiPersona.persona_cache.flush! AiPersona.persona_cache.flush!
end end
fab!(:user) let(:context) do
{
it "can disable commands" do site_url: Discourse.base_url,
persona = TestPersona.new site_title: "test site title",
site_description: "test site description",
rendered = persona.render_system_prompt(topic: topic_with_users, allow_commands: false) time: Time.zone.now,
participants: topic_with_users.allowed_users.map(&:username).join(", "),
expect(rendered).not_to include("!tags") }
expect(rendered).not_to include("!search")
end end
fab!(:user)
it "renders the system prompt" do it "renders the system prompt" do
freeze_time freeze_time
SiteSetting.title = "test site title" rendered = persona.craft_prompt(context)
SiteSetting.site_description = "test site description"
rendered = expect(rendered[:insts]).to include(Discourse.base_url)
persona.render_system_prompt(topic: topic_with_users, render_function_instructions: true) expect(rendered[:insts]).to include("test site title")
expect(rendered[:insts]).to include("test site description")
expect(rendered[:insts]).to include("joe, jane")
expect(rendered[:insts]).to include(Time.zone.now.to_s)
tools = rendered[:tools]
expect(tools.find { |t| t[:name] == "search" }).to be_present
expect(tools.find { |t| t[:name] == "tags" }).to be_present
expect(rendered).to include(Discourse.base_url)
expect(rendered).to include("test site title")
expect(rendered).to include("test site description")
expect(rendered).to include("joe, jane")
expect(rendered).to include(Time.zone.now.to_s)
expect(rendered).to include("<tool_name>search</tool_name>")
expect(rendered).to include("<tool_name>tags</tool_name>")
# needs to be configured so it is not available # needs to be configured so it is not available
expect(rendered).not_to include("<tool_name>image</tool_name>") expect(tools.find { |t| t[:name] == "image" }).to be_nil
rendered =
persona.render_system_prompt(topic: topic_with_users, render_function_instructions: false)
expect(rendered).not_to include("<tool_name>search</tool_name>")
expect(rendered).not_to include("<tool_name>tags</tool_name>")
end end
describe "custom personas" do describe "custom personas" do
@ -88,31 +83,29 @@ module DiscourseAi::AiBot::Personas
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
) )
custom_persona = DiscourseAi::AiBot::Personas.all(user: user).last custom_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last
expect(custom_persona.name).to eq("zzzpun_bot") expect(custom_persona.name).to eq("zzzpun_bot")
expect(custom_persona.description).to eq("you write puns") expect(custom_persona.description).to eq("you write puns")
instance = custom_persona.new instance = custom_persona.new
expect(instance.commands).to eq([DiscourseAi::AiBot::Commands::ImageCommand]) expect(instance.tools).to eq([DiscourseAi::AiBot::Tools::Image])
expect(instance.render_system_prompt(render_function_instructions: true)).to eq( expect(instance.craft_prompt(context).dig(:insts)).to eq("you are pun bot\n\n")
"you are pun bot",
)
# should update # should update
persona.update!(name: "zzzpun_bot2") persona.update!(name: "zzzpun_bot2")
custom_persona = DiscourseAi::AiBot::Personas.all(user: user).last custom_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last
expect(custom_persona.name).to eq("zzzpun_bot2") expect(custom_persona.name).to eq("zzzpun_bot2")
# can be disabled # can be disabled
persona.update!(enabled: false) persona.update!(enabled: false)
last_persona = DiscourseAi::AiBot::Personas.all(user: user).last last_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last
expect(last_persona.name).not_to eq("zzzpun_bot2") expect(last_persona.name).not_to eq("zzzpun_bot2")
persona.update!(enabled: true) persona.update!(enabled: true)
# no groups have access # no groups have access
persona.update!(allowed_group_ids: []) persona.update!(allowed_group_ids: [])
last_persona = DiscourseAi::AiBot::Personas.all(user: user).last last_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last
expect(last_persona.name).not_to eq("zzzpun_bot2") expect(last_persona.name).not_to eq("zzzpun_bot2")
end end
end end
@ -127,7 +120,7 @@ module DiscourseAi::AiBot::Personas
SiteSetting.ai_google_custom_search_cx = "abc123" SiteSetting.ai_google_custom_search_cx = "abc123"
# should be ordered by priority and then alpha # should be ordered by priority and then alpha
expect(DiscourseAi::AiBot::Personas.all(user: user)).to eq( expect(DiscourseAi::AiBot::Personas::Persona.all(user: user)).to eq(
[General, Artist, Creative, Researcher, SettingsExplorer, SqlHelper], [General, Artist, Creative, Researcher, SettingsExplorer, SqlHelper],
) )
@ -135,18 +128,18 @@ module DiscourseAi::AiBot::Personas
SiteSetting.ai_stability_api_key = "" SiteSetting.ai_stability_api_key = ""
SiteSetting.ai_google_custom_search_api_key = "" SiteSetting.ai_google_custom_search_api_key = ""
expect(DiscourseAi::AiBot::Personas.all(user: user)).to contain_exactly( expect(DiscourseAi::AiBot::Personas::Persona.all(user: user)).to contain_exactly(
General, General,
SqlHelper, SqlHelper,
SettingsExplorer, SettingsExplorer,
Creative, Creative,
) )
AiPersona.find(DiscourseAi::AiBot::Personas.system_personas[General]).update!( AiPersona.find(DiscourseAi::AiBot::Personas::Persona.system_personas[General]).update!(
enabled: false, enabled: false,
) )
expect(DiscourseAi::AiBot::Personas.all(user: user)).to contain_exactly( expect(DiscourseAi::AiBot::Personas::Persona.all(user: user)).to contain_exactly(
SqlHelper, SqlHelper,
SettingsExplorer, SettingsExplorer,
Creative, Creative,

View File

@ -6,6 +6,6 @@ RSpec.describe DiscourseAi::AiBot::Personas::Researcher do
end end
it "renders schema" do it "renders schema" do
expect(researcher.commands).to eq([DiscourseAi::AiBot::Commands::GoogleCommand]) expect(researcher.tools).to eq([DiscourseAi::AiBot::Tools::Google])
end end
end end

View File

@ -6,18 +6,15 @@ RSpec.describe DiscourseAi::AiBot::Personas::SettingsExplorer do
end end
it "renders schema" do it "renders schema" do
prompt = settings_explorer.render_system_prompt prompt = settings_explorer.system_prompt
# check we do not render plugin settings # check we do not render plugin settings
expect(prompt).not_to include("ai_bot_enabled_personas") expect(prompt).not_to include("ai_bot_enabled_personas")
expect(prompt).to include("site_description") expect(prompt).to include("site_description")
expect(settings_explorer.available_commands).to eq( expect(settings_explorer.tools).to eq(
[ [DiscourseAi::AiBot::Tools::SettingContext, DiscourseAi::AiBot::Tools::SearchSettings],
DiscourseAi::AiBot::Commands::SettingContextCommand,
DiscourseAi::AiBot::Commands::SearchSettingsCommand,
],
) )
end end
end end

View File

@ -6,12 +6,12 @@ RSpec.describe DiscourseAi::AiBot::Personas::SqlHelper do
end end
it "renders schema" do it "renders schema" do
prompt = sql_helper.render_system_prompt prompt = sql_helper.system_prompt
expect(prompt).to include("posts(") expect(prompt).to include("posts(")
expect(prompt).to include("topics(") expect(prompt).to include("topics(")
expect(prompt).not_to include("translation_key") # not a priority table expect(prompt).not_to include("translation_key") # not a priority table
expect(prompt).to include("user_api_keys") # not a priority table expect(prompt).to include("user_api_keys") # not a priority table
expect(sql_helper.available_commands).to eq([DiscourseAi::AiBot::Commands::DbSchemaCommand]) expect(sql_helper.tools).to eq([DiscourseAi::AiBot::Tools::DbSchema])
end end
end end

View File

@ -0,0 +1,168 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Playground do
subject(:playground) { described_class.new(bot) }
before do
SiteSetting.ai_bot_enabled_chat_bots = "gpt-4"
SiteSetting.ai_bot_enabled = true
end
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) }
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user) }
fab!(:user) { Fabricate(:user) }
let!(:pm) do
Fabricate(
:private_message_topic,
title: "This is my special PM",
user: user,
topic_allowed_users: [
Fabricate.build(:topic_allowed_user, user: user),
Fabricate.build(:topic_allowed_user, user: bot_user),
],
)
end
let!(:first_post) do
Fabricate(:post, topic: pm, user: user, post_number: 1, raw: "This is a reply by the user")
end
let!(:second_post) do
Fabricate(:post, topic: pm, user: bot_user, post_number: 2, raw: "This is a bot reply")
end
let!(:third_post) do
Fabricate(
:post,
topic: pm,
user: user,
post_number: 3,
raw: "This is a second reply by the user",
)
end
describe "#title_playground" do
let(:expected_response) { "This is a suggested title" }
before { SiteSetting.min_personal_message_post_length = 5 }
it "updates the title using bot suggestions" do
DiscourseAi::Completions::Llm.with_prepared_responses([expected_response]) do
playground.title_playground(third_post)
expect(pm.reload.title).to eq(expected_response)
end
end
end
describe "#reply_to" do
it "streams the bot reply through MB and create a new post in the PM with a cooked responses" do
expected_bot_response =
"Hello this is a bot and what you just said is an interesting question"
DiscourseAi::Completions::Llm.with_prepared_responses([expected_bot_response]) do
messages =
MessageBus.track_publish("discourse-ai/ai-bot/topic/#{pm.id}") do
playground.reply_to(third_post)
end
done_signal = messages.pop
expect(done_signal.data[:done]).to eq(true)
messages.each_with_index do |m, idx|
expect(m.data[:raw]).to eq(expected_bot_response[0..idx])
end
expect(pm.reload.posts.last.cooked).to eq(PrettyText.cook(expected_bot_response))
end
end
end
describe "#conversation_context" do
it "includes previous posts ordered by post_number" do
context = playground.conversation_context(third_post)
expect(context).to contain_exactly(
*[
{ type: "user", name: user.username, content: third_post.raw },
{ type: "assistant", content: second_post.raw },
{ type: "user", name: user.username, content: first_post.raw },
],
)
end
it "only include regular posts" do
first_post.update!(post_type: Post.types[:whisper])
context = playground.conversation_context(third_post)
expect(context).to contain_exactly(
*[
{ type: "user", name: user.username, content: third_post.raw },
{ type: "assistant", content: second_post.raw },
],
)
end
context "with custom prompts" do
it "When post custom prompt is present, we use that instead of the post content" do
custom_prompt = [
[
{ args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
"time",
"tool",
],
[
{ name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json,
"time",
"tool_call",
],
["I replied this thanks to the time command", bot_user.username],
]
PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
context = playground.conversation_context(third_post)
expect(context).to contain_exactly(
*[
{ type: "user", name: user.username, content: third_post.raw },
{ type: "assistant", content: custom_prompt.third.first },
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
{ type: "tool", name: "time", content: custom_prompt.first.first },
{ type: "user", name: user.username, content: first_post.raw },
],
)
end
end
it "include replies generated from tools only once" do
custom_prompt = [
[
{ args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
"time",
"tool",
],
[
{ name: "time", arguments: { name: "time", timezone: "Buenos Aires" }.to_json }.to_json,
"time",
"tool_call",
],
["I replied this thanks to the time command", bot_user.username],
]
PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
PostCustomPrompt.create!(post: first_post, custom_prompt: custom_prompt)
context = playground.conversation_context(third_post)
expect(context).to contain_exactly(
*[
{ type: "user", name: user.username, content: third_post.raw },
{ type: "assistant", content: custom_prompt.third.first },
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
{ type: "tool", name: "time", content: custom_prompt.first.first },
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
{ type: "tool", name: "time", content: custom_prompt.first.first },
],
)
end
end
end

View File

@ -1,8 +1,13 @@
#frozen_string_literal: true #frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do RSpec.describe DiscourseAi::AiBot::Tools::DallE do
subject(:dall_e) { described_class.new({ prompts: prompts }) }
let(:prompts) { ["a pink cow", "a red cow"] }
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:bot) { DiscourseAi::AiBot::OpenAiBot.new(bot_user) } let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:progress_blk) { Proc.new {} }
before { SiteSetting.ai_bot_enabled = true } before { SiteSetting.ai_bot_enabled = true }
@ -17,7 +22,6 @@ RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
data = [{ b64_json: image, revised_prompt: "a pink cow 1" }] data = [{ b64_json: image, revised_prompt: "a pink cow 1" }]
prompts = ["a pink cow", "a red cow"]
WebMock WebMock
.stub_request(:post, SiteSetting.ai_openai_dall_e_3_url) .stub_request(:post, SiteSetting.ai_openai_dall_e_3_url)
@ -30,14 +34,12 @@ RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do
end end
.to_return(status: 200, body: { data: data }.to_json) .to_return(status: 200, body: { data: data }.to_json)
image = described_class.new(bot: bot, post: post, args: nil) info = dall_e.invoke(bot_user, llm, &progress_blk).to_json
info = image.process(prompts: prompts).to_json
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"]) expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"])
expect(image.custom_raw).to include("upload://") expect(subject.custom_raw).to include("upload://")
expect(image.custom_raw).to include("[grid]") expect(subject.custom_raw).to include("[grid]")
expect(image.custom_raw).to include("a pink cow 1") expect(subject.custom_raw).to include("a pink cow 1")
end end
it "can generate correct info" do it "can generate correct info" do
@ -49,7 +51,6 @@ RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
data = [{ b64_json: image, revised_prompt: "a pink cow 1" }] data = [{ b64_json: image, revised_prompt: "a pink cow 1" }]
prompts = ["a pink cow", "a red cow"]
WebMock WebMock
.stub_request(:post, "https://api.openai.com/v1/images/generations") .stub_request(:post, "https://api.openai.com/v1/images/generations")
@ -60,14 +61,12 @@ RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do
end end
.to_return(status: 200, body: { data: data }.to_json) .to_return(status: 200, body: { data: data }.to_json)
image = described_class.new(bot: bot, post: post, args: nil) info = dall_e.invoke(bot_user, llm, &progress_blk).to_json
info = image.process(prompts: prompts).to_json
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"]) expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"])
expect(image.custom_raw).to include("upload://") expect(subject.custom_raw).to include("upload://")
expect(image.custom_raw).to include("[grid]") expect(subject.custom_raw).to include("[grid]")
expect(image.custom_raw).to include("a pink cow 1") expect(subject.custom_raw).to include("a pink cow 1")
end end
end end
end end

View File

@ -1,10 +1,14 @@
#frozen_string_literal: true #frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Commands::DbSchemaCommand do RSpec.describe DiscourseAi::AiBot::Tools::DbSchema do
let(:command) { DiscourseAi::AiBot::Commands::DbSchemaCommand.new(bot: nil, args: nil) } let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
before { SiteSetting.ai_bot_enabled = true }
describe "#process" do describe "#process" do
it "returns rich schema for tables" do it "returns rich schema for tables" do
result = command.process(tables: "posts,topics") result = described_class.new({ tables: "posts,topics" }).invoke(bot_user, llm)
expect(result[:schema_info]).to include("raw text") expect(result[:schema_info]).to include("raw text")
expect(result[:schema_info]).to include("views integer") expect(result[:schema_info]).to include("views integer")
expect(result[:schema_info]).to include("posts") expect(result[:schema_info]).to include("posts")

View File

@ -1,8 +1,11 @@
#frozen_string_literal: true #frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do RSpec.describe DiscourseAi::AiBot::Tools::Google do
subject(:search) { described_class.new({ query: "some search term" }) }
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:bot) { DiscourseAi::AiBot::OpenAiBot.new(bot_user) } let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:progress_blk) { Proc.new {} }
before { SiteSetting.ai_bot_enabled = true } before { SiteSetting.ai_bot_enabled = true }
@ -20,10 +23,9 @@ RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term", "https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
).to_return(status: 200, body: json_text, headers: {}) ).to_return(status: 200, body: json_text, headers: {})
google = described_class.new(bot: nil, post: post, args: {}.to_json) info = search.invoke(bot_user, llm, &progress_blk).to_json
info = google.process(query: "some search term").to_json
expect(google.description_args[:count]).to eq(0) expect(search.results_count).to eq(0)
expect(info).to_not include("oops") expect(info).to_not include("oops")
end end
@ -61,23 +63,14 @@ RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term", "https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
).to_return(status: 200, body: json_text, headers: {}) ).to_return(status: 200, body: json_text, headers: {})
google = info = search.invoke(bot_user, llm, &progress_blk).to_json
described_class.new(bot: bot, post: post, args: { query: "some search term" }.to_json)
info = google.process(query: "some search term").to_json expect(search.results_count).to eq(2)
expect(google.description_args[:count]).to eq(2)
expect(info).to include("title1") expect(info).to include("title1")
expect(info).to include("snippet1") expect(info).to include("snippet1")
expect(info).to include("some+search+term") expect(info).to include("some+search+term")
expect(info).to include("title2") expect(info).to include("title2")
expect(info).to_not include("oops") expect(info).to_not include("oops")
google.invoke!
expect(post.reload.raw).to include("some search term")
expect { google.invoke! }.to raise_error(StandardError)
end end
end end
end end

View File

@ -1,8 +1,14 @@
#frozen_string_literal: true #frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do RSpec.describe DiscourseAi::AiBot::Tools::Image do
subject(:tool) { described_class.new({ prompts: prompts, seeds: [99, 32] }) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:progress_blk) { Proc.new {} }
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:bot) { DiscourseAi::AiBot::OpenAiBot.new(bot_user) }
let(:prompts) { ["a pink cow", "a red cow"] }
before { SiteSetting.ai_bot_enabled = true } before { SiteSetting.ai_bot_enabled = true }
@ -17,7 +23,6 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
artifacts = [{ base64: image, seed: 99 }] artifacts = [{ base64: image, seed: 99 }]
prompts = ["a pink cow", "a red cow"]
WebMock WebMock
.stub_request( .stub_request(
@ -31,15 +36,13 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
end end
.to_return(status: 200, body: { artifacts: artifacts }.to_json) .to_return(status: 200, body: { artifacts: artifacts }.to_json)
image = described_class.new(bot: bot, post: post, args: nil) info = tool.invoke(bot_user, llm, &progress_blk).to_json
info = image.process(prompts: prompts).to_json
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow", "a red cow"], "seeds" => [99, 99]) expect(JSON.parse(info)).to eq("prompts" => ["a pink cow", "a red cow"], "seeds" => [99, 99])
expect(image.custom_raw).to include("upload://") expect(tool.custom_raw).to include("upload://")
expect(image.custom_raw).to include("[grid]") expect(tool.custom_raw).to include("[grid]")
expect(image.custom_raw).to include("a pink cow") expect(tool.custom_raw).to include("a pink cow")
expect(image.custom_raw).to include("a red cow") expect(tool.custom_raw).to include("a red cow")
end end
end end
end end

View File

@ -0,0 +1,19 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::ListCategories do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
before { SiteSetting.ai_bot_enabled = true }
describe "#process" do
it "list available categories" do
Fabricate(:category, name: "america", posts_year: 999)
info = described_class.new({}).invoke(bot_user, llm).to_s
expect(info).to include("america")
expect(info).to include("999")
end
end
end

View File

@ -0,0 +1,23 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::ListTags do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
before do
SiteSetting.ai_bot_enabled = true
SiteSetting.tagging_enabled = true
end
describe "#process" do
it "can generate correct info" do
Fabricate(:tag, name: "america", public_topic_count: 100)
Fabricate(:tag, name: "not_here", public_topic_count: 0)
info = described_class.new({}).invoke(bot_user, llm)
expect(info.to_s).to include("america")
expect(info.to_s).not_to include("not_here")
end
end
end

View File

@ -1,8 +1,10 @@
#frozen_string_literal: true #frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Commands::ReadCommand do RSpec.describe DiscourseAi::AiBot::Tools::Read do
subject(:tool) { described_class.new({ topic_id: topic_with_tags.id }) }
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user) } let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
fab!(:parent_category) { Fabricate(:category, name: "animals") } fab!(:parent_category) { Fabricate(:category, name: "animals") }
fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") } fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") }
@ -32,9 +34,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::ReadCommand do
Fabricate(:post, topic: topic_with_tags, raw: "hello there") Fabricate(:post, topic: topic_with_tags, raw: "hello there")
Fabricate(:post, topic: topic_with_tags, raw: "mister sam") Fabricate(:post, topic: topic_with_tags, raw: "mister sam")
read = described_class.new(bot: bot, args: nil) results = tool.invoke(bot_user, llm)
results = read.process(topic_id: topic_id)
expect(results[:topic_id]).to eq(topic_id) expect(results[:topic_id]).to eq(topic_id)
expect(results[:content]).to include("hello") expect(results[:content]).to include("hello")
@ -44,10 +44,8 @@ RSpec.describe DiscourseAi::AiBot::Commands::ReadCommand do
expect(results[:content]).to include("sad") expect(results[:content]).to include("sad")
expect(results[:content]).to include("animals") expect(results[:content]).to include("animals")
expect(results[:content]).not_to include("hidden") expect(results[:content]).not_to include("hidden")
expect(read.description_args).to eq( expect(tool.title).to eq(topic_with_tags.title)
title: topic_with_tags.title, expect(tool.url).to eq(topic_with_tags.relative_url)
url: topic_with_tags.relative_url,
)
end end
end end
end end

View File

@ -0,0 +1,39 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
before { SiteSetting.ai_bot_enabled = true }
def search_settings(query)
described_class.new({ query: query })
end
describe "#process" do
it "can handle no results" do
results = search_settings("this will not exist frogs").invoke(bot_user, llm)
expect(results[:args]).to eq({ query: "this will not exist frogs" })
expect(results[:rows]).to eq([])
end
it "can return more many settings with no descriptions if there are lots of hits" do
results = search_settings("a").invoke(bot_user, llm)
expect(results[:rows].length).to be > 30
expect(results[:rows][0].length).to eq(1)
end
it "can return descriptions if there are few matches" do
results =
search_settings("this will not be found!@,default_locale,ai_bot_enabled_chat_bots").invoke(
bot_user,
llm,
)
expect(results[:rows].length).to eq(2)
expect(results[:rows][0][1]).not_to eq(nil)
end
end
end

View File

@ -0,0 +1,140 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::Search do
before { SearchIndexer.enable }
after { SearchIndexer.disable }
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:progress_blk) { Proc.new {} }
fab!(:admin)
fab!(:parent_category) { Fabricate(:category, name: "animals") }
fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") }
fab!(:tag_funny) { Fabricate(:tag, name: "funny") }
fab!(:tag_sad) { Fabricate(:tag, name: "sad") }
fab!(:tag_hidden) { Fabricate(:tag, name: "hidden") }
fab!(:staff_tag_group) do
tag_group = Fabricate.build(:tag_group, name: "Staff only", tag_names: ["hidden"])
tag_group.permissions = [
[Group::AUTO_GROUPS[:staff], TagGroupPermission.permission_types[:full]],
]
tag_group.save!
tag_group
end
fab!(:topic_with_tags) do
Fabricate(:topic, category: category, tags: [tag_funny, tag_sad, tag_hidden])
end
before { SiteSetting.ai_bot_enabled = true }
describe "#invoke" do
it "can retreive options from persona correctly" do
persona_options = { "base_query" => "#funny" }
search_post = Fabricate(:post, topic: topic_with_tags)
bot_post = Fabricate(:post)
search = described_class.new({ order: "latest" }, persona_options: persona_options)
results = search.invoke(bot_user, llm, &progress_blk)
expect(results[:rows].length).to eq(1)
search_post.topic.tags = []
search_post.topic.save!
# no longer has the tag funny
results = search.invoke(bot_user, llm, &progress_blk)
expect(results[:rows].length).to eq(0)
end
it "can handle no results" do
post1 = Fabricate(:post, topic: topic_with_tags)
search = described_class.new({ search_query: "ABDDCDCEDGDG", order: "fake" })
results = search.invoke(bot_user, llm, &progress_blk)
expect(results[:args]).to eq({ search_query: "ABDDCDCEDGDG", order: "fake" })
expect(results[:rows]).to eq([])
end
describe "semantic search" do
let (:query) {
"this is an expanded search"
}
after { DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query) }
it "supports semantic search when enabled" do
SiteSetting.ai_embeddings_semantic_search_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
hyde_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
query,
hyde_embedding,
)
post1 = Fabricate(:post, topic: topic_with_tags)
search = described_class.new({ search_query: "hello world, sam", status: "public" })
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
.any_instance
.expects(:asymmetric_topics_similarity_search)
.returns([post1.topic_id])
results =
DiscourseAi::Completions::Llm.with_prepared_responses(["<ai>#{query}</ai>"]) do
search.invoke(bot_user, llm, &progress_blk)
end
expect(results[:args]).to eq({ search_query: "hello world, sam", status: "public" })
expect(results[:rows].length).to eq(1)
end
end
it "supports subfolder properly" do
Discourse.stubs(:base_path).returns("/subfolder")
post1 = Fabricate(:post, topic: topic_with_tags)
search = described_class.new({ limit: 1, user: post1.user.username })
results = search.invoke(bot_user, llm, &progress_blk)
expect(results[:rows].to_s).to include("/subfolder" + post1.url)
end
it "returns rich topic information" do
post1 = Fabricate(:post, like_count: 1, topic: topic_with_tags)
search = described_class.new({ user: post1.user.username })
post1.topic.update!(views: 100, posts_count: 2, like_count: 10)
results = search.invoke(bot_user, llm, &progress_blk)
row = results[:rows].first
category = row[results[:column_names].index("category")]
expect(category).to eq("animals > amazing-cat")
tags = row[results[:column_names].index("tags")]
expect(tags).to eq("funny, sad")
likes = row[results[:column_names].index("likes")]
expect(likes).to eq(1)
username = row[results[:column_names].index("username")]
expect(username).to eq(post1.user.username)
likes = row[results[:column_names].index("topic_likes")]
expect(likes).to eq(10)
views = row[results[:column_names].index("topic_views")]
expect(views).to eq(100)
replies = row[results[:column_names].index("topic_replies")]
expect(replies).to eq(1)
end
end
end

View File

@ -1,21 +1,26 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Commands::SettingContextCommand do def has_rg?
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:command) { described_class.new(bot_user: bot_user, args: nil) }
def has_rg?
if defined?(@has_rg) if defined?(@has_rg)
@has_rg @has_rg
else else
@has_rg |= system("which rg") @has_rg |= system("which rg")
end end
end
RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
before { SiteSetting.ai_bot_enabled = true }
def setting_context(setting_name)
described_class.new({ setting_name: setting_name })
end end
describe "#execute" do describe "#execute" do
skip("rg is needed for these tests") if !has_rg?
it "returns the context for core setting" do it "returns the context for core setting" do
result = command.process(setting_name: "moderators_view_emails") result = setting_context("moderators_view_emails").invoke(bot_user, llm)
expect(result[:setting_name]).to eq("moderators_view_emails") expect(result[:setting_name]).to eq("moderators_view_emails")
@ -23,18 +28,17 @@ RSpec.describe DiscourseAi::AiBot::Commands::SettingContextCommand do
expect(result[:context]).to include("moderators_view_emails") expect(result[:context]).to include("moderators_view_emails")
end end
skip("rg is needed for these tests") if !has_rg?
it "returns the context for plugin setting" do it "returns the context for plugin setting" do
result = command.process(setting_name: "ai_bot_enabled") result = setting_context("ai_bot_enabled").invoke(bot_user, llm)
expect(result[:setting_name]).to eq("ai_bot_enabled") expect(result[:setting_name]).to eq("ai_bot_enabled")
expect(result[:context]).to include("ai_bot_enabled:") expect(result[:context]).to include("ai_bot_enabled:")
end end
context "when the setting does not exist" do context "when the setting does not exist" do
skip("rg is needed for these tests") if !has_rg?
it "returns an error message" do it "returns an error message" do
result = command.process(setting_name: "this_setting_does_not_exist") result = setting_context("this_setting_does_not_exist").invoke(bot_user, llm)
expect(result[:context]).to eq("This setting does not exist") expect(result[:context]).to eq("This setting does not exist")
end end
end end

View File

@ -0,0 +1,46 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::Summarize do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:progress_blk) { Proc.new {} }
before { SiteSetting.ai_bot_enabled = true }
let(:summary) { "summary stuff" }
describe "#process" do
it "can generate correct info" do
post = Fabricate(:post)
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
summarization =
described_class.new({ topic_id: post.topic_id, guidance: "why did it happen?" })
info = summarization.invoke(bot_user, llm, &progress_blk)
expect(info).to include("Topic summarized")
expect(summarization.custom_raw).to include(summary)
expect(summarization.chain_next_response?).to eq(false)
end
end
it "protects hidden data" do
category = Fabricate(:category)
category.set_permissions({})
category.save!
topic = Fabricate(:topic, category_id: category.id)
post = Fabricate(:post, topic: topic)
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
summarization =
described_class.new({ topic_id: post.topic_id, guidance: "why did it happen?" })
info = summarization.invoke(bot_user, llm, &progress_blk)
expect(info).not_to include(post.raw)
expect(summarization.custom_raw).to eq(I18n.t("discourse_ai.ai_bot.topic_not_found"))
end
end
end
end

View File

@ -1,12 +1,17 @@
#frozen_string_literal: true #frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Commands::TimeCommand do RSpec.describe DiscourseAi::AiBot::Tools::Time do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
before { SiteSetting.ai_bot_enabled = true }
describe "#process" do describe "#process" do
it "can generate correct info" do it "can generate correct info" do
freeze_time freeze_time
args = { timezone: "America/Los_Angeles" } args = { timezone: "America/Los_Angeles" }
info = DiscourseAi::AiBot::Commands::TimeCommand.new(bot: nil, args: nil).process(**args) info = described_class.new(args).invoke(bot_user, llm)
expect(info).to eq({ args: args, time: Time.now.in_time_zone("America/Los_Angeles").to_s }) expect(info).to eq({ args: args, time: Time.now.in_time_zone("America/Los_Angeles").to_s })
expect(info.to_s).not_to include("not_here") expect(info.to_s).not_to include("not_here")

View File

@ -14,7 +14,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
expect(response.parsed_body["ai_personas"].length).to eq(AiPersona.count) expect(response.parsed_body["ai_personas"].length).to eq(AiPersona.count)
expect(response.parsed_body["meta"]["commands"].length).to eq( expect(response.parsed_body["meta"]["commands"].length).to eq(
DiscourseAi::AiBot::Personas::Persona.all_available_commands.length, DiscourseAi::AiBot::Personas::Persona.all_available_tools.length,
) )
end end
@ -34,7 +34,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
serializer_persona2 = response.parsed_body["ai_personas"].find { |p| p["id"] == persona2.id } serializer_persona2 = response.parsed_body["ai_personas"].find { |p| p["id"] == persona2.id }
commands = response.parsed_body["meta"]["commands"] commands = response.parsed_body["meta"]["commands"]
search_command = commands.find { |c| c["id"] == "SearchCommand" } search_command = commands.find { |c| c["id"] == "Search" }
expect(search_command["help"]).to eq(I18n.t("discourse_ai.ai_bot.command_help.search")) expect(search_command["help"]).to eq(I18n.t("discourse_ai.ai_bot.command_help.search"))
@ -71,7 +71,8 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
"Général Description", "Général Description",
) )
id = DiscourseAi::AiBot::Personas.system_personas[DiscourseAi::AiBot::Personas::General] id =
DiscourseAi::AiBot::Personas::Persona.system_personas[DiscourseAi::AiBot::Personas::General]
name = I18n.t("discourse_ai.ai_bot.personas.general.name") name = I18n.t("discourse_ai.ai_bot.personas.general.name")
description = I18n.t("discourse_ai.ai_bot.personas.general.description") description = I18n.t("discourse_ai.ai_bot.personas.general.description")
persona = response.parsed_body["ai_personas"].find { |p| p["id"] == id } persona = response.parsed_body["ai_personas"].find { |p| p["id"] == id }
@ -147,7 +148,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
context "with system personas" do context "with system personas" do
it "does not allow editing of system prompts" do it "does not allow editing of system prompts" do
put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas.system_personas.values.first}.json", put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json",
params: { params: {
ai_persona: { ai_persona: {
system_prompt: "you are not a helpful bot", system_prompt: "you are not a helpful bot",
@ -160,7 +161,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
end end
it "does not allow editing of commands" do it "does not allow editing of commands" do
put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas.system_personas.values.first}.json", put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json",
params: { params: {
ai_persona: { ai_persona: {
commands: %w[SearchCommand ImageCommand], commands: %w[SearchCommand ImageCommand],
@ -173,7 +174,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
end end
it "does not allow editing of name and description cause it is localized" do it "does not allow editing of name and description cause it is localized" do
put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas.system_personas.values.first}.json", put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json",
params: { params: {
ai_persona: { ai_persona: {
name: "bob", name: "bob",
@ -187,7 +188,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
end end
it "does allow some actions" do it "does allow some actions" do
put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas.system_personas.values.first}.json", put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json",
params: { params: {
ai_persona: { ai_persona: {
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_1]], allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_1]],
@ -225,7 +226,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
it "is not allowed to delete system personas" do it "is not allowed to delete system personas" do
expect { expect {
delete "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas.system_personas.values.first}.json" delete "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json"
expect(response).to have_http_status(:unprocessable_entity) expect(response).to have_http_status(:unprocessable_entity)
expect(response.parsed_body["errors"].join).not_to be_blank expect(response.parsed_body["errors"].join).not_to be_blank
# let's make sure this is translated # let's make sure this is translated

View File

@ -1,69 +0,0 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Inference::AnthropicCompletions do
before { SiteSetting.ai_anthropic_api_key = "abc-123" }
it "can complete a trivial prompt" do
response_text = "1. Serenity\\n2. Laughter\\n3. Adventure"
prompt = "Human: write 3 words\n\n"
user_id = 183
req_opts = { max_tokens_to_sample: 700, temperature: 0.5 }
AnthropicCompletionStubs.stub_response(prompt, response_text, req_opts: req_opts)
completions =
DiscourseAi::Inference::AnthropicCompletions.perform!(
prompt,
"claude-2",
temperature: req_opts[:temperature],
max_tokens: req_opts[:max_tokens_to_sample],
user_id: user_id,
)
expect(completions[:completion]).to eq(response_text)
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
request_body = { model: "claude-2", prompt: prompt }.merge(req_opts).to_json
response_body = AnthropicCompletionStubs.response(response_text).to_json
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
expect(log.request_tokens).to eq(6)
expect(log.response_tokens).to eq(16)
expect(log.raw_request_payload).to eq(request_body)
expect(log.raw_response_payload).to eq(response_body)
end
it "supports streaming mode" do
deltas = ["Mount", "ain", " ", "Tree ", "Frog"]
prompt = "Human: write 3 words\n\n"
req_opts = { max_tokens_to_sample: 300, stream: true }
content = +""
AnthropicCompletionStubs.stub_streamed_response(prompt, deltas, req_opts: req_opts)
DiscourseAi::Inference::AnthropicCompletions.perform!(
prompt,
"claude-2",
max_tokens: req_opts[:max_tokens_to_sample],
) do |partial, cancel|
data = partial[:completion]
content << data if data
cancel.call if content.split(" ").length == 2
end
expect(content).to eq("Mountain Tree ")
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
request_body = { model: "claude-2", prompt: prompt }.merge(req_opts).to_json
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
expect(log.request_tokens).to eq(6)
expect(log.response_tokens).to eq(3)
expect(log.raw_request_payload).to eq(request_body)
expect(log.raw_response_payload).to be_present
end
end

View File

@ -1,108 +0,0 @@
# frozen_string_literal: true
require "rails_helper"
module DiscourseAi::Inference
describe FunctionList do
let :function_list do
function =
Function.new(name: "get_weather", description: "Get the weather in a city (default to c)")
function.add_parameter(
name: "location",
type: "string",
description: "the city name",
required: true,
)
function.add_parameter(
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: false,
)
list = FunctionList.new
list << function
list
end
let :image_function_list do
function = Function.new(name: "image", description: "generates an image")
function.add_parameter(
name: "prompts",
type: "array",
item_type: "string",
required: true,
description: "the prompts",
)
list = FunctionList.new
list << function
list
end
it "can handle function call parsing" do
raw_prompt = <<~PROMPT
<function_calls>
<invoke>
<tool_name>image</tool_name>
<parameters>
<prompts>
[
"an oil painting",
"a cute fluffy orange",
"3 apple's",
"a cat"
]
</prompts>
</parameters>
</invoke>
</function_calls>
PROMPT
parsed = image_function_list.parse_prompt(raw_prompt)
expect(parsed).to eq(
[
{
name: "image",
arguments: {
prompts: ["an oil painting", "a cute fluffy orange", "3 apple's", "a cat"],
},
},
],
)
end
it "can generate a general custom system prompt" do
prompt = function_list.system_prompt
# this is fragile, by design, we need to test something here
#
expected = <<~PROMPT
<tools>
<tool_description>
<tool_name>get_weather</tool_name>
<description>Get the weather in a city (default to c)</description>
<parameters>
<parameter>
<name>location</name>
<type>string</type>
<description>the city name</description>
<required>true</required>
</parameter>
<parameter>
<name>unit</name>
<type>string</type>
<description>the unit of measurement celcius c or fahrenheit f</description>
<required>false</required>
<options>c,f</options>
</parameter>
</parameters>
</tool_description>
</tools>
PROMPT
expect(prompt).to include(expected)
end
end
end

View File

@ -1,417 +0,0 @@
# frozen_string_literal: true
require "rails_helper"
describe DiscourseAi::Inference::OpenAiCompletions do
before { SiteSetting.ai_openai_api_key = "abc-123" }
fab!(:user)
it "supports sending an organization id" do
SiteSetting.ai_openai_organization = "org_123"
stub_request(:post, "https://api.openai.com/v1/chat/completions").with(
body:
"{\"model\":\"gpt-3.5-turbo-0613\",\"messages\":[{\"role\":\"system\",\"content\":\"hello\"}]}",
headers: {
"Authorization" => "Bearer abc-123",
"Content-Type" => "application/json",
"Host" => "api.openai.com",
"User-Agent" => "Ruby",
"OpenAI-Organization" => "org_123",
},
).to_return(
status: 200,
body: { choices: [message: { content: "world" }] }.to_json,
headers: {
},
)
result =
DiscourseAi::Inference::OpenAiCompletions.perform!(
[{ role: "system", content: "hello" }],
"gpt-3.5-turbo-0613",
)
expect(result.dig(:choices, 0, :message, :content)).to eq("world")
end
context "when configured using Azure" do
it "Supports custom Azure endpoints for completions" do
gpt_url_base =
"https://company.openai.azure.com/openai/deployments/deployment/chat/completions?api-version=2023-03-15-preview"
key = "12345"
SiteSetting.ai_openai_api_key = key
[
{ setting_name: "ai_openai_gpt35_url", model: "gpt-35-turbo" },
{ setting_name: "ai_openai_gpt35_16k_url", model: "gpt-35-16k-turbo" },
{ setting_name: "ai_openai_gpt4_url", model: "gpt-4" },
{ setting_name: "ai_openai_gpt4_32k_url", model: "gpt-4-32k" },
{ setting_name: "ai_openai_gpt4_turbo_url", model: "gpt-4-1106-preview" },
].each do |config|
gpt_url = "#{gpt_url_base}/#{config[:model]}"
setting_name = config[:setting_name]
model = config[:model]
SiteSetting.public_send("#{setting_name}=".to_sym, gpt_url)
expected = {
id: "chatcmpl-7TfPzOyBGW5K6dyWp3NPU0mYLGZRQ",
object: "chat.completion",
created: 1_687_305_079,
model: model,
choices: [
{
index: 0,
finish_reason: "stop",
message: {
role: "assistant",
content: "Hi there! How can I assist you today?",
},
},
],
usage: {
completion_tokens: 10,
prompt_tokens: 9,
total_tokens: 19,
},
}
stub_request(:post, gpt_url).with(
body: "{\"model\":\"#{model}\",\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}]}",
headers: {
"Api-Key" => "12345",
"Content-Type" => "application/json",
"Host" => "company.openai.azure.com",
},
).to_return(status: 200, body: expected.to_json, headers: {})
result =
DiscourseAi::Inference::OpenAiCompletions.perform!(
[role: "user", content: "hello"],
model,
)
expect(result).to eq(expected)
end
end
end
it "supports function calling" do
prompt = [role: "system", content: "you are weatherbot"]
prompt << { role: "user", content: "what is the weather in sydney?" }
functions = []
function =
DiscourseAi::Inference::Function.new(
name: "get_weather",
description: "Get the weather in a city",
)
function.add_parameter(
name: "location",
type: "string",
description: "the city name",
required: true,
)
function.add_parameter(
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
)
functions << function
function_calls = []
current_function_call = nil
deltas = [
{ role: "assistant" },
{ function_call: { name: "get_weather", arguments: "" } },
{ function_call: { arguments: "{ \"location\": " } },
{ function_call: { arguments: "\"sydney\", \"unit\": \"c\" }" } },
]
OpenAiCompletionsInferenceStubs.stub_streamed_response(
prompt,
deltas,
model: "gpt-3.5-turbo-0613",
req_opts: {
functions: functions,
stream: true,
},
)
DiscourseAi::Inference::OpenAiCompletions.perform!(
prompt,
"gpt-3.5-turbo-0613",
functions: functions,
) do |json, cancel|
fn = json.dig(:choices, 0, :delta, :function_call)
if fn && fn[:name]
current_function_call = { name: fn[:name], arguments: +fn[:arguments].to_s.dup }
function_calls << current_function_call
elsif fn && fn[:arguments] && current_function_call
current_function_call[:arguments] << fn[:arguments]
end
end
expect(function_calls.length).to eq(1)
expect(function_calls[0][:name]).to eq("get_weather")
expect(JSON.parse(function_calls[0][:arguments])).to eq(
{ "location" => "sydney", "unit" => "c" },
)
prompt << { role: "function", name: "get_weather", content: 22.to_json }
OpenAiCompletionsInferenceStubs.stub_response(
prompt,
"The current temperature in Sydney is 22 degrees Celsius.",
model: "gpt-3.5-turbo-0613",
req_opts: {
functions: functions,
},
)
result =
DiscourseAi::Inference::OpenAiCompletions.perform!(
prompt,
"gpt-3.5-turbo-0613",
functions: functions,
)
expect(result.dig(:choices, 0, :message, :content)).to eq(
"The current temperature in Sydney is 22 degrees Celsius.",
)
end
it "supports rate limits" do
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
[
{ status: 429, body: "", headers: {} },
{ status: 429, body: "", headers: {} },
{ status: 200, body: { choices: [message: { content: "ok" }] }.to_json, headers: {} },
],
)
completions =
DiscourseAi::Inference::OpenAiCompletions.perform!(
[{ role: "user", content: "hello" }],
"gpt-3.5-turbo",
temperature: 0.5,
top_p: 0.8,
max_tokens: 700,
retries: 3,
retry_timeout: 0,
)
expect(completions.dig(:choices, 0, :message, :content)).to eq("ok")
end
it "supports will raise once rate limit is met" do
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
[
{ status: 429, body: "", headers: {} },
{ status: 429, body: "", headers: {} },
{ status: 429, body: "", headers: {} },
],
)
expect do
DiscourseAi::Inference::OpenAiCompletions.perform!(
[{ role: "user", content: "hello" }],
"gpt-3.5-turbo",
temperature: 0.5,
top_p: 0.8,
max_tokens: 700,
retries: 3,
retry_timeout: 0,
)
end.to raise_error(DiscourseAi::Inference::OpenAiCompletions::CompletionFailed)
end
it "can complete a trivial prompt" do
response_text = "1. Serenity\\n2. Laughter\\n3. Adventure"
prompt = [role: "user", content: "write 3 words"]
user_id = 183
req_opts = { temperature: 0.5, top_p: 0.8, max_tokens: 700 }
OpenAiCompletionsInferenceStubs.stub_response(prompt, response_text, req_opts: req_opts)
completions =
DiscourseAi::Inference::OpenAiCompletions.perform!(
prompt,
"gpt-3.5-turbo",
temperature: 0.5,
top_p: 0.8,
max_tokens: 700,
user_id: user_id,
)
expect(completions.dig(:choices, 0, :message, :content)).to eq(response_text)
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
body = { model: "gpt-3.5-turbo", messages: prompt }.merge(req_opts).to_json
request_body = OpenAiCompletionsInferenceStubs.response(response_text).to_json
expect(log.provider_id).to eq(AiApiAuditLog::Provider::OpenAI)
expect(log.request_tokens).to eq(337)
expect(log.response_tokens).to eq(162)
expect(log.raw_request_payload).to eq(body)
expect(log.raw_response_payload).to eq(request_body)
end
context "when Webmock has streaming support" do
# See: https://github.com/bblimke/webmock/issues/629
let(:mock_net_http) do
Class.new(Net::HTTP) do
def request(*)
super do |response|
response.instance_eval do
def read_body(*, &)
@body.each(&)
end
end
yield response if block_given?
response
end
end
end
end
let(:remove_original_net_http) { Net.send(:remove_const, :HTTP) }
let(:original_http) { remove_original_net_http }
let(:stub_net_http) { Net.send(:const_set, :HTTP, mock_net_http) }
let(:remove_stubbed_net_http) { Net.send(:remove_const, :HTTP) }
let(:restore_net_http) { Net.send(:const_set, :HTTP, original_http) }
before do
mock_net_http
remove_original_net_http
stub_net_http
end
after do
remove_stubbed_net_http
restore_net_http
end
it "recovers from chunked payload issues" do
raw_data = <<~TEXT
da|ta: |{"choices":[{"delta":{"content"|:"test"}}]}
data: {"choices":[{"delta":{"content":"test1"}}]}
data: {"choices":[{"delta":{"conte|nt":"test2"}}]|}
data: {"ch|oices":[{"delta|":{"content":"test3"}}]}
data: [DONE]
TEXT
chunks = raw_data.split("|")
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
status: 200,
body: chunks,
)
partials = []
DiscourseAi::Inference::OpenAiCompletions.perform!([], "gpt-3.5-turbo") do |partial, cancel|
partials << partial
end
expect(partials.length).to eq(4)
expect(partials).to eq(
[
{ choices: [{ delta: { content: "test" } }] },
{ choices: [{ delta: { content: "test1" } }] },
{ choices: [{ delta: { content: "test2" } }] },
{ choices: [{ delta: { content: "test3" } }] },
],
)
end
it "support extremely slow streaming" do
raw_data = <<~TEXT
data: {"choices":[{"delta":{"content":"test"}}]}
data: {"choices":[{"delta":{"content":"test1"}}]}
data: {"choices":[{"delta":{"content":"test2"}}]}
data: [DONE]
TEXT
chunks = raw_data.split("")
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
status: 200,
body: chunks,
)
partials = []
DiscourseAi::Inference::OpenAiCompletions.perform!([], "gpt-3.5-turbo") do |partial, cancel|
partials << partial
end
expect(partials.length).to eq(3)
expect(partials).to eq(
[
{ choices: [{ delta: { content: "test" } }] },
{ choices: [{ delta: { content: "test1" } }] },
{ choices: [{ delta: { content: "test2" } }] },
],
)
end
end
it "can operate in streaming mode" do
deltas = [
{ role: "assistant" },
{ content: "Mount" },
{ content: "ain" },
{ content: " " },
{ content: "Tree " },
{ content: "Frog" },
]
prompt = [role: "user", content: "write 3 words"]
content = +""
OpenAiCompletionsInferenceStubs.stub_streamed_response(
prompt,
deltas,
req_opts: {
stream: true,
},
)
DiscourseAi::Inference::OpenAiCompletions.perform!(prompt, "gpt-3.5-turbo") do |partial, cancel|
data = partial.dig(:choices, 0, :delta, :content)
content << data if data
cancel.call if content.split(" ").length == 2
end
expect(content).to eq("Mountain Tree ")
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
request_body = { model: "gpt-3.5-turbo", messages: prompt, stream: true }.to_json
expect(log.provider_id).to eq(AiApiAuditLog::Provider::OpenAI)
expect(log.request_tokens).to eq(4)
expect(log.response_tokens).to eq(3)
expect(log.raw_request_payload).to eq(request_body)
expect(log.raw_response_payload).to be_present
end
end

Some files were not shown because too many files have changed in this diff Show More