DEV: AI bot migration to the Llm pattern. (#343)
* DEV: AI bot migration to the Llm pattern. We added tool and conversation context support to the Llm service in discourse-ai#366, meaning we met all the conditions to migrate this module. This PR migrates to the new pattern, meaning adding a new bot now requires minimal effort as long as the service supports it. On top of this, we introduce the concept of a "Playground" to separate the PM-specific bits from the completion, allowing us to use the bot in other contexts like chat in the future. Commands are called tools, and we simplified all the placeholder logic to perform updates in a single place, making the flow more one-wayish. * Followup fixes based on testing * Cleanup unused inference code * FIX: text-based tools could be in the middle of a sentence * GPT-4-turbo support * Use new LLM API
This commit is contained in:
parent
03fc94684b
commit
f9d7d7f5f0
|
@ -12,11 +12,11 @@ module DiscourseAi
|
|||
# localized for system personas
|
||||
LocalizedAiPersonaSerializer.new(persona, root: false)
|
||||
end
|
||||
commands =
|
||||
DiscourseAi::AiBot::Personas::Persona.all_available_commands.map do |command|
|
||||
AiCommandSerializer.new(command, root: false)
|
||||
tools =
|
||||
DiscourseAi::AiBot::Personas::Persona.all_available_tools.map do |tool|
|
||||
AiToolSerializer.new(tool, root: false)
|
||||
end
|
||||
render json: { ai_personas: ai_personas, meta: { commands: commands } }
|
||||
render json: { ai_personas: ai_personas, meta: { commands: tools } }
|
||||
end
|
||||
|
||||
def show
|
||||
|
|
|
@ -8,17 +8,25 @@ module ::Jobs
|
|||
return unless bot_user = User.find_by(id: args[:bot_user_id])
|
||||
return unless post = Post.includes(:topic).find_by(id: args[:post_id])
|
||||
|
||||
kwargs = {}
|
||||
kwargs[:user] = post.user
|
||||
begin
|
||||
persona = nil
|
||||
if persona_id = post.topic.custom_fields["ai_persona_id"]
|
||||
kwargs[:persona_id] = persona_id.to_i
|
||||
else
|
||||
kwargs[:persona_name] = post.topic.custom_fields["ai_persona"]
|
||||
persona =
|
||||
DiscourseAi::AiBot::Personas::Persona.find_by(user: post.user, id: persona_id.to_i)
|
||||
raise DiscourseAi::AiBot::Bot::BOT_NOT_FOUND if persona.nil?
|
||||
end
|
||||
|
||||
begin
|
||||
bot = DiscourseAi::AiBot::Bot.as(bot_user, **kwargs)
|
||||
bot.reply_to(post)
|
||||
if !persona && persona_name = post.topic.custom_fields["ai_persona"]
|
||||
persona =
|
||||
DiscourseAi::AiBot::Personas::Persona.find_by(user: post.user, name: persona_name)
|
||||
raise DiscourseAi::AiBot::Bot::BOT_NOT_FOUND if persona.nil?
|
||||
end
|
||||
|
||||
persona ||= DiscourseAi::AiBot::Personas::General
|
||||
|
||||
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona.new)
|
||||
|
||||
DiscourseAi::AiBot::Playground.new(bot).reply_to(post)
|
||||
rescue DiscourseAi::AiBot::Bot::BOT_NOT_FOUND
|
||||
Rails.logger.warn(
|
||||
"Bot not found for post #{post.id} - perhaps persona was deleted or bot was disabled",
|
||||
|
|
|
@ -11,7 +11,7 @@ module ::Jobs
|
|||
|
||||
return unless post.topic.custom_fields[DiscourseAi::AiBot::EntryPoint::REQUIRE_TITLE_UPDATE]
|
||||
|
||||
bot.update_pm_title(post)
|
||||
DiscourseAi::AiBot::Playground.new(bot).title_playground(post)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -67,7 +67,7 @@ class AiPersona < ActiveRecord::Base
|
|||
id = self.id
|
||||
system = self.system
|
||||
|
||||
persona_class = DiscourseAi::AiBot::Personas.system_personas_by_id[self.id]
|
||||
persona_class = DiscourseAi::AiBot::Personas::Persona.system_personas_by_id[self.id]
|
||||
if persona_class
|
||||
persona_class.define_singleton_method :allowed_group_ids do
|
||||
allowed_group_ids
|
||||
|
@ -90,8 +90,10 @@ class AiPersona < ActiveRecord::Base
|
|||
|
||||
options = {}
|
||||
|
||||
commands =
|
||||
self.commands.filter_map do |element|
|
||||
tools = self.respond_to?(:commands) ? self.commands : self.tools
|
||||
|
||||
tools =
|
||||
tools.filter_map do |element|
|
||||
inner_name = element
|
||||
current_options = nil
|
||||
|
||||
|
@ -100,8 +102,12 @@ class AiPersona < ActiveRecord::Base
|
|||
current_options = element[1]
|
||||
end
|
||||
|
||||
# Won't migrate data yet. Let's rewrite to the tool name.
|
||||
inner_name = inner_name.gsub("Command", "")
|
||||
inner_name = "List#{inner_name}" if %w[Categories Tags].include?(inner_name)
|
||||
|
||||
begin
|
||||
klass = ("DiscourseAi::AiBot::Commands::#{inner_name}").constantize
|
||||
klass = ("DiscourseAi::AiBot::Tools::#{inner_name}").constantize
|
||||
options[klass] = current_options if current_options
|
||||
klass
|
||||
rescue StandardError
|
||||
|
@ -143,8 +149,8 @@ class AiPersona < ActiveRecord::Base
|
|||
super(*args, **kwargs)
|
||||
end
|
||||
|
||||
define_method :commands do
|
||||
commands
|
||||
define_method :tools do
|
||||
tools
|
||||
end
|
||||
|
||||
define_method :options do
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class AiCommandSerializer < ApplicationSerializer
|
||||
class AiToolSerializer < ApplicationSerializer
|
||||
attributes :options, :id, :name, :help
|
||||
|
||||
def include_options?
|
||||
object.options.present?
|
||||
object.accepted_options.present?
|
||||
end
|
||||
|
||||
def id
|
||||
|
@ -21,7 +21,7 @@ class AiCommandSerializer < ApplicationSerializer
|
|||
|
||||
def options
|
||||
options = {}
|
||||
object.options.each do |option|
|
||||
object.accepted_options.each do |option|
|
||||
options[option.name] = {
|
||||
name: option.localized_name,
|
||||
description: option.localized_description,
|
|
@ -132,6 +132,8 @@ en:
|
|||
attribution: "Image by Stable Diffusion XL"
|
||||
|
||||
ai_bot:
|
||||
placeholder_reply: "I will reply shortly..."
|
||||
|
||||
personas:
|
||||
cannot_delete_system_persona: "System personas cannot be deleted, please disable it instead"
|
||||
cannot_edit_system_persona: "System personas can only be renamed, you may not edit commands or system prompt, instead disable and make a copy"
|
||||
|
@ -157,6 +159,7 @@ en:
|
|||
name: "DALL-E 3"
|
||||
description: "AI Bot specialized in generating images using DALL-E 3"
|
||||
topic_not_found: "Summary unavailable, topic not found!"
|
||||
summarizing: "Summarizing topic"
|
||||
searching: "Searching for: '%{query}'"
|
||||
command_options:
|
||||
search:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
DiscourseAi::AiBot::Personas.system_personas.each do |persona_class, id|
|
||||
DiscourseAi::AiBot::Personas::Persona.system_personas.each do |persona_class, id|
|
||||
persona = AiPersona.find_by(id: id)
|
||||
if !persona
|
||||
persona = AiPersona.new
|
||||
|
@ -32,7 +32,7 @@ DiscourseAi::AiBot::Personas.system_personas.each do |persona_class, id|
|
|||
|
||||
persona.system = true
|
||||
instance = persona_class.new
|
||||
persona.commands = instance.commands.map { |command| command.to_s.split("::").last }
|
||||
persona.commands = instance.tools.map { |tool| tool.to_s.split("::").last }
|
||||
persona.system_prompt = instance.system_prompt
|
||||
persona.save!(validate: false)
|
||||
end
|
||||
|
|
|
@ -1,74 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
class AnthropicBot < Bot
|
||||
def self.can_reply_as?(bot_user)
|
||||
bot_user.id == DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
|
||||
end
|
||||
|
||||
def bot_prompt_with_topic_context(post, allow_commands:)
|
||||
super(post, allow_commands: allow_commands).join("\n\n") + "\n\nAssistant:"
|
||||
end
|
||||
|
||||
def prompt_limit(allow_commands: true)
|
||||
# no side channel for commands, so we can ignore allow commands
|
||||
50_000 # https://console.anthropic.com/docs/prompt-design#what-is-a-prompt
|
||||
end
|
||||
|
||||
def title_prompt(post)
|
||||
super(post).join("\n\n") + "\n\nAssistant:"
|
||||
end
|
||||
|
||||
def get_delta(partial, context)
|
||||
completion = partial[:completion]
|
||||
if completion&.start_with?(" ") && !context[:processed_first]
|
||||
completion = completion[1..-1]
|
||||
context[:processed_first] = true
|
||||
end
|
||||
completion
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::AnthropicTokenizer
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def build_message(poster_username, content, system: false, function: nil)
|
||||
role = poster_username == bot_user.username ? "Assistant" : "Human"
|
||||
|
||||
if system || function
|
||||
content
|
||||
else
|
||||
"#{role}: #{content}"
|
||||
end
|
||||
end
|
||||
|
||||
def model_for
|
||||
"claude-2"
|
||||
end
|
||||
|
||||
def get_updated_title(prompt)
|
||||
DiscourseAi::Inference::AnthropicCompletions.perform!(
|
||||
prompt,
|
||||
model_for,
|
||||
temperature: 0.4,
|
||||
max_tokens: 40,
|
||||
).dig(:completion)
|
||||
end
|
||||
|
||||
def submit_prompt(prompt, post: nil, prefer_low_cost: false, &blk)
|
||||
DiscourseAi::Inference::AnthropicCompletions.perform!(
|
||||
prompt,
|
||||
model_for,
|
||||
temperature: 0.4,
|
||||
max_tokens: 3000,
|
||||
post: post,
|
||||
stop_sequences: ["\n\nHuman:", "</function_calls>"],
|
||||
&blk
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -3,464 +3,152 @@
|
|||
module DiscourseAi
|
||||
module AiBot
|
||||
class Bot
|
||||
class FunctionCalls
|
||||
attr_accessor :maybe_buffer, :maybe_found, :custom
|
||||
|
||||
def initialize
|
||||
@functions = []
|
||||
@current_function = nil
|
||||
@found = false
|
||||
@cancel_completion = false
|
||||
@maybe_buffer = +""
|
||||
@maybe_found = false
|
||||
@custom = false
|
||||
end
|
||||
|
||||
def custom?
|
||||
@custom
|
||||
end
|
||||
|
||||
def found?
|
||||
!@functions.empty? || @found
|
||||
end
|
||||
|
||||
def found!
|
||||
@found = true
|
||||
end
|
||||
|
||||
def maybe_found?
|
||||
@maybe_found
|
||||
end
|
||||
|
||||
def cancel_completion?
|
||||
@cancel_completion
|
||||
end
|
||||
|
||||
def cancel_completion!
|
||||
@cancel_completion = true
|
||||
end
|
||||
|
||||
def add_function(name)
|
||||
@current_function = { name: name, arguments: +"" }
|
||||
@functions << @current_function
|
||||
end
|
||||
|
||||
def add_argument_fragment(fragment)
|
||||
@current_function[:arguments] << fragment
|
||||
end
|
||||
|
||||
def length
|
||||
@functions.length
|
||||
end
|
||||
|
||||
def each
|
||||
@functions.each { |function| yield function }
|
||||
end
|
||||
|
||||
def to_a
|
||||
@functions
|
||||
end
|
||||
end
|
||||
|
||||
attr_reader :bot_user, :persona
|
||||
|
||||
BOT_NOT_FOUND = Class.new(StandardError)
|
||||
MAX_COMPLETIONS = 5
|
||||
|
||||
def self.as(bot_user, persona_id: nil, persona_name: nil, user: nil)
|
||||
available_bots = [DiscourseAi::AiBot::OpenAiBot, DiscourseAi::AiBot::AnthropicBot]
|
||||
|
||||
bot =
|
||||
available_bots.detect(-> { raise BOT_NOT_FOUND }) do |bot_klass|
|
||||
bot_klass.can_reply_as?(bot_user)
|
||||
def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new)
|
||||
new(bot_user, persona)
|
||||
end
|
||||
|
||||
persona = nil
|
||||
if persona_id
|
||||
persona = DiscourseAi::AiBot::Personas.find_by(user: user, id: persona_id)
|
||||
raise BOT_NOT_FOUND if persona.nil?
|
||||
end
|
||||
|
||||
if !persona && persona_name
|
||||
persona = DiscourseAi::AiBot::Personas.find_by(user: user, name: persona_name)
|
||||
raise BOT_NOT_FOUND if persona.nil?
|
||||
end
|
||||
|
||||
bot.new(bot_user, persona: persona&.new)
|
||||
end
|
||||
|
||||
def initialize(bot_user, persona: nil)
|
||||
def initialize(bot_user, persona)
|
||||
@bot_user = bot_user
|
||||
@persona = persona || DiscourseAi::AiBot::Personas::General.new
|
||||
@persona = persona
|
||||
end
|
||||
|
||||
def update_pm_title(post)
|
||||
prompt = title_prompt(post)
|
||||
attr_reader :bot_user
|
||||
|
||||
new_title = get_updated_title(prompt).strip.split("\n").last
|
||||
|
||||
PostRevisor.new(post.topic.first_post, post.topic).revise!(
|
||||
bot_user,
|
||||
title: new_title.sub(/\A"/, "").sub(/"\Z/, ""),
|
||||
)
|
||||
post.topic.custom_fields.delete(DiscourseAi::AiBot::EntryPoint::REQUIRE_TITLE_UPDATE)
|
||||
post.topic.save_custom_fields
|
||||
end
|
||||
|
||||
def reply_to(
|
||||
post,
|
||||
total_completions: 0,
|
||||
bot_reply_post: nil,
|
||||
prefer_low_cost: false,
|
||||
standalone: false
|
||||
)
|
||||
return if total_completions > MAX_COMPLETIONS
|
||||
|
||||
# do not allow commands when we are at the end of chain (total completions == MAX_COMPLETIONS)
|
||||
allow_commands = (total_completions < MAX_COMPLETIONS)
|
||||
|
||||
prompt =
|
||||
if standalone && post.post_custom_prompt
|
||||
username, standalone_prompt = post.post_custom_prompt.custom_prompt.last
|
||||
[build_message(username, standalone_prompt)]
|
||||
else
|
||||
bot_prompt_with_topic_context(post, allow_commands: allow_commands)
|
||||
end
|
||||
|
||||
redis_stream_key = nil
|
||||
partial_reply = +""
|
||||
reply = +(bot_reply_post ? bot_reply_post.raw.dup : "")
|
||||
start = Time.now
|
||||
|
||||
setup_cancel = false
|
||||
context = {}
|
||||
functions = FunctionCalls.new
|
||||
|
||||
submit_prompt(prompt, post: post, prefer_low_cost: prefer_low_cost) do |partial, cancel|
|
||||
current_delta = get_delta(partial, context)
|
||||
partial_reply << current_delta
|
||||
|
||||
if !available_functions.empty?
|
||||
populate_functions(
|
||||
partial: partial,
|
||||
reply: partial_reply,
|
||||
functions: functions,
|
||||
current_delta: current_delta,
|
||||
done: false,
|
||||
)
|
||||
|
||||
cancel&.call if functions.cancel_completion?
|
||||
end
|
||||
|
||||
if functions.maybe_buffer.present? && !functions.maybe_found?
|
||||
reply << functions.maybe_buffer
|
||||
functions.maybe_buffer = +""
|
||||
end
|
||||
|
||||
reply << current_delta if !functions.found? && !functions.maybe_found?
|
||||
|
||||
if redis_stream_key && !Discourse.redis.get(redis_stream_key)
|
||||
cancel&.call
|
||||
|
||||
bot_reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) if bot_reply_post
|
||||
end
|
||||
|
||||
# Minor hack to skip the delay during tests.
|
||||
next if (Time.now - start < 0.5) && !Rails.env.test?
|
||||
|
||||
if bot_reply_post
|
||||
Discourse.redis.expire(redis_stream_key, 60)
|
||||
start = Time.now
|
||||
|
||||
publish_update(bot_reply_post, raw: reply.dup)
|
||||
else
|
||||
bot_reply_post =
|
||||
PostCreator.create!(
|
||||
bot_user,
|
||||
topic_id: post.topic_id,
|
||||
raw: reply,
|
||||
skip_validations: true,
|
||||
)
|
||||
end
|
||||
|
||||
if !setup_cancel && bot_reply_post
|
||||
redis_stream_key = "gpt_cancel:#{bot_reply_post.id}"
|
||||
Discourse.redis.setex(redis_stream_key, 60, 1)
|
||||
setup_cancel = true
|
||||
end
|
||||
end
|
||||
|
||||
if !available_functions.empty?
|
||||
populate_functions(
|
||||
partial: nil,
|
||||
reply: partial_reply,
|
||||
current_delta: "",
|
||||
functions: functions,
|
||||
done: true,
|
||||
)
|
||||
end
|
||||
|
||||
if functions.maybe_buffer.present?
|
||||
reply << functions.maybe_buffer
|
||||
functions.maybe_buffer = +""
|
||||
end
|
||||
|
||||
if bot_reply_post
|
||||
publish_update(bot_reply_post, done: true)
|
||||
|
||||
bot_reply_post.revise(
|
||||
bot_user,
|
||||
{ raw: reply },
|
||||
skip_validations: true,
|
||||
skip_revision: true,
|
||||
)
|
||||
|
||||
bot_reply_post.post_custom_prompt ||= post.build_post_custom_prompt(custom_prompt: [])
|
||||
prompt = post.post_custom_prompt.custom_prompt || []
|
||||
|
||||
truncated_reply = partial_reply
|
||||
|
||||
# TODO: we may want to move this code
|
||||
if functions.length > 0 && partial_reply.include?("</invoke>")
|
||||
# recover stop word potentially
|
||||
truncated_reply =
|
||||
partial_reply.split("</invoke>").first + "</invoke>\n</function_calls>"
|
||||
end
|
||||
|
||||
prompt << [truncated_reply, bot_user.username] if truncated_reply.present?
|
||||
|
||||
post.post_custom_prompt.update!(custom_prompt: prompt)
|
||||
end
|
||||
|
||||
if functions.length > 0
|
||||
chain = false
|
||||
standalone = false
|
||||
|
||||
functions.each do |function|
|
||||
name, args = function[:name], function[:arguments]
|
||||
|
||||
if command_klass = available_commands.detect { |cmd| cmd.invoked?(name) }
|
||||
command =
|
||||
command_klass.new(
|
||||
bot: self,
|
||||
args: args,
|
||||
post: bot_reply_post,
|
||||
parent_post: post,
|
||||
xml_format: !functions.custom?,
|
||||
)
|
||||
chain_intermediate, bot_reply_post = command.invoke!
|
||||
chain ||= chain_intermediate
|
||||
standalone ||= command.standalone?
|
||||
end
|
||||
end
|
||||
|
||||
if chain
|
||||
reply_to(
|
||||
bot_reply_post,
|
||||
total_completions: total_completions + 1,
|
||||
bot_reply_post: bot_reply_post,
|
||||
standalone: standalone,
|
||||
)
|
||||
end
|
||||
end
|
||||
rescue => e
|
||||
if Rails.env.development?
|
||||
p e
|
||||
puts e.backtrace
|
||||
end
|
||||
raise e if Rails.env.test?
|
||||
Discourse.warn_exception(e, message: "ai-bot: Reply failed")
|
||||
end
|
||||
|
||||
def extra_tokens_per_message
|
||||
0
|
||||
end
|
||||
|
||||
def bot_prompt_with_topic_context(post, allow_commands:)
|
||||
messages = []
|
||||
conversation = conversation_context(post)
|
||||
|
||||
rendered_system_prompt = system_prompt(post, allow_commands: allow_commands)
|
||||
total_prompt_tokens = tokenize(rendered_system_prompt).length + extra_tokens_per_message
|
||||
|
||||
prompt_limit = self.prompt_limit(allow_commands: allow_commands)
|
||||
|
||||
conversation.each do |raw, username, function|
|
||||
break if total_prompt_tokens >= prompt_limit
|
||||
|
||||
tokens = tokenize(raw.to_s + username.to_s)
|
||||
|
||||
while !raw.blank? &&
|
||||
tokens.length + total_prompt_tokens + extra_tokens_per_message > prompt_limit
|
||||
raw = raw[0..-100] || ""
|
||||
tokens = tokenize(raw.to_s + username.to_s)
|
||||
end
|
||||
|
||||
next if raw.blank?
|
||||
|
||||
total_prompt_tokens += tokens.length + extra_tokens_per_message
|
||||
messages.unshift(build_message(username, raw, function: !!function))
|
||||
end
|
||||
|
||||
messages.unshift(build_message(bot_user.username, rendered_system_prompt, system: true))
|
||||
|
||||
messages
|
||||
end
|
||||
|
||||
def prompt_limit(allow_commands: false)
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def title_prompt(post)
|
||||
prompt = <<~TEXT
|
||||
You are titlebot. Given a topic you will figure out a title.
|
||||
You will never respond with anything but a 7 word topic title.
|
||||
def get_updated_title(conversation_context, post_user)
|
||||
title_prompt = { insts: <<~TEXT, conversation_context: conversation_context }
|
||||
You are titlebot. Given a topic, you will figure out a title.
|
||||
You will never respond with anything but 7 word topic title.
|
||||
TEXT
|
||||
messages = [build_message(bot_user.username, prompt, system: true)]
|
||||
|
||||
messages << build_message("User", <<~TEXT)
|
||||
Suggest a 7 word title for the following topic without quoting any of it:
|
||||
title_prompt[
|
||||
:input
|
||||
] = "Based on our previous conversation, suggest a 7 word title without quoting any of it."
|
||||
|
||||
<content>
|
||||
#{post.topic.posts.map(&:raw).join("\n\n")[0..prompt_limit(allow_commands: false)]}
|
||||
</content>
|
||||
TEXT
|
||||
messages
|
||||
DiscourseAi::Completions::Llm
|
||||
.proxy(model)
|
||||
.generate(title_prompt, user: post_user)
|
||||
.strip
|
||||
.split("\n")
|
||||
.last
|
||||
end
|
||||
|
||||
def available_commands
|
||||
@persona.available_commands
|
||||
end
|
||||
def reply(context, &update_blk)
|
||||
prompt = persona.craft_prompt(context)
|
||||
|
||||
def system_prompt_style!(style)
|
||||
@style = style
|
||||
end
|
||||
total_completions = 0
|
||||
ongoing_chain = true
|
||||
low_cost = false
|
||||
raw_context = []
|
||||
|
||||
def system_prompt(post, allow_commands:)
|
||||
return "You are a helpful Bot" if @style == :simple
|
||||
while total_completions <= MAX_COMPLETIONS && ongoing_chain
|
||||
current_model = model(prefer_low_cost: low_cost)
|
||||
llm = DiscourseAi::Completions::Llm.proxy(current_model)
|
||||
tool_found = false
|
||||
|
||||
@persona.render_system_prompt(
|
||||
topic: post.topic,
|
||||
allow_commands: allow_commands,
|
||||
render_function_instructions:
|
||||
allow_commands && include_function_instructions_in_system_prompt?,
|
||||
)
|
||||
end
|
||||
llm.generate(prompt, user: context[:user]) do |partial, cancel|
|
||||
if (tool = persona.find_tool(partial))
|
||||
tool_found = true
|
||||
ongoing_chain = tool.chain_next_response?
|
||||
low_cost = tool.low_cost?
|
||||
tool_call_id = tool.tool_call_id
|
||||
invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json
|
||||
|
||||
def include_function_instructions_in_system_prompt?
|
||||
true
|
||||
end
|
||||
invocation_context = {
|
||||
type: "tool",
|
||||
name: tool_call_id,
|
||||
content: invocation_result_json,
|
||||
}
|
||||
tool_context = {
|
||||
type: "tool_call",
|
||||
name: tool_call_id,
|
||||
content: { name: tool.name, arguments: tool.parameters }.to_json,
|
||||
}
|
||||
|
||||
def function_list
|
||||
@persona.function_list
|
||||
end
|
||||
prompt[:conversation_context] ||= []
|
||||
|
||||
def tokenizer
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def tokenize(text)
|
||||
tokenizer.tokenize(text)
|
||||
end
|
||||
|
||||
def submit_prompt(
|
||||
prompt,
|
||||
post:,
|
||||
prefer_low_cost: false,
|
||||
temperature: nil,
|
||||
max_tokens: nil,
|
||||
&blk
|
||||
)
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def get_delta(partial, context)
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def populate_functions(partial:, reply:, functions:, done:, current_delta:)
|
||||
if !done
|
||||
search_length = "<function_calls>".length
|
||||
index = -1
|
||||
while index > -search_length
|
||||
substr = reply[index..-1] || reply
|
||||
index -= 1
|
||||
|
||||
functions.maybe_found = "<function_calls>".start_with?(substr)
|
||||
break if functions.maybe_found?
|
||||
end
|
||||
|
||||
functions.maybe_buffer << current_delta if functions.maybe_found?
|
||||
functions.found! if reply.match?(/^<function_calls>/i)
|
||||
if functions.found?
|
||||
functions.maybe_buffer = functions.maybe_buffer.to_s.split("<")[0..-2].join("<")
|
||||
functions.cancel_completion! if reply.match?(%r{</function_calls>}i)
|
||||
end
|
||||
if tool.standalone?
|
||||
prompt[:conversation_context] = [invocation_context, tool_context]
|
||||
else
|
||||
functions_string = reply.scan(%r{(<function_calls>(.*?)</invoke>)}im)&.first&.first
|
||||
if functions_string
|
||||
function_list
|
||||
.parse_prompt(functions_string + "</function_calls>")
|
||||
.each do |function|
|
||||
functions.add_function(function[:name])
|
||||
functions.add_argument_fragment(function[:arguments].to_json)
|
||||
end
|
||||
end
|
||||
end
|
||||
prompt[:conversation_context] = [invocation_context, tool_context] +
|
||||
prompt[:conversation_context]
|
||||
end
|
||||
|
||||
def available_functions
|
||||
@persona.available_functions
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def get_updated_title(prompt)
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def model_for(bot)
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def conversation_context(post)
|
||||
context =
|
||||
post
|
||||
.topic
|
||||
.posts
|
||||
.includes(:user)
|
||||
.joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id")
|
||||
.where("post_number <= ?", post.post_number)
|
||||
.order("post_number desc")
|
||||
.where("post_type = ?", Post.types[:regular])
|
||||
.limit(50)
|
||||
.pluck(:raw, :username, "post_custom_prompts.custom_prompt")
|
||||
|
||||
result = []
|
||||
|
||||
first = true
|
||||
context.each do |raw, username, custom_prompt|
|
||||
if custom_prompt.present?
|
||||
if first
|
||||
custom_prompt.reverse_each { |message| result << message }
|
||||
first = false
|
||||
raw_context << [tool_context[:content], tool_call_id, "tool_call"]
|
||||
raw_context << [invocation_result_json, tool_call_id, "tool"]
|
||||
else
|
||||
result << custom_prompt.first
|
||||
end
|
||||
else
|
||||
result << [raw, username]
|
||||
update_blk.call(partial, cancel, nil)
|
||||
end
|
||||
end
|
||||
|
||||
ongoing_chain = false if !tool_found
|
||||
total_completions += 1
|
||||
|
||||
# do not allow tools when we are at the end of a chain (total_completions == MAX_COMPLETIONS)
|
||||
prompt.delete(:tools) if total_completions == MAX_COMPLETIONS
|
||||
end
|
||||
|
||||
raw_context
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
attr_reader :persona
|
||||
|
||||
def invoke_tool(tool, llm, cancel, &update_blk)
|
||||
update_blk.call("", cancel, build_placeholder(tool.summary, ""))
|
||||
|
||||
result =
|
||||
tool.invoke(bot_user, llm) do |progress|
|
||||
placeholder = build_placeholder(tool.summary, progress)
|
||||
update_blk.call("", cancel, placeholder)
|
||||
end
|
||||
|
||||
tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)
|
||||
update_blk.call(tool_details, cancel, nil)
|
||||
|
||||
result
|
||||
end
|
||||
|
||||
def publish_update(bot_reply_post, payload)
|
||||
MessageBus.publish(
|
||||
"discourse-ai/ai-bot/topic/#{bot_reply_post.topic_id}",
|
||||
payload.merge(post_id: bot_reply_post.id, post_number: bot_reply_post.post_number),
|
||||
user_ids: bot_reply_post.topic.allowed_user_ids,
|
||||
)
|
||||
def model(prefer_low_cost: false)
|
||||
default_model =
|
||||
case bot_user.id
|
||||
when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
|
||||
"claude-2"
|
||||
when DiscourseAi::AiBot::EntryPoint::GPT4_ID
|
||||
"gpt-4"
|
||||
when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
|
||||
"gpt-4-turbo"
|
||||
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
|
||||
"gpt-3.5-turbo-16k"
|
||||
else
|
||||
nil
|
||||
end
|
||||
|
||||
if %w[gpt-4 gpt-4-turbo].include?(default_model) && prefer_low_cost
|
||||
return "gpt-3.5-turbo-16k"
|
||||
end
|
||||
|
||||
default_model
|
||||
end
|
||||
|
||||
def tool_invocation?(partial)
|
||||
Nokogiri::HTML5.fragment(partial).at("invoke").present?
|
||||
end
|
||||
|
||||
def build_placeholder(summary, details, custom_raw: nil)
|
||||
placeholder = +(<<~HTML)
|
||||
<details>
|
||||
<summary>#{summary}</summary>
|
||||
<p>#{details}</p>
|
||||
</details>
|
||||
HTML
|
||||
|
||||
placeholder << custom_raw << "\n" if custom_raw
|
||||
|
||||
placeholder
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi::AiBot::Commands
|
||||
class CategoriesCommand < Command
|
||||
class << self
|
||||
def name
|
||||
"categories"
|
||||
end
|
||||
|
||||
def desc
|
||||
"Will list the categories on the current discourse instance, prefer to format with # in front of the category name"
|
||||
end
|
||||
|
||||
def parameters
|
||||
[]
|
||||
end
|
||||
end
|
||||
|
||||
def result_name
|
||||
"results"
|
||||
end
|
||||
|
||||
def description_args
|
||||
{ count: @last_count || 0 }
|
||||
end
|
||||
|
||||
def process
|
||||
columns = {
|
||||
name: "Name",
|
||||
slug: "Slug",
|
||||
description: "Description",
|
||||
posts_year: "Posts Year",
|
||||
posts_month: "Posts Month",
|
||||
posts_week: "Posts Week",
|
||||
id: "id",
|
||||
parent_category_id: "parent_category_id",
|
||||
}
|
||||
|
||||
rows = Category.where(read_restricted: false).limit(100).pluck(*columns.keys)
|
||||
@last_count = rows.length
|
||||
|
||||
format_results(rows, columns.values)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,237 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Commands
|
||||
class Command
|
||||
CARET = "<!-- caret -->"
|
||||
PROGRESS_CARET = "<!-- progress -->"
|
||||
|
||||
class << self
|
||||
def name
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def invoked?(cmd_name)
|
||||
cmd_name == name
|
||||
end
|
||||
|
||||
def desc
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def custom_system_message
|
||||
end
|
||||
|
||||
def parameters
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def options
|
||||
[]
|
||||
end
|
||||
|
||||
def help
|
||||
I18n.t("discourse_ai.ai_bot.command_help.#{name}")
|
||||
end
|
||||
|
||||
def option(name, type:)
|
||||
Option.new(command: self, name: name, type: type)
|
||||
end
|
||||
end
|
||||
|
||||
attr_reader :bot_user, :bot
|
||||
|
||||
def initialize(bot:, args:, post: nil, parent_post: nil, xml_format: false)
|
||||
@bot = bot
|
||||
@bot_user = bot&.bot_user
|
||||
@args = args
|
||||
@post = post
|
||||
@parent_post = parent_post
|
||||
@xml_format = xml_format
|
||||
|
||||
@placeholder = +(<<~HTML).strip
|
||||
<details>
|
||||
<summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary>
|
||||
<p>
|
||||
#{CARET}
|
||||
</p>
|
||||
</details>
|
||||
#{PROGRESS_CARET}
|
||||
HTML
|
||||
|
||||
@invoked = false
|
||||
end
|
||||
|
||||
def persona_options
|
||||
return @persona_options if @persona_options
|
||||
|
||||
@persona_options = HashWithIndifferentAccess.new
|
||||
|
||||
# during tests we may operate without a bot
|
||||
return @persona_options if !self.bot
|
||||
|
||||
self.class.options.each do |option|
|
||||
val = self.bot.persona.options.dig(self.class, option.name)
|
||||
@persona_options[option.name] = val if val
|
||||
end
|
||||
|
||||
@persona_options
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
bot.tokenizer
|
||||
end
|
||||
|
||||
def standalone?
|
||||
false
|
||||
end
|
||||
|
||||
def low_cost?
|
||||
false
|
||||
end
|
||||
|
||||
def result_name
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def name
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def process(post)
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def description_args
|
||||
{}
|
||||
end
|
||||
|
||||
def custom_raw
|
||||
end
|
||||
|
||||
def chain_next_response
|
||||
true
|
||||
end
|
||||
|
||||
def show_progress(text, progress_caret: false)
|
||||
return if !@post
|
||||
return if !@placeholder
|
||||
|
||||
# during tests we may have none
|
||||
caret = progress_caret ? PROGRESS_CARET : CARET
|
||||
new_placeholder = @placeholder.sub(caret, text + caret)
|
||||
raw = @post.raw.sub(@placeholder, new_placeholder)
|
||||
@placeholder = new_placeholder
|
||||
|
||||
@post.revise(bot_user, { raw: raw }, skip_validations: true, skip_revision: true)
|
||||
end
|
||||
|
||||
def localized_description
|
||||
I18n.t(
|
||||
"discourse_ai.ai_bot.command_description.#{self.class.name}",
|
||||
self.description_args,
|
||||
)
|
||||
end
|
||||
|
||||
def invoke!
|
||||
raise StandardError.new("Command can only be invoked once!") if @invoked
|
||||
|
||||
@invoked = true
|
||||
|
||||
if !@post
|
||||
@post =
|
||||
PostCreator.create!(
|
||||
bot_user,
|
||||
raw: @placeholder,
|
||||
topic_id: @parent_post.topic_id,
|
||||
skip_validations: true,
|
||||
skip_rate_limiter: true,
|
||||
)
|
||||
else
|
||||
@post.revise(
|
||||
bot_user,
|
||||
{ raw: @post.raw + "\n\n" + @placeholder + "\n\n" },
|
||||
skip_validations: true,
|
||||
skip_revision: true,
|
||||
)
|
||||
end
|
||||
|
||||
@post.post_custom_prompt ||= @post.build_post_custom_prompt(custom_prompt: [])
|
||||
prompt = @post.post_custom_prompt.custom_prompt || []
|
||||
|
||||
parsed_args = JSON.parse(@args).symbolize_keys
|
||||
|
||||
function_results = process(**parsed_args).to_json
|
||||
function_results = <<~XML if @xml_format
|
||||
<function_results>
|
||||
<result>
|
||||
<tool_name>#{self.class.name}</tool_name>
|
||||
<json>
|
||||
#{function_results}
|
||||
</json>
|
||||
</result>
|
||||
</function_results>
|
||||
XML
|
||||
prompt << [function_results, self.class.name, "function"]
|
||||
@post.post_custom_prompt.update!(custom_prompt: prompt)
|
||||
|
||||
raw = +(<<~HTML)
|
||||
<details>
|
||||
<summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary>
|
||||
<p>
|
||||
#{localized_description}
|
||||
</p>
|
||||
</details>
|
||||
|
||||
HTML
|
||||
|
||||
raw << custom_raw if custom_raw.present?
|
||||
|
||||
raw = @post.raw.sub(@placeholder, raw)
|
||||
|
||||
@post.revise(bot_user, { raw: raw }, skip_validations: true, skip_revision: true)
|
||||
|
||||
if chain_next_response
|
||||
# somewhat annoying but whitespace was stripped in revise
|
||||
# so we need to save again
|
||||
@post.raw = raw
|
||||
@post.save!(validate: false)
|
||||
end
|
||||
|
||||
[chain_next_response, @post]
|
||||
end
|
||||
|
||||
def format_results(rows, column_names = nil, args: nil)
|
||||
rows = rows&.map { |row| yield row } if block_given?
|
||||
|
||||
if !column_names
|
||||
index = -1
|
||||
column_indexes = {}
|
||||
|
||||
rows =
|
||||
rows&.map do |data|
|
||||
new_row = []
|
||||
data.each do |key, value|
|
||||
found_index = column_indexes[key.to_s] ||= (index += 1)
|
||||
new_row[found_index] = value
|
||||
end
|
||||
new_row
|
||||
end
|
||||
column_names = column_indexes.keys
|
||||
end
|
||||
|
||||
# this is not the most efficient format
|
||||
# however this is needed cause GPT 3.5 / 4 was steered using JSON
|
||||
result = { column_names: column_names, rows: rows }
|
||||
result[:args] = args if args
|
||||
result
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
attr_reader :bot_user, :args
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,122 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi::AiBot::Commands
|
||||
class DallECommand < Command
|
||||
class << self
|
||||
def name
|
||||
"dall_e"
|
||||
end
|
||||
|
||||
def desc
|
||||
"Renders images from supplied descriptions"
|
||||
end
|
||||
|
||||
def parameters
|
||||
[
|
||||
Parameter.new(
|
||||
name: "prompts",
|
||||
description:
|
||||
"The prompts used to generate or create or draw the image (5000 chars or less, be creative) up to 4 prompts",
|
||||
type: "array",
|
||||
item_type: "string",
|
||||
required: true,
|
||||
),
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def result_name
|
||||
"results"
|
||||
end
|
||||
|
||||
def description_args
|
||||
{ prompt: @last_prompt }
|
||||
end
|
||||
|
||||
def chain_next_response
|
||||
false
|
||||
end
|
||||
|
||||
def custom_raw
|
||||
@custom_raw
|
||||
end
|
||||
|
||||
def process(prompts:)
|
||||
# max 4 prompts
|
||||
prompts = prompts.take(4)
|
||||
|
||||
@last_prompt = prompts[0]
|
||||
|
||||
show_progress(localized_description)
|
||||
|
||||
results = nil
|
||||
|
||||
# this ensures multisite safety since background threads
|
||||
# generate the images
|
||||
api_key = SiteSetting.ai_openai_api_key
|
||||
api_url = SiteSetting.ai_openai_dall_e_3_url
|
||||
|
||||
threads = []
|
||||
prompts.each_with_index do |prompt, index|
|
||||
threads << Thread.new(prompt) do |inner_prompt|
|
||||
attempts = 0
|
||||
begin
|
||||
DiscourseAi::Inference::OpenAiImageGenerator.perform!(
|
||||
inner_prompt,
|
||||
api_key: api_key,
|
||||
api_url: api_url,
|
||||
)
|
||||
rescue => e
|
||||
attempts += 1
|
||||
sleep 2
|
||||
retry if attempts < 3
|
||||
Discourse.warn_exception(e, message: "Failed to generate image for prompt #{prompt}")
|
||||
nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
while true
|
||||
show_progress(".", progress_caret: true)
|
||||
break if threads.all? { |t| t.join(2) }
|
||||
end
|
||||
|
||||
results = threads.filter_map(&:value)
|
||||
|
||||
if results.blank?
|
||||
return { prompts: prompts, error: "Something went wrong, could not generate image" }
|
||||
end
|
||||
|
||||
uploads = []
|
||||
|
||||
results.each_with_index do |result, index|
|
||||
result[:data].each do |image|
|
||||
Tempfile.create("v1_txt2img_#{index}.png") do |file|
|
||||
file.binmode
|
||||
file.write(Base64.decode64(image[:b64_json]))
|
||||
file.rewind
|
||||
uploads << {
|
||||
prompt: image[:revised_prompt],
|
||||
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@custom_raw = <<~RAW
|
||||
|
||||
[grid]
|
||||
#{
|
||||
uploads
|
||||
.map do |item|
|
||||
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
|
||||
end
|
||||
.join(" ")
|
||||
}
|
||||
[/grid]
|
||||
RAW
|
||||
|
||||
{ prompts: uploads.map { |item| item[:prompt] } }
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,54 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi::AiBot::Commands
|
||||
class DbSchemaCommand < Command
|
||||
class << self
|
||||
def name
|
||||
"schema"
|
||||
end
|
||||
|
||||
def desc
|
||||
"Will load schema information for specific tables in the database"
|
||||
end
|
||||
|
||||
def parameters
|
||||
[
|
||||
Parameter.new(
|
||||
name: "tables",
|
||||
description:
|
||||
"list of tables to load schema information for, comma seperated list eg: (users,posts))",
|
||||
type: "string",
|
||||
required: true,
|
||||
),
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def result_name
|
||||
"results"
|
||||
end
|
||||
|
||||
def description_args
|
||||
{ tables: @tables.join(", ") }
|
||||
end
|
||||
|
||||
def process(tables:)
|
||||
@tables = tables.split(",").map(&:strip)
|
||||
|
||||
table_info = {}
|
||||
DB
|
||||
.query(<<~SQL, @tables)
|
||||
select table_name, column_name, data_type from information_schema.columns
|
||||
where table_schema = 'public'
|
||||
and table_name in (?)
|
||||
order by table_name
|
||||
SQL
|
||||
.each { |row| (table_info[row.table_name] ||= []) << "#{row.column_name} #{row.data_type}" }
|
||||
|
||||
schema_info =
|
||||
table_info.map { |table_name, columns| "#{table_name}(#{columns.join(",")})" }.join("\n")
|
||||
|
||||
{ schema_info: schema_info, tables: tables }
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,82 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi::AiBot::Commands
|
||||
class GoogleCommand < Command
|
||||
class << self
|
||||
def name
|
||||
"google"
|
||||
end
|
||||
|
||||
def desc
|
||||
"Will search using Google - global internet search (supports all Google search operators)"
|
||||
end
|
||||
|
||||
def parameters
|
||||
[
|
||||
Parameter.new(
|
||||
name: "query",
|
||||
description: "The search query",
|
||||
type: "string",
|
||||
required: true,
|
||||
),
|
||||
]
|
||||
end
|
||||
|
||||
def custom_system_message
|
||||
"You were trained on OLD data, lean on search to get up to date information from the web"
|
||||
end
|
||||
end
|
||||
|
||||
def result_name
|
||||
"results"
|
||||
end
|
||||
|
||||
def description_args
|
||||
{
|
||||
count: @last_num_results || 0,
|
||||
query: @last_query || "",
|
||||
url: "https://google.com/search?q=#{CGI.escape(@last_query || "")}",
|
||||
}
|
||||
end
|
||||
|
||||
def process(query:)
|
||||
@last_query = query
|
||||
|
||||
show_progress(localized_description)
|
||||
|
||||
api_key = SiteSetting.ai_google_custom_search_api_key
|
||||
cx = SiteSetting.ai_google_custom_search_cx
|
||||
query = CGI.escape(query)
|
||||
uri =
|
||||
URI("https://www.googleapis.com/customsearch/v1?key=#{api_key}&cx=#{cx}&q=#{query}&num=10")
|
||||
body = Net::HTTP.get(uri)
|
||||
|
||||
parse_search_json(body, query)
|
||||
end
|
||||
|
||||
def minimize_field(result, field, max_tokens: 100)
|
||||
data = result[field]
|
||||
return "" if data.blank?
|
||||
|
||||
data = ::DiscourseAi::Tokenizer::BertTokenizer.truncate(data, max_tokens).squish
|
||||
data
|
||||
end
|
||||
|
||||
def parse_search_json(json_data, query)
|
||||
parsed = JSON.parse(json_data)
|
||||
results = parsed["items"]
|
||||
|
||||
@last_num_results = parsed.dig("searchInformation", "totalResults").to_i
|
||||
|
||||
format_results(results, args: query) do |result|
|
||||
{
|
||||
title: minimize_field(result, "title"),
|
||||
link: minimize_field(result, "link"),
|
||||
snippet: minimize_field(result, "snippet", max_tokens: 120),
|
||||
displayLink: minimize_field(result, "displayLink"),
|
||||
formattedUrl: minimize_field(result, "formattedUrl"),
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,135 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi::AiBot::Commands
|
||||
class ImageCommand < Command
|
||||
class << self
|
||||
def name
|
||||
"image"
|
||||
end
|
||||
|
||||
def desc
|
||||
"Renders an image from the description (remove all connector words, keep it to 40 words or less). Despite being a text based bot you can generate images! (when user asks to draw, paint or other synonyms try this)"
|
||||
end
|
||||
|
||||
def parameters
|
||||
[
|
||||
Parameter.new(
|
||||
name: "prompts",
|
||||
description:
|
||||
"The prompts used to generate or create or draw the image (40 words or less, be creative) up to 4 prompts",
|
||||
type: "array",
|
||||
item_type: "string",
|
||||
required: true,
|
||||
),
|
||||
Parameter.new(
|
||||
name: "seeds",
|
||||
description:
|
||||
"The seed used to generate the image (optional) - can be used to retain image style on amended prompts",
|
||||
type: "array",
|
||||
item_type: "integer",
|
||||
),
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def result_name
|
||||
"results"
|
||||
end
|
||||
|
||||
def description_args
|
||||
{ prompt: @last_prompt }
|
||||
end
|
||||
|
||||
def chain_next_response
|
||||
false
|
||||
end
|
||||
|
||||
def custom_raw
|
||||
@custom_raw
|
||||
end
|
||||
|
||||
def process(prompts:, seeds: nil)
|
||||
# max 4 prompts
|
||||
prompts = prompts[0..3]
|
||||
seeds = seeds[0..3] if seeds
|
||||
|
||||
@last_prompt = prompts[0]
|
||||
|
||||
show_progress(localized_description)
|
||||
|
||||
results = nil
|
||||
|
||||
# this ensures multisite safety since background threads
|
||||
# generate the images
|
||||
api_key = SiteSetting.ai_stability_api_key
|
||||
engine = SiteSetting.ai_stability_engine
|
||||
api_url = SiteSetting.ai_stability_api_url
|
||||
|
||||
threads = []
|
||||
prompts.each_with_index do |prompt, index|
|
||||
seed = seeds ? seeds[index] : nil
|
||||
threads << Thread.new(seed, prompt) do |inner_seed, inner_prompt|
|
||||
attempts = 0
|
||||
begin
|
||||
DiscourseAi::Inference::StabilityGenerator.perform!(
|
||||
inner_prompt,
|
||||
engine: engine,
|
||||
api_key: api_key,
|
||||
api_url: api_url,
|
||||
image_count: 1,
|
||||
seed: inner_seed,
|
||||
)
|
||||
rescue => e
|
||||
attempts += 1
|
||||
retry if attempts < 3
|
||||
Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}")
|
||||
nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
while true
|
||||
show_progress(".", progress_caret: true)
|
||||
break if threads.all? { |t| t.join(2) }
|
||||
end
|
||||
|
||||
results = threads.map(&:value).compact
|
||||
|
||||
if !results.present?
|
||||
return { prompts: prompts, error: "Something went wrong, could not generate image" }
|
||||
end
|
||||
|
||||
uploads = []
|
||||
|
||||
results.each_with_index do |result, index|
|
||||
result[:artifacts].each do |image|
|
||||
Tempfile.create("v1_txt2img_#{index}.png") do |file|
|
||||
file.binmode
|
||||
file.write(Base64.decode64(image[:base64]))
|
||||
file.rewind
|
||||
uploads << {
|
||||
prompt: prompts[index],
|
||||
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
|
||||
seed: image[:seed],
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@custom_raw = <<~RAW
|
||||
|
||||
[grid]
|
||||
#{
|
||||
uploads
|
||||
.map do |item|
|
||||
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
|
||||
end
|
||||
.join(" ")
|
||||
}
|
||||
[/grid]
|
||||
RAW
|
||||
|
||||
{ prompts: uploads.map { |item| item[:prompt] }, seeds: uploads.map { |item| item[:seed] } }
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,23 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Commands
|
||||
class Option
|
||||
attr_reader :command, :name, :type
|
||||
def initialize(command:, name:, type:)
|
||||
@command = command
|
||||
@name = name.to_s
|
||||
@type = type
|
||||
end
|
||||
|
||||
def localized_name
|
||||
I18n.t("discourse_ai.ai_bot.command_options.#{command.name}.#{name}.name")
|
||||
end
|
||||
|
||||
def localized_description
|
||||
I18n.t("discourse_ai.ai_bot.command_options.#{command.name}.#{name}.description")
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,18 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Commands
|
||||
class Parameter
|
||||
attr_reader :item_type, :name, :description, :type, :enum, :required
|
||||
def initialize(name:, description:, type:, enum: nil, required: false, item_type: nil)
|
||||
@name = name
|
||||
@description = description
|
||||
@type = type
|
||||
@enum = enum
|
||||
@required = required
|
||||
@item_type = item_type
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,77 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi::AiBot::Commands
|
||||
class ReadCommand < Command
|
||||
class << self
|
||||
def name
|
||||
"read"
|
||||
end
|
||||
|
||||
def desc
|
||||
"Will read a topic or a post on this Discourse instance"
|
||||
end
|
||||
|
||||
def parameters
|
||||
[
|
||||
Parameter.new(
|
||||
name: "topic_id",
|
||||
description: "the id of the topic to read",
|
||||
type: "integer",
|
||||
required: true,
|
||||
),
|
||||
Parameter.new(
|
||||
name: "post_number",
|
||||
description: "the post number to read",
|
||||
type: "integer",
|
||||
required: false,
|
||||
),
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def description_args
|
||||
{ title: @title, url: @url }
|
||||
end
|
||||
|
||||
def process(topic_id:, post_number: nil)
|
||||
not_found = { topic_id: topic_id, description: "Topic not found" }
|
||||
|
||||
@title = ""
|
||||
|
||||
topic_id = topic_id.to_i
|
||||
|
||||
topic = Topic.find_by(id: topic_id)
|
||||
return not_found if !topic || !Guardian.new.can_see?(topic)
|
||||
|
||||
@title = topic.title
|
||||
|
||||
posts = Post.secured(Guardian.new).where(topic_id: topic_id).order(:post_number).limit(40)
|
||||
@url = topic.relative_url(post_number)
|
||||
|
||||
posts = posts.where("post_number = ?", post_number) if post_number
|
||||
|
||||
content = +<<~TEXT.strip
|
||||
title: #{topic.title}
|
||||
TEXT
|
||||
|
||||
category_names = [topic.category&.parent_category&.name, topic.category&.name].compact.join(
|
||||
" ",
|
||||
)
|
||||
content << "\ncategories: #{category_names}" if category_names.present?
|
||||
|
||||
if topic.tags.length > 0
|
||||
tags = DiscourseTagging.filter_visible(topic.tags, Guardian.new)
|
||||
content << "\ntags: #{tags.map(&:name).join(", ")}\n\n" if tags.length > 0
|
||||
end
|
||||
|
||||
posts.each { |post| content << "\n\n#{post.username} said:\n\n#{post.raw}" }
|
||||
|
||||
# TODO: 16k or 100k models can handle a lot more tokens
|
||||
content = tokenizer.truncate(content, 1500).squish
|
||||
|
||||
result = { topic_id: topic_id, content: content, complete: true }
|
||||
result[:post_number] = post_number if post_number
|
||||
result
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,232 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi::AiBot::Commands
|
||||
class SearchCommand < Command
|
||||
class << self
|
||||
def name
|
||||
"search"
|
||||
end
|
||||
|
||||
def desc
|
||||
"Will search topics in the current discourse instance, when rendering always prefer to link to the topics you find"
|
||||
end
|
||||
|
||||
def options
|
||||
[option(:base_query, type: :string), option(:max_results, type: :integer)]
|
||||
end
|
||||
|
||||
def parameters
|
||||
[
|
||||
Parameter.new(
|
||||
name: "search_query",
|
||||
description:
|
||||
"Specific keywords to search for, space seperated (correct bad spelling, remove connector words)",
|
||||
type: "string",
|
||||
),
|
||||
Parameter.new(
|
||||
name: "user",
|
||||
description:
|
||||
"Filter search results to this username (only include if user explicitly asks to filter by user)",
|
||||
type: "string",
|
||||
),
|
||||
Parameter.new(
|
||||
name: "order",
|
||||
description: "search result order",
|
||||
type: "string",
|
||||
enum: %w[latest latest_topic oldest views likes],
|
||||
),
|
||||
Parameter.new(
|
||||
name: "limit",
|
||||
description:
|
||||
"Number of results to return. Defaults to maximum number of results. Only set if absolutely necessary",
|
||||
type: "integer",
|
||||
),
|
||||
Parameter.new(
|
||||
name: "max_posts",
|
||||
description:
|
||||
"maximum number of posts on the topics (topics where lots of people posted)",
|
||||
type: "integer",
|
||||
),
|
||||
Parameter.new(
|
||||
name: "tags",
|
||||
description:
|
||||
"list of tags to search for. Use + to join with OR, use , to join with AND",
|
||||
type: "string",
|
||||
),
|
||||
Parameter.new(
|
||||
name: "category",
|
||||
description: "category name to filter to",
|
||||
type: "string",
|
||||
),
|
||||
Parameter.new(
|
||||
name: "before",
|
||||
description: "only topics created before a specific date YYYY-MM-DD",
|
||||
type: "string",
|
||||
),
|
||||
Parameter.new(
|
||||
name: "after",
|
||||
description: "only topics created after a specific date YYYY-MM-DD",
|
||||
type: "string",
|
||||
),
|
||||
Parameter.new(
|
||||
name: "status",
|
||||
description: "search for topics in a particular state",
|
||||
type: "string",
|
||||
enum: %w[open closed archived noreplies single_user],
|
||||
),
|
||||
]
|
||||
end
|
||||
|
||||
def custom_system_message
|
||||
<<~TEXT
|
||||
You were trained on OLD data, lean on search to get up to date information about this forum
|
||||
When searching try to SIMPLIFY search terms
|
||||
Discourse search joins all terms with AND. Reduce and simplify terms to find more results.
|
||||
TEXT
|
||||
end
|
||||
end
|
||||
|
||||
def result_name
|
||||
"results"
|
||||
end
|
||||
|
||||
def description_args
|
||||
{
|
||||
count: @last_num_results || 0,
|
||||
query: @last_query || "",
|
||||
url: "#{Discourse.base_path}/search?q=#{CGI.escape(@last_query || "")}",
|
||||
}
|
||||
end
|
||||
|
||||
MIN_SEMANTIC_RESULTS = 5
|
||||
|
||||
def max_semantic_results
|
||||
max_results / 4
|
||||
end
|
||||
|
||||
def max_results
|
||||
return 20 if !bot
|
||||
|
||||
max_results = persona_options[:max_results].to_i
|
||||
return [max_results, 100].min if max_results > 0
|
||||
|
||||
if bot.prompt_limit(allow_commands: false) > 30_000
|
||||
60
|
||||
elsif bot.prompt_limit(allow_commands: false) > 10_000
|
||||
40
|
||||
else
|
||||
20
|
||||
end
|
||||
end
|
||||
|
||||
def process(**search_args)
|
||||
limit = nil
|
||||
|
||||
search_string =
|
||||
search_args
|
||||
.map do |key, value|
|
||||
if key == :search_query
|
||||
value
|
||||
elsif key == :limit
|
||||
limit = value.to_i
|
||||
nil
|
||||
else
|
||||
"#{key}:#{value}"
|
||||
end
|
||||
end
|
||||
.compact
|
||||
.join(" ")
|
||||
|
||||
@last_query = search_string
|
||||
|
||||
show_progress(I18n.t("discourse_ai.ai_bot.searching", query: search_string))
|
||||
|
||||
if persona_options[:base_query].present?
|
||||
search_string = "#{search_string} #{persona_options[:base_query]}"
|
||||
end
|
||||
|
||||
results =
|
||||
Search.execute(
|
||||
search_string.to_s + " status:public",
|
||||
search_type: :full_page,
|
||||
guardian: Guardian.new(),
|
||||
)
|
||||
|
||||
# let's be frugal with tokens, 50 results is too much and stuff gets cut off
|
||||
limit ||= max_results
|
||||
limit = max_results if limit > max_results
|
||||
|
||||
should_try_semantic_search = SiteSetting.ai_embeddings_semantic_search_enabled
|
||||
should_try_semantic_search &&= (limit == max_results)
|
||||
should_try_semantic_search &&= (search_args[:search_query].present?)
|
||||
|
||||
limit = limit - max_semantic_results if should_try_semantic_search
|
||||
|
||||
posts = results&.posts || []
|
||||
posts = posts[0..limit - 1]
|
||||
|
||||
if should_try_semantic_search
|
||||
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(Guardian.new())
|
||||
topic_ids = Set.new(posts.map(&:topic_id))
|
||||
|
||||
search = Search.new(search_string, guardian: Guardian.new)
|
||||
|
||||
results = nil
|
||||
begin
|
||||
results = semantic_search.search_for_topics(search.term)
|
||||
rescue => e
|
||||
Discourse.warn_exception(e, message: "Semantic search failed")
|
||||
end
|
||||
|
||||
if results
|
||||
results = search.apply_filters(results)
|
||||
|
||||
results.each do |post|
|
||||
next if topic_ids.include?(post.topic_id)
|
||||
|
||||
topic_ids << post.topic_id
|
||||
posts << post
|
||||
|
||||
break if posts.length >= max_results
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@last_num_results = posts.length
|
||||
# this is the general pattern from core
|
||||
# if there are millions of hidden tags it may fail
|
||||
hidden_tags = nil
|
||||
|
||||
if posts.blank?
|
||||
{ args: search_args, rows: [], instruction: "nothing was found, expand your search" }
|
||||
else
|
||||
format_results(posts, args: search_args) do |post|
|
||||
category_names = [
|
||||
post.topic.category&.parent_category&.name,
|
||||
post.topic.category&.name,
|
||||
].compact.join(" > ")
|
||||
row = {
|
||||
title: post.topic.title,
|
||||
url: Discourse.base_path + post.url,
|
||||
username: post.user&.username,
|
||||
excerpt: post.excerpt,
|
||||
created: post.created_at,
|
||||
category: category_names,
|
||||
likes: post.like_count,
|
||||
topic_views: post.topic.views,
|
||||
topic_likes: post.topic.like_count,
|
||||
topic_replies: post.topic.posts_count - 1,
|
||||
}
|
||||
|
||||
if SiteSetting.tagging_enabled
|
||||
hidden_tags ||= DiscourseTagging.hidden_tag_names
|
||||
# using map over pluck to avoid n+1 (assuming caller preloading)
|
||||
tags = post.topic.tags.map(&:name) - hidden_tags
|
||||
row[:tags] = tags.join(", ") if tags.present?
|
||||
end
|
||||
row
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,85 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi::AiBot::Commands
|
||||
class SearchSettingsCommand < Command
|
||||
class << self
|
||||
def name
|
||||
"search_settings"
|
||||
end
|
||||
|
||||
def desc
|
||||
"Will search through site settings and return top 20 results"
|
||||
end
|
||||
|
||||
def parameters
|
||||
[
|
||||
Parameter.new(
|
||||
name: "query",
|
||||
description:
|
||||
"comma delimited list of settings to search for (e.g. 'setting_1,setting_2')",
|
||||
type: "string",
|
||||
required: true,
|
||||
),
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def result_name
|
||||
"results"
|
||||
end
|
||||
|
||||
def description_args
|
||||
{ count: @last_num_results || 0, query: @last_query || "" }
|
||||
end
|
||||
|
||||
INCLUDE_DESCRIPTIONS_MAX_LENGTH = 10
|
||||
MAX_RESULTS = 200
|
||||
|
||||
def process(query:)
|
||||
@last_query = query
|
||||
@last_num_results = 0
|
||||
|
||||
terms = query.split(",").map(&:strip).map(&:downcase).reject(&:blank?)
|
||||
|
||||
found =
|
||||
SiteSetting.all_settings.filter do |setting|
|
||||
name = setting[:setting].to_s.downcase
|
||||
description = setting[:description].to_s.downcase
|
||||
plugin = setting[:plugin].to_s.downcase
|
||||
|
||||
search_string = "#{name} #{description} #{plugin}"
|
||||
|
||||
terms.any? { |term| search_string.include?(term) }
|
||||
end
|
||||
|
||||
if found.blank?
|
||||
{
|
||||
args: {
|
||||
query: query,
|
||||
},
|
||||
rows: [],
|
||||
instruction: "no settings matched #{query}, expand your search",
|
||||
}
|
||||
else
|
||||
include_descriptions = false
|
||||
|
||||
if found.length > MAX_RESULTS
|
||||
found = found[0..MAX_RESULTS]
|
||||
elsif found.length < INCLUDE_DESCRIPTIONS_MAX_LENGTH
|
||||
include_descriptions = true
|
||||
end
|
||||
|
||||
@last_num_results = found.length
|
||||
|
||||
format_results(found, args: { query: query }) do |setting|
|
||||
result = { name: setting[:setting] }
|
||||
if include_descriptions
|
||||
result[:description] = setting[:description]
|
||||
result[:plugin] = setting[:plugin]
|
||||
end
|
||||
result
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,154 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi::AiBot::Commands
|
||||
MAX_CONTEXT_TOKENS = 2000
|
||||
|
||||
class SettingContextCommand < Command
|
||||
def self.rg_installed?
|
||||
if defined?(@rg_installed)
|
||||
@rg_installed
|
||||
else
|
||||
@rg_installed =
|
||||
begin
|
||||
Discourse::Utils.execute_command("which", "rg")
|
||||
true
|
||||
rescue Discourse::Utils::CommandError
|
||||
false
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
class << self
|
||||
def name
|
||||
"setting_context"
|
||||
end
|
||||
|
||||
def desc
|
||||
"Will provide you with full context regarding a particular site setting in Discourse"
|
||||
end
|
||||
|
||||
def parameters
|
||||
[
|
||||
Parameter.new(
|
||||
name: "setting_name",
|
||||
description: "The name of the site setting we need context for",
|
||||
type: "string",
|
||||
required: true,
|
||||
),
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def result_name
|
||||
"context"
|
||||
end
|
||||
|
||||
def description_args
|
||||
{ setting_name: @setting_name }
|
||||
end
|
||||
|
||||
CODE_FILE_EXTENSIONS = "rb,js,gjs,hbs"
|
||||
|
||||
def process(setting_name:)
|
||||
if !self.class.rg_installed?
|
||||
return(
|
||||
{
|
||||
setting_name: setting_name,
|
||||
context: "This command requires the rg command line tool to be installed on the server",
|
||||
}
|
||||
)
|
||||
end
|
||||
|
||||
@setting_name = setting_name
|
||||
if !SiteSetting.has_setting?(setting_name)
|
||||
{ setting_name: setting_name, context: "This setting does not exist" }
|
||||
else
|
||||
description = SiteSetting.description(setting_name)
|
||||
result = +"# #{setting_name}\n#{description}\n\n"
|
||||
|
||||
setting_info =
|
||||
find_setting_info(setting_name, [Rails.root.join("config", "site_settings.yml").to_s])
|
||||
if !setting_info
|
||||
setting_info =
|
||||
find_setting_info(setting_name, Dir[Rails.root.join("plugins/**/settings.yml")])
|
||||
end
|
||||
|
||||
result << setting_info
|
||||
result << "\n\n"
|
||||
|
||||
%w[lib app plugins].each do |dir|
|
||||
path = Rails.root.join(dir).to_s
|
||||
result << Discourse::Utils.execute_command(
|
||||
"rg",
|
||||
setting_name,
|
||||
path,
|
||||
"-g",
|
||||
"!**/spec/**",
|
||||
"-g",
|
||||
"!**/dist/**",
|
||||
"-g",
|
||||
"*.{#{CODE_FILE_EXTENSIONS}}",
|
||||
"-C",
|
||||
"10",
|
||||
"--color",
|
||||
"never",
|
||||
"--heading",
|
||||
"--no-ignore",
|
||||
chdir: path,
|
||||
success_status_codes: [0, 1],
|
||||
)
|
||||
end
|
||||
|
||||
result.gsub!(/^#{Regexp.escape(Rails.root.to_s)}/, "")
|
||||
|
||||
result = tokenizer.truncate(result, MAX_CONTEXT_TOKENS)
|
||||
|
||||
{ setting_name: setting_name, context: result }
|
||||
end
|
||||
end
|
||||
|
||||
def find_setting_info(name, paths)
|
||||
path, result = nil
|
||||
|
||||
paths.each do |search_path|
|
||||
result =
|
||||
Discourse::Utils.execute_command(
|
||||
"rg",
|
||||
name,
|
||||
search_path,
|
||||
"-g",
|
||||
"*.{#{CODE_FILE_EXTENSIONS}}",
|
||||
"-A",
|
||||
"10",
|
||||
"--color",
|
||||
"never",
|
||||
"--heading",
|
||||
success_status_codes: [0, 1],
|
||||
)
|
||||
if !result.blank?
|
||||
path = search_path
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if result.blank?
|
||||
nil
|
||||
else
|
||||
rows = result.split("\n")
|
||||
leading_spaces = rows[0].match(/^\s*/)[0].length
|
||||
|
||||
filtered = []
|
||||
|
||||
rows.each do |row|
|
||||
if !filtered.blank?
|
||||
break if row.match(/^\s*/)[0].length <= leading_spaces
|
||||
end
|
||||
filtered << row
|
||||
end
|
||||
|
||||
filtered.unshift("#{path}")
|
||||
filtered.join("\n")
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,184 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi::AiBot::Commands
|
||||
class SummarizeCommand < Command
|
||||
class << self
|
||||
def name
|
||||
"summarize"
|
||||
end
|
||||
|
||||
def desc
|
||||
"Will summarize a topic attempting to answer question in guidance"
|
||||
end
|
||||
|
||||
def parameters
|
||||
[
|
||||
Parameter.new(
|
||||
name: "topic_id",
|
||||
description: "The discourse topic id to summarize",
|
||||
type: "integer",
|
||||
required: true,
|
||||
),
|
||||
Parameter.new(
|
||||
name: "guidance",
|
||||
description: "Special guidance on how to summarize the topic",
|
||||
type: "string",
|
||||
),
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def result_name
|
||||
"summary"
|
||||
end
|
||||
|
||||
def standalone?
|
||||
true
|
||||
end
|
||||
|
||||
def low_cost?
|
||||
true
|
||||
end
|
||||
|
||||
def description_args
|
||||
{ url: "#{Discourse.base_path}/t/-/#{@last_topic_id}", title: @last_topic_title || "" }
|
||||
end
|
||||
|
||||
def process(topic_id:, guidance: nil)
|
||||
@last_topic_id = topic_id
|
||||
|
||||
topic_id = topic_id.to_i
|
||||
topic = nil
|
||||
if topic_id > 0
|
||||
topic = Topic.find_by(id: topic_id)
|
||||
topic = nil if !topic || !Guardian.new.can_see?(topic)
|
||||
end
|
||||
|
||||
@last_summary = nil
|
||||
|
||||
if topic
|
||||
@last_topic_title = topic.title
|
||||
|
||||
posts =
|
||||
Post
|
||||
.where(topic_id: topic.id)
|
||||
.where("post_type in (?)", [Post.types[:regular], Post.types[:small_action]])
|
||||
.where("not hidden")
|
||||
.order(:post_number)
|
||||
|
||||
columns = ["posts.id", :post_number, :raw, :username]
|
||||
|
||||
current_post_numbers = posts.limit(5).pluck(:post_number)
|
||||
current_post_numbers += posts.reorder("posts.score desc").limit(50).pluck(:post_number)
|
||||
current_post_numbers += posts.reorder("post_number desc").limit(5).pluck(:post_number)
|
||||
|
||||
data =
|
||||
Post
|
||||
.where(topic_id: topic.id)
|
||||
.joins(:user)
|
||||
.where("post_number in (?)", current_post_numbers)
|
||||
.order(:post_number)
|
||||
.pluck(*columns)
|
||||
|
||||
@last_summary = summarize(data, guidance, topic)
|
||||
end
|
||||
|
||||
if !@last_summary
|
||||
"Say: No topic found!"
|
||||
else
|
||||
"Topic summarized"
|
||||
end
|
||||
end
|
||||
|
||||
def custom_raw
|
||||
@last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found")
|
||||
end
|
||||
|
||||
def chain_next_response
|
||||
false
|
||||
end
|
||||
|
||||
def summarize(data, guidance, topic)
|
||||
text = +""
|
||||
data.each do |id, post_number, raw, username|
|
||||
text << "(#{post_number} #{username} said: #{raw}"
|
||||
end
|
||||
|
||||
summaries = []
|
||||
current_section = +""
|
||||
split = []
|
||||
|
||||
text
|
||||
.split(/\s+/)
|
||||
.each_slice(20) do |slice|
|
||||
current_section << " "
|
||||
current_section << slice.join(" ")
|
||||
|
||||
# somehow any more will get closer to limits
|
||||
if bot.tokenize(current_section).length > 2500
|
||||
split << current_section
|
||||
current_section = +""
|
||||
end
|
||||
end
|
||||
|
||||
split << current_section if current_section.present?
|
||||
|
||||
split = split[0..3] + split[-3..-1] if split.length > 5
|
||||
|
||||
split.each do |section|
|
||||
# TODO progress meter
|
||||
summary =
|
||||
generate_gpt_summary(
|
||||
section,
|
||||
topic: topic,
|
||||
context: "Guidance: #{guidance}\nYou are summarizing the topic: #{topic.title}",
|
||||
)
|
||||
summaries << summary
|
||||
end
|
||||
|
||||
if summaries.length > 1
|
||||
messages = []
|
||||
messages << { role: "system", content: "You are a helpful bot" }
|
||||
messages << {
|
||||
role: "user",
|
||||
content:
|
||||
"concatenated the disjoint summaries, creating a cohesive narrative:\n#{summaries.join("\n")}}",
|
||||
}
|
||||
bot.submit_prompt(messages, temperature: 0.6, max_tokens: 500, prefer_low_cost: true).dig(
|
||||
:choices,
|
||||
0,
|
||||
:message,
|
||||
:content,
|
||||
)
|
||||
else
|
||||
summaries.first
|
||||
end
|
||||
end
|
||||
|
||||
def generate_gpt_summary(text, topic:, context: nil, length: nil)
|
||||
length ||= 400
|
||||
|
||||
prompt = <<~TEXT
|
||||
#{context}
|
||||
Summarize the following in #{length} words:
|
||||
|
||||
#{text}
|
||||
TEXT
|
||||
|
||||
system_prompt = <<~TEXT
|
||||
You are a summarization bot.
|
||||
You effectively summarise any text.
|
||||
You condense it into a shorter version.
|
||||
You understand and generate Discourse forum markdown.
|
||||
Try generating links as well the format is #{topic.url}/POST_NUMBER. eg: [ref](#{topic.url}/77)
|
||||
TEXT
|
||||
|
||||
messages = [{ role: "system", content: system_prompt }]
|
||||
messages << { role: "user", content: prompt }
|
||||
|
||||
result =
|
||||
bot.submit_prompt(messages, temperature: 0.6, max_tokens: length, prefer_low_cost: true)
|
||||
result.dig(:choices, 0, :message, :content)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,42 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi::AiBot::Commands
|
||||
class TagsCommand < Command
|
||||
class << self
|
||||
def name
|
||||
"tags"
|
||||
end
|
||||
|
||||
def desc
|
||||
"Will list the 100 most popular tags on the current discourse instance"
|
||||
end
|
||||
|
||||
def parameters
|
||||
[]
|
||||
end
|
||||
end
|
||||
|
||||
def result_name
|
||||
"results"
|
||||
end
|
||||
|
||||
def description_args
|
||||
{ count: @last_count || 0 }
|
||||
end
|
||||
|
||||
def process
|
||||
column_names = { name: "Name", public_topic_count: "Topic Count" }
|
||||
|
||||
tags =
|
||||
Tag
|
||||
.where("public_topic_count > 0")
|
||||
.order(public_topic_count: :desc)
|
||||
.limit(100)
|
||||
.pluck(*column_names.keys)
|
||||
|
||||
@last_count = tags.length
|
||||
|
||||
format_results(tags, column_names.values)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,49 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi::AiBot::Commands
|
||||
class TimeCommand < Command
|
||||
class << self
|
||||
def name
|
||||
"time"
|
||||
end
|
||||
|
||||
def desc
|
||||
"Will generate the time in a timezone"
|
||||
end
|
||||
|
||||
def parameters
|
||||
[
|
||||
Parameter.new(
|
||||
name: "timezone",
|
||||
description: "ALWAYS supply a Ruby compatible timezone",
|
||||
type: "string",
|
||||
required: true,
|
||||
),
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def result_name
|
||||
"time"
|
||||
end
|
||||
|
||||
def description_args
|
||||
{ timezone: @last_timezone, time: @last_time }
|
||||
end
|
||||
|
||||
def process(timezone:)
|
||||
time =
|
||||
begin
|
||||
Time.now.in_time_zone(timezone)
|
||||
rescue StandardError
|
||||
nil
|
||||
end
|
||||
time = Time.now if !time
|
||||
|
||||
@last_timezone = timezone
|
||||
@last_time = time.to_s
|
||||
|
||||
{ args: { timezone: timezone }, time: time.to_s }
|
||||
end
|
||||
end
|
||||
end
|
|
@ -50,7 +50,7 @@ module DiscourseAi
|
|||
scope.user.in_any_groups?(SiteSetting.ai_bot_allowed_groups_map)
|
||||
end,
|
||||
) do
|
||||
DiscourseAi::AiBot::Personas
|
||||
DiscourseAi::AiBot::Personas::Persona
|
||||
.all(user: scope.user)
|
||||
.map do |persona|
|
||||
{ id: persona.id, name: persona.name, description: persona.description }
|
||||
|
@ -92,32 +92,19 @@ module DiscourseAi
|
|||
include_condition: -> { SiteSetting.ai_bot_enabled && object.topic.private_message? },
|
||||
) do
|
||||
id = topic.custom_fields["ai_persona_id"]
|
||||
name = DiscourseAi::AiBot::Personas.find_by(user: scope.user, id: id.to_i)&.name if id
|
||||
name =
|
||||
DiscourseAi::AiBot::Personas::Persona.find_by(user: scope.user, id: id.to_i)&.name if id
|
||||
name || topic.custom_fields["ai_persona"]
|
||||
end
|
||||
|
||||
plugin.on(:post_created) do |post|
|
||||
bot_ids = BOTS.map(&:first)
|
||||
|
||||
if post.post_type == Post.types[:regular] && post.topic.private_message? &&
|
||||
!bot_ids.include?(post.user_id)
|
||||
if (SiteSetting.ai_bot_allowed_groups_map & post.user.group_ids).present?
|
||||
bot_id = post.topic.topic_allowed_users.where(user_id: bot_ids).first&.user_id
|
||||
|
||||
if bot_id
|
||||
if post.post_number == 1
|
||||
post.topic.custom_fields[REQUIRE_TITLE_UPDATE] = true
|
||||
post.topic.save_custom_fields
|
||||
end
|
||||
::Jobs.enqueue(:create_ai_reply, post_id: post.id, bot_user_id: bot_id)
|
||||
::Jobs.enqueue_in(
|
||||
5.minutes,
|
||||
:update_ai_bot_pm_title,
|
||||
post_id: post.id,
|
||||
bot_user_id: bot_id,
|
||||
)
|
||||
end
|
||||
end
|
||||
# Don't schedule a reply for a bot reply.
|
||||
if !bot_ids.include?(post.user_id)
|
||||
bot_user = post.topic.topic_allowed_users.where(user_id: bot_ids).first&.user
|
||||
bot = DiscourseAi::AiBot::Bot.as(bot_user)
|
||||
DiscourseAi::AiBot::Playground.new(bot).update_playground_with(post)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -1,157 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
class OpenAiBot < Bot
|
||||
def self.can_reply_as?(bot_user)
|
||||
open_ai_bot_ids = [
|
||||
DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID,
|
||||
DiscourseAi::AiBot::EntryPoint::GPT4_ID,
|
||||
DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID,
|
||||
]
|
||||
|
||||
open_ai_bot_ids.include?(bot_user.id)
|
||||
end
|
||||
|
||||
def prompt_limit(allow_commands:)
|
||||
# provide a buffer of 120 tokens - our function counting is not
|
||||
# 100% accurate and getting numbers to align exactly is very hard
|
||||
buffer = reply_params[:max_tokens] + 50
|
||||
|
||||
if allow_commands
|
||||
# note this is about 100 tokens over, OpenAI have a more optimal representation
|
||||
@function_size ||= tokenize(available_functions.to_json.to_s).length
|
||||
buffer += @function_size
|
||||
end
|
||||
|
||||
if bot_user.id == DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
|
||||
150_000 - buffer
|
||||
elsif bot_user.id == DiscourseAi::AiBot::EntryPoint::GPT4_ID
|
||||
8192 - buffer
|
||||
else
|
||||
16_384 - buffer
|
||||
end
|
||||
end
|
||||
|
||||
def reply_params
|
||||
# technically we could allow GPT-3.5 16k more tokens
|
||||
# but lets just keep it here for now
|
||||
{ temperature: 0.4, top_p: 0.9, max_tokens: 2500 }
|
||||
end
|
||||
|
||||
def extra_tokens_per_message
|
||||
# open ai defines about 4 tokens per message of overhead
|
||||
4
|
||||
end
|
||||
|
||||
def submit_prompt(
|
||||
prompt,
|
||||
prefer_low_cost: false,
|
||||
post: nil,
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
&blk
|
||||
)
|
||||
params =
|
||||
reply_params.merge(
|
||||
temperature: temperature,
|
||||
top_p: top_p,
|
||||
max_tokens: max_tokens,
|
||||
) { |key, old_value, new_value| new_value.nil? ? old_value : new_value }
|
||||
|
||||
model = model_for(low_cost: prefer_low_cost)
|
||||
|
||||
params[:functions] = available_functions if available_functions.present?
|
||||
|
||||
DiscourseAi::Inference::OpenAiCompletions.perform!(
|
||||
prompt,
|
||||
model,
|
||||
**params,
|
||||
post: post,
|
||||
&blk
|
||||
)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
def model_for(low_cost: false)
|
||||
if low_cost || bot_user.id == DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
|
||||
"gpt-3.5-turbo-16k"
|
||||
elsif bot_user.id == DiscourseAi::AiBot::EntryPoint::GPT4_ID
|
||||
"gpt-4"
|
||||
else
|
||||
# not quite released yet, once released we should replace with
|
||||
# gpt-4-turbo
|
||||
"gpt-4-1106-preview"
|
||||
end
|
||||
end
|
||||
|
||||
def clean_username(username)
|
||||
if username.match?(/\0[a-zA-Z0-9_-]{1,64}\z/)
|
||||
username
|
||||
else
|
||||
# not the best in the world, but this is what we have to work with
|
||||
# if sites enable unicode usernames this can get messy
|
||||
username.gsub(/[^a-zA-Z0-9_-]/, "_")[0..63]
|
||||
end
|
||||
end
|
||||
|
||||
def include_function_instructions_in_system_prompt?
|
||||
# open ai uses a bespoke system for function calls
|
||||
false
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def populate_functions(partial:, reply:, functions:, done:, current_delta:)
|
||||
return if !partial
|
||||
fn = partial.dig(:choices, 0, :delta, :function_call)
|
||||
if fn
|
||||
functions.add_function(fn[:name]) if fn[:name].present?
|
||||
functions.add_argument_fragment(fn[:arguments]) if !fn[:arguments].nil?
|
||||
functions.custom = true
|
||||
end
|
||||
end
|
||||
|
||||
def build_message(poster_username, content, function: false, system: false)
|
||||
is_bot = poster_username == bot_user.username
|
||||
|
||||
if function
|
||||
role = "function"
|
||||
elsif system
|
||||
role = "system"
|
||||
else
|
||||
role = is_bot ? "assistant" : "user"
|
||||
end
|
||||
|
||||
result = { role: role, content: content }
|
||||
|
||||
if function
|
||||
result[:name] = poster_username
|
||||
elsif !system && poster_username != bot_user.username && poster_username.present?
|
||||
# Open AI restrict name to 64 chars and only A-Za-z._ (work around)
|
||||
result[:name] = clean_username(poster_username)
|
||||
end
|
||||
|
||||
result
|
||||
end
|
||||
|
||||
def get_delta(partial, _context)
|
||||
partial.dig(:choices, 0, :delta, :content).to_s
|
||||
end
|
||||
|
||||
def get_updated_title(prompt)
|
||||
DiscourseAi::Inference::OpenAiCompletions.perform!(
|
||||
prompt,
|
||||
model_for,
|
||||
temperature: 0.7,
|
||||
top_p: 0.9,
|
||||
max_tokens: 40,
|
||||
).dig(:choices, 0, :message, :content)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,46 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Personas
|
||||
def self.system_personas
|
||||
@system_personas ||= {
|
||||
Personas::General => -1,
|
||||
Personas::SqlHelper => -2,
|
||||
Personas::Artist => -3,
|
||||
Personas::SettingsExplorer => -4,
|
||||
Personas::Researcher => -5,
|
||||
Personas::Creative => -6,
|
||||
Personas::DallE3 => -7,
|
||||
}
|
||||
end
|
||||
|
||||
def self.system_personas_by_id
|
||||
@system_personas_by_id ||= system_personas.invert
|
||||
end
|
||||
|
||||
def self.all(user:)
|
||||
# this needs to be dynamic cause site settings may change
|
||||
all_available_commands = Persona.all_available_commands
|
||||
|
||||
AiPersona.all_personas.filter do |persona|
|
||||
next false if !user.in_any_groups?(persona.allowed_group_ids)
|
||||
|
||||
if persona.system
|
||||
instance = persona.new
|
||||
(
|
||||
instance.required_commands == [] ||
|
||||
(instance.required_commands - all_available_commands).empty?
|
||||
)
|
||||
else
|
||||
true
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def self.find_by(id: nil, name: nil, user:)
|
||||
all(user: user).find { |persona| persona.id == id || persona.name == name }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -4,12 +4,12 @@ module DiscourseAi
|
|||
module AiBot
|
||||
module Personas
|
||||
class Artist < Persona
|
||||
def commands
|
||||
[Commands::ImageCommand]
|
||||
def tools
|
||||
[Tools::Image]
|
||||
end
|
||||
|
||||
def required_commands
|
||||
[Commands::ImageCommand]
|
||||
def required_tools
|
||||
[Tools::Image]
|
||||
end
|
||||
|
||||
def system_prompt
|
||||
|
|
|
@ -4,7 +4,7 @@ module DiscourseAi
|
|||
module AiBot
|
||||
module Personas
|
||||
class Creative < Persona
|
||||
def commands
|
||||
def tools
|
||||
[]
|
||||
end
|
||||
|
||||
|
|
|
@ -4,12 +4,12 @@ module DiscourseAi
|
|||
module AiBot
|
||||
module Personas
|
||||
class DallE3 < Persona
|
||||
def commands
|
||||
[Commands::DallECommand]
|
||||
def tools
|
||||
[Tools::DallE]
|
||||
end
|
||||
|
||||
def required_commands
|
||||
[Commands::DallECommand]
|
||||
def required_tools
|
||||
[Tools::DallE]
|
||||
end
|
||||
|
||||
def system_prompt
|
||||
|
|
|
@ -4,15 +4,15 @@ module DiscourseAi
|
|||
module AiBot
|
||||
module Personas
|
||||
class General < Persona
|
||||
def commands
|
||||
def tools
|
||||
[
|
||||
Commands::SearchCommand,
|
||||
Commands::GoogleCommand,
|
||||
Commands::ImageCommand,
|
||||
Commands::ReadCommand,
|
||||
Commands::ImageCommand,
|
||||
Commands::CategoriesCommand,
|
||||
Commands::TagsCommand,
|
||||
Tools::Search,
|
||||
Tools::Google,
|
||||
Tools::Image,
|
||||
Tools::Read,
|
||||
Tools::Image,
|
||||
Tools::ListCategories,
|
||||
Tools::ListTags,
|
||||
]
|
||||
end
|
||||
|
||||
|
|
|
@ -4,19 +4,84 @@ module DiscourseAi
|
|||
module AiBot
|
||||
module Personas
|
||||
class Persona
|
||||
def self.name
|
||||
class << self
|
||||
def system_personas
|
||||
@system_personas ||= {
|
||||
Personas::General => -1,
|
||||
Personas::SqlHelper => -2,
|
||||
Personas::Artist => -3,
|
||||
Personas::SettingsExplorer => -4,
|
||||
Personas::Researcher => -5,
|
||||
Personas::Creative => -6,
|
||||
Personas::DallE3 => -7,
|
||||
}
|
||||
end
|
||||
|
||||
def system_personas_by_id
|
||||
@system_personas_by_id ||= system_personas.invert
|
||||
end
|
||||
|
||||
def all(user:)
|
||||
# listing tools has to be dynamic cause site settings may change
|
||||
|
||||
AiPersona.all_personas.filter do |persona|
|
||||
next false if !user.in_any_groups?(persona.allowed_group_ids)
|
||||
|
||||
if persona.system
|
||||
instance = persona.new
|
||||
(
|
||||
instance.required_tools == [] ||
|
||||
(instance.required_tools - all_available_tools).empty?
|
||||
)
|
||||
else
|
||||
true
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def find_by(id: nil, name: nil, user:)
|
||||
all(user: user).find { |persona| persona.id == id || persona.name == name }
|
||||
end
|
||||
|
||||
def name
|
||||
I18n.t("discourse_ai.ai_bot.personas.#{to_s.demodulize.underscore}.name")
|
||||
end
|
||||
|
||||
def self.description
|
||||
def description
|
||||
I18n.t("discourse_ai.ai_bot.personas.#{to_s.demodulize.underscore}.description")
|
||||
end
|
||||
|
||||
def commands
|
||||
def all_available_tools
|
||||
tools = [
|
||||
Tools::ListCategories,
|
||||
Tools::Time,
|
||||
Tools::Search,
|
||||
Tools::Summarize,
|
||||
Tools::Read,
|
||||
Tools::DbSchema,
|
||||
Tools::SearchSettings,
|
||||
Tools::Summarize,
|
||||
Tools::SettingContext,
|
||||
]
|
||||
|
||||
tools << Tools::ListTags if SiteSetting.tagging_enabled
|
||||
tools << Tools::Image if SiteSetting.ai_stability_api_key.present?
|
||||
|
||||
tools << Tools::DallE if SiteSetting.ai_openai_api_key.present?
|
||||
if SiteSetting.ai_google_custom_search_api_key.present? &&
|
||||
SiteSetting.ai_google_custom_search_cx.present?
|
||||
tools << Tools::Google
|
||||
end
|
||||
|
||||
tools
|
||||
end
|
||||
end
|
||||
|
||||
def tools
|
||||
[]
|
||||
end
|
||||
|
||||
def required_commands
|
||||
def required_tools
|
||||
[]
|
||||
end
|
||||
|
||||
|
@ -24,104 +89,55 @@ module DiscourseAi
|
|||
{}
|
||||
end
|
||||
|
||||
def render_commands(render_function_instructions:)
|
||||
return +"" if available_commands.empty?
|
||||
|
||||
result = +""
|
||||
if render_function_instructions
|
||||
result << "\n"
|
||||
result << function_list.system_prompt
|
||||
result << "\n"
|
||||
end
|
||||
result << available_commands.map(&:custom_system_message).compact.join("\n")
|
||||
result
|
||||
def available_tools
|
||||
self.class.all_available_tools.filter { |tool| tools.include?(tool) }
|
||||
end
|
||||
|
||||
def render_system_prompt(
|
||||
topic: nil,
|
||||
render_function_instructions: true,
|
||||
allow_commands: true
|
||||
)
|
||||
substitutions = {
|
||||
site_url: Discourse.base_url,
|
||||
site_title: SiteSetting.title,
|
||||
site_description: SiteSetting.site_description,
|
||||
time: Time.zone.now,
|
||||
}
|
||||
|
||||
substitutions[:participants] = topic.allowed_users.map(&:username).join(", ") if topic
|
||||
|
||||
prompt =
|
||||
def craft_prompt(context)
|
||||
system_insts =
|
||||
system_prompt.gsub(/\{(\w+)\}/) do |match|
|
||||
found = substitutions[match[1..-2].to_sym]
|
||||
found = context[match[1..-2].to_sym]
|
||||
found.nil? ? match : found.to_s
|
||||
end
|
||||
|
||||
if allow_commands
|
||||
prompt += render_commands(render_function_instructions: render_function_instructions)
|
||||
end
|
||||
insts = <<~TEXT
|
||||
#{system_insts}
|
||||
#{available_tools.map(&:custom_system_message).compact_blank.join("\n")}
|
||||
TEXT
|
||||
|
||||
prompt
|
||||
end
|
||||
|
||||
def available_commands
|
||||
return @available_commands if @available_commands
|
||||
@available_commands = all_available_commands.filter { |cmd| commands.include?(cmd) }
|
||||
end
|
||||
|
||||
def available_functions
|
||||
# note if defined? can be a problem in test
|
||||
# this can never be nil so it is safe
|
||||
return @available_functions if @available_functions
|
||||
|
||||
functions = []
|
||||
|
||||
functions =
|
||||
available_commands.map do |command|
|
||||
function =
|
||||
DiscourseAi::Inference::Function.new(name: command.name, description: command.desc)
|
||||
command.parameters.each { |parameter| function.add_parameter(parameter) }
|
||||
function
|
||||
end
|
||||
|
||||
@available_functions = functions
|
||||
end
|
||||
|
||||
def function_list
|
||||
return @function_list if @function_list
|
||||
|
||||
@function_list = DiscourseAi::Inference::FunctionList.new
|
||||
available_functions.each { |function| @function_list << function }
|
||||
@function_list
|
||||
end
|
||||
|
||||
def self.all_available_commands
|
||||
all_commands = [
|
||||
Commands::CategoriesCommand,
|
||||
Commands::TimeCommand,
|
||||
Commands::SearchCommand,
|
||||
Commands::SummarizeCommand,
|
||||
Commands::ReadCommand,
|
||||
Commands::DbSchemaCommand,
|
||||
Commands::SearchSettingsCommand,
|
||||
Commands::SummarizeCommand,
|
||||
Commands::SettingContextCommand,
|
||||
{ insts: insts }.tap do |prompt|
|
||||
prompt[:tools] = available_tools.map(&:signature) if available_tools
|
||||
prompt[:conversation_context] = context[:conversation_context] if context[
|
||||
:conversation_context
|
||||
]
|
||||
|
||||
all_commands << Commands::TagsCommand if SiteSetting.tagging_enabled
|
||||
all_commands << Commands::ImageCommand if SiteSetting.ai_stability_api_key.present?
|
||||
|
||||
all_commands << Commands::DallECommand if SiteSetting.ai_openai_api_key.present?
|
||||
if SiteSetting.ai_google_custom_search_api_key.present? &&
|
||||
SiteSetting.ai_google_custom_search_cx.present?
|
||||
all_commands << Commands::GoogleCommand
|
||||
end
|
||||
end
|
||||
|
||||
all_commands
|
||||
def find_tool(partial)
|
||||
parsed_function = Nokogiri::HTML5.fragment(partial)
|
||||
function_id = parsed_function.at("tool_id")&.text
|
||||
function_name = parsed_function.at("tool_name")&.text
|
||||
return false if function_name.nil?
|
||||
|
||||
tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name }
|
||||
return false if tool_klass.nil?
|
||||
|
||||
arguments =
|
||||
tool_klass.signature[:parameters]
|
||||
.to_a
|
||||
.reduce({}) do |memo, p|
|
||||
argument = parsed_function.at(p[:name])&.text
|
||||
next(memo) unless argument
|
||||
|
||||
memo[p[:name].to_sym] = argument
|
||||
memo
|
||||
end
|
||||
|
||||
def all_available_commands
|
||||
@cmds ||= self.class.all_available_commands
|
||||
tool_klass.new(
|
||||
arguments,
|
||||
tool_call_id: function_id,
|
||||
persona_options: options[tool_klass].to_h,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -4,12 +4,12 @@ module DiscourseAi
|
|||
module AiBot
|
||||
module Personas
|
||||
class Researcher < Persona
|
||||
def commands
|
||||
[Commands::GoogleCommand]
|
||||
def tools
|
||||
[Tools::Google]
|
||||
end
|
||||
|
||||
def required_commands
|
||||
[Commands::GoogleCommand]
|
||||
def required_tools
|
||||
[Tools::Google]
|
||||
end
|
||||
|
||||
def system_prompt
|
||||
|
|
|
@ -4,15 +4,8 @@ module DiscourseAi
|
|||
module AiBot
|
||||
module Personas
|
||||
class SettingsExplorer < Persona
|
||||
def commands
|
||||
all_available_commands
|
||||
end
|
||||
|
||||
def all_available_commands
|
||||
[
|
||||
DiscourseAi::AiBot::Commands::SettingContextCommand,
|
||||
DiscourseAi::AiBot::Commands::SearchSettingsCommand,
|
||||
]
|
||||
def tools
|
||||
[Tools::SettingContext, Tools::SearchSettings]
|
||||
end
|
||||
|
||||
def system_prompt
|
||||
|
|
|
@ -27,12 +27,8 @@ module DiscourseAi
|
|||
@schema = schema
|
||||
end
|
||||
|
||||
def commands
|
||||
all_available_commands
|
||||
end
|
||||
|
||||
def all_available_commands
|
||||
[DiscourseAi::AiBot::Commands::DbSchemaCommand]
|
||||
def tools
|
||||
[Tools::DbSchema]
|
||||
end
|
||||
|
||||
def system_prompt
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
class Playground
|
||||
# An abstraction to manage the bot and topic interactions.
|
||||
# The bot will take care of completions while this class updates the topic title
|
||||
# and stream replies.
|
||||
|
||||
REQUIRE_TITLE_UPDATE = "discourse-ai-title-update"
|
||||
|
||||
def initialize(bot)
|
||||
@bot = bot
|
||||
end
|
||||
|
||||
def update_playground_with(post)
|
||||
if can_attach?(post) && bot.bot_user
|
||||
schedule_playground_titling(post, bot.bot_user)
|
||||
schedule_bot_reply(post, bot.bot_user)
|
||||
end
|
||||
end
|
||||
|
||||
def conversation_context(post)
|
||||
# Pay attention to the `post_number <= ?` here.
|
||||
# We want to inject the last post as context because they are translated differently.
|
||||
context =
|
||||
post
|
||||
.topic
|
||||
.posts
|
||||
.includes(:user)
|
||||
.joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id")
|
||||
.where("post_number <= ?", post.post_number)
|
||||
.order("post_number desc")
|
||||
.where("post_type = ?", Post.types[:regular])
|
||||
.limit(50)
|
||||
.pluck(:raw, :username, "post_custom_prompts.custom_prompt")
|
||||
|
||||
result = []
|
||||
|
||||
first = true
|
||||
context.each do |raw, username, custom_prompt|
|
||||
custom_prompt_translation =
|
||||
Proc.new do |message|
|
||||
# We can't keep backwards-compatibility for stored functions.
|
||||
# Tool syntax requires a tool_call_id which we don't have.
|
||||
if message[2] != "function"
|
||||
custom_context = {
|
||||
content: message[0],
|
||||
type: message[2].present? ? message[2] : "assistant",
|
||||
}
|
||||
|
||||
custom_context[:name] = message[1] if custom_context[:type] != "assistant"
|
||||
|
||||
result << custom_context
|
||||
end
|
||||
end
|
||||
|
||||
if custom_prompt.present?
|
||||
if first
|
||||
custom_prompt.reverse_each(&custom_prompt_translation)
|
||||
first = false
|
||||
else
|
||||
tool_call_and_tool = custom_prompt.first(2)
|
||||
tool_call_and_tool.reverse_each(&custom_prompt_translation)
|
||||
end
|
||||
else
|
||||
context = {
|
||||
content: raw,
|
||||
type: (available_bot_usernames.include?(username) ? "assistant" : "user"),
|
||||
}
|
||||
|
||||
context[:name] = username if context[:type] == "user"
|
||||
|
||||
result << context
|
||||
end
|
||||
end
|
||||
|
||||
result
|
||||
end
|
||||
|
||||
def title_playground(post)
|
||||
context = conversation_context(post)
|
||||
|
||||
bot
|
||||
.get_updated_title(context, post.user)
|
||||
.tap do |new_title|
|
||||
PostRevisor.new(post.topic.first_post, post.topic).revise!(
|
||||
bot.bot_user,
|
||||
title: new_title.sub(/\A"/, "").sub(/"\Z/, ""),
|
||||
)
|
||||
post.topic.custom_fields.delete(DiscourseAi::AiBot::EntryPoint::REQUIRE_TITLE_UPDATE)
|
||||
post.topic.save_custom_fields
|
||||
end
|
||||
end
|
||||
|
||||
def reply_to(post)
|
||||
reply = +""
|
||||
start = Time.now
|
||||
|
||||
context = {
|
||||
site_url: Discourse.base_url,
|
||||
site_title: SiteSetting.title,
|
||||
site_description: SiteSetting.site_description,
|
||||
time: Time.zone.now,
|
||||
participants: post.topic.allowed_users.map(&:username).join(", "),
|
||||
conversation_context: conversation_context(post),
|
||||
user: post.user,
|
||||
}
|
||||
|
||||
reply_post =
|
||||
PostCreator.create!(
|
||||
bot.bot_user,
|
||||
topic_id: post.topic_id,
|
||||
raw: I18n.t("discourse_ai.ai_bot.placeholder_reply"),
|
||||
skip_validations: true,
|
||||
)
|
||||
|
||||
redis_stream_key = "gpt_cancel:#{reply_post.id}"
|
||||
Discourse.redis.setex(redis_stream_key, 60, 1)
|
||||
|
||||
new_custom_prompts =
|
||||
bot.reply(context) do |partial, cancel, placeholder|
|
||||
reply << partial
|
||||
raw = reply.dup
|
||||
raw << "\n\n" << placeholder if placeholder.present?
|
||||
|
||||
if !Discourse.redis.get(redis_stream_key)
|
||||
cancel&.call
|
||||
|
||||
reply_post.update!(raw: reply, cooked: PrettyText.cook(reply))
|
||||
end
|
||||
|
||||
# Minor hack to skip the delay during tests.
|
||||
if placeholder.blank?
|
||||
next if (Time.now - start < 0.5) && !Rails.env.test?
|
||||
start = Time.now
|
||||
end
|
||||
|
||||
Discourse.redis.expire(redis_stream_key, 60)
|
||||
|
||||
publish_update(reply_post, raw: raw)
|
||||
end
|
||||
|
||||
return if reply.blank?
|
||||
|
||||
reply_post.tap do |bot_reply|
|
||||
publish_update(bot_reply, done: true)
|
||||
|
||||
bot_reply.revise(
|
||||
bot.bot_user,
|
||||
{ raw: reply },
|
||||
skip_validations: true,
|
||||
skip_revision: true,
|
||||
)
|
||||
|
||||
bot_reply.post_custom_prompt ||= bot_reply.build_post_custom_prompt(custom_prompt: [])
|
||||
prompt = bot_reply.post_custom_prompt.custom_prompt || []
|
||||
prompt.concat(new_custom_prompts)
|
||||
prompt << [reply, bot.bot_user.username]
|
||||
bot_reply.post_custom_prompt.update!(custom_prompt: prompt)
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
attr_reader :bot
|
||||
|
||||
def can_attach?(post)
|
||||
return false if bot.bot_user.nil?
|
||||
return false if post.post_type != Post.types[:regular]
|
||||
return false unless post.topic.private_message?
|
||||
return false if (SiteSetting.ai_bot_allowed_groups_map & post.user.group_ids).blank?
|
||||
|
||||
true
|
||||
end
|
||||
|
||||
def schedule_playground_titling(post, bot_user)
|
||||
if post.post_number == 1
|
||||
post.topic.custom_fields[REQUIRE_TITLE_UPDATE] = true
|
||||
post.topic.save_custom_fields
|
||||
end
|
||||
|
||||
::Jobs.enqueue_in(
|
||||
5.minutes,
|
||||
:update_ai_bot_pm_title,
|
||||
post_id: post.id,
|
||||
bot_user_id: bot_user.id,
|
||||
)
|
||||
end
|
||||
|
||||
def schedule_bot_reply(post, bot_user)
|
||||
::Jobs.enqueue(:create_ai_reply, post_id: post.id, bot_user_id: bot_user.id)
|
||||
end
|
||||
|
||||
def context(topic)
|
||||
{
|
||||
site_url: Discourse.base_url,
|
||||
site_title: SiteSetting.title,
|
||||
site_description: SiteSetting.site_description,
|
||||
time: Time.zone.now,
|
||||
participants: topic.allowed_users.map(&:username).join(", "),
|
||||
}
|
||||
end
|
||||
|
||||
def publish_update(bot_reply_post, payload)
|
||||
MessageBus.publish(
|
||||
"discourse-ai/ai-bot/topic/#{bot_reply_post.topic_id}",
|
||||
payload.merge(post_id: bot_reply_post.id, post_number: bot_reply_post.post_number),
|
||||
user_ids: bot_reply_post.topic.allowed_user_ids,
|
||||
)
|
||||
end
|
||||
|
||||
def available_bot_usernames
|
||||
@bot_usernames ||= DiscourseAi::AiBot::EntryPoint::BOTS.map(&:second)
|
||||
end
|
||||
|
||||
def clean_username(username)
|
||||
if username.match?(/\0[a-zA-Z0-9_-]{1,64}\z/)
|
||||
username
|
||||
else
|
||||
# not the best in the world, but this is what we have to work with
|
||||
# if sites enable unicode usernames this can get messy
|
||||
username.gsub(/[^a-zA-Z0-9_-]/, "_")[0..63]
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,125 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class DallE < Tool
|
||||
def self.signature
|
||||
{
|
||||
name: name,
|
||||
description: "Renders images from supplied descriptions",
|
||||
parameters: [
|
||||
{
|
||||
name: "prompts",
|
||||
description:
|
||||
"The prompts used to generate or create or draw the image (5000 chars or less, be creative) up to 4 prompts",
|
||||
type: "array",
|
||||
item_type: "string",
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def self.name
|
||||
"dall_e"
|
||||
end
|
||||
|
||||
def prompts
|
||||
parameters[:prompts]
|
||||
end
|
||||
|
||||
def chain_next_response?
|
||||
false
|
||||
end
|
||||
|
||||
def invoke(bot_user, _llm)
|
||||
# max 4 prompts
|
||||
max_prompts = prompts.take(4)
|
||||
progress = +""
|
||||
|
||||
yield(progress)
|
||||
|
||||
results = nil
|
||||
|
||||
# this ensures multisite safety since background threads
|
||||
# generate the images
|
||||
api_key = SiteSetting.ai_openai_api_key
|
||||
api_url = SiteSetting.ai_openai_dall_e_3_url
|
||||
|
||||
threads = []
|
||||
max_prompts.each_with_index do |prompt, index|
|
||||
threads << Thread.new(prompt) do |inner_prompt|
|
||||
attempts = 0
|
||||
begin
|
||||
DiscourseAi::Inference::OpenAiImageGenerator.perform!(
|
||||
inner_prompt,
|
||||
api_key: api_key,
|
||||
api_url: api_url,
|
||||
)
|
||||
rescue => e
|
||||
attempts += 1
|
||||
sleep 2
|
||||
retry if attempts < 3
|
||||
Discourse.warn_exception(
|
||||
e,
|
||||
message: "Failed to generate image for prompt #{prompt}",
|
||||
)
|
||||
nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
while true
|
||||
progress << "."
|
||||
yield(progress)
|
||||
break if threads.all? { |t| t.join(2) }
|
||||
end
|
||||
|
||||
results = threads.filter_map(&:value)
|
||||
|
||||
if results.blank?
|
||||
return { prompts: max_prompts, error: "Something went wrong, could not generate image" }
|
||||
end
|
||||
|
||||
uploads = []
|
||||
|
||||
results.each_with_index do |result, index|
|
||||
result[:data].each do |image|
|
||||
Tempfile.create("v1_txt2img_#{index}.png") do |file|
|
||||
file.binmode
|
||||
file.write(Base64.decode64(image[:b64_json]))
|
||||
file.rewind
|
||||
uploads << {
|
||||
prompt: image[:revised_prompt],
|
||||
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
self.custom_raw = <<~RAW
|
||||
|
||||
[grid]
|
||||
#{
|
||||
uploads
|
||||
.map do |item|
|
||||
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
|
||||
end
|
||||
.join(" ")
|
||||
}
|
||||
[/grid]
|
||||
RAW
|
||||
|
||||
{ prompts: uploads.map { |item| item[:prompt] } }
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def description_args
|
||||
{ prompt: prompts.first }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,62 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class DbSchema < Tool
|
||||
def self.signature
|
||||
{
|
||||
name: name,
|
||||
description: "Will load schema information for specific tables in the database",
|
||||
parameters: [
|
||||
{
|
||||
name: "tables",
|
||||
description:
|
||||
"list of tables to load schema information for, comma seperated list eg: (users,posts))",
|
||||
type: "string",
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def self.name
|
||||
"schema"
|
||||
end
|
||||
|
||||
def tables
|
||||
parameters[:tables]
|
||||
end
|
||||
|
||||
def invoke(_bot_user, _llm)
|
||||
tables_arr = tables.split(",").map(&:strip)
|
||||
|
||||
table_info = {}
|
||||
DB
|
||||
.query(<<~SQL, tables_arr)
|
||||
select table_name, column_name, data_type from information_schema.columns
|
||||
where table_schema = 'public'
|
||||
and table_name in (?)
|
||||
order by table_name
|
||||
SQL
|
||||
.each do |row|
|
||||
(table_info[row.table_name] ||= []) << "#{row.column_name} #{row.data_type}"
|
||||
end
|
||||
|
||||
schema_info =
|
||||
table_info
|
||||
.map { |table_name, columns| "#{table_name}(#{columns.join(",")})" }
|
||||
.join("\n")
|
||||
|
||||
{ schema_info: schema_info, tables: tables }
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def description_args
|
||||
{ tables: tables.join(", ") }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,85 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class Google < Tool
|
||||
def self.signature
|
||||
{
|
||||
name: name,
|
||||
description:
|
||||
"Will search using Google - global internet search (supports all Google search operators)",
|
||||
parameters: [
|
||||
{ name: "query", description: "The search query", type: "string", required: true },
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def self.custom_system_message
|
||||
"You were trained on OLD data, lean on search to get up to date information from the web"
|
||||
end
|
||||
|
||||
def self.name
|
||||
"google"
|
||||
end
|
||||
|
||||
def query
|
||||
parameters[:query].to_s
|
||||
end
|
||||
|
||||
def invoke(bot_user, llm)
|
||||
yield("") # Triggers placeholder update
|
||||
|
||||
api_key = SiteSetting.ai_google_custom_search_api_key
|
||||
cx = SiteSetting.ai_google_custom_search_cx
|
||||
escaped_query = CGI.escape(query)
|
||||
uri =
|
||||
URI(
|
||||
"https://www.googleapis.com/customsearch/v1?key=#{api_key}&cx=#{cx}&q=#{escaped_query}&num=10",
|
||||
)
|
||||
body = Net::HTTP.get(uri)
|
||||
|
||||
parse_search_json(body, escaped_query, llm)
|
||||
end
|
||||
|
||||
attr_reader :results_count
|
||||
|
||||
protected
|
||||
|
||||
def description_args
|
||||
{
|
||||
count: results_count || 0,
|
||||
query: query,
|
||||
url: "https://google.com/search?q=#{CGI.escape(query)}",
|
||||
}
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def minimize_field(result, field, llm, max_tokens: 100)
|
||||
data = result[field]
|
||||
return "" if data.blank?
|
||||
|
||||
llm.tokenizer.truncate(data, max_tokens).squish
|
||||
end
|
||||
|
||||
def parse_search_json(json_data, escaped_query, llm)
|
||||
parsed = JSON.parse(json_data)
|
||||
results = parsed["items"]
|
||||
|
||||
@results_count = parsed.dig("searchInformation", "totalResults").to_i
|
||||
|
||||
format_results(results, args: escaped_query) do |result|
|
||||
{
|
||||
title: minimize_field(result, "title", llm),
|
||||
link: minimize_field(result, "link", llm),
|
||||
snippet: minimize_field(result, "snippet", llm, max_tokens: 120),
|
||||
displayLink: minimize_field(result, "displayLink", llm),
|
||||
formattedUrl: minimize_field(result, "formattedUrl", llm),
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,144 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class Image < Tool
|
||||
def self.signature
|
||||
{
|
||||
name: name,
|
||||
description:
|
||||
"Renders an image from the description (remove all connector words, keep it to 40 words or less). Despite being a text based bot you can generate images! (when user asks to draw, paint or other synonyms try this)",
|
||||
parameters: [
|
||||
{
|
||||
name: "prompts",
|
||||
description:
|
||||
"The prompts used to generate or create or draw the image (40 words or less, be creative) up to 4 prompts",
|
||||
type: "array",
|
||||
item_type: "string",
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
name: "seeds",
|
||||
description:
|
||||
"The seed used to generate the image (optional) - can be used to retain image style on amended prompts",
|
||||
type: "array",
|
||||
item_type: "integer",
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def self.name
|
||||
"image"
|
||||
end
|
||||
|
||||
def prompts
|
||||
JSON.parse(parameters[:prompts].to_s)
|
||||
end
|
||||
|
||||
def seeds
|
||||
parameters[:seeds]
|
||||
end
|
||||
|
||||
def chain_next_response?
|
||||
false
|
||||
end
|
||||
|
||||
def invoke(bot_user, _llm)
|
||||
# max 4 prompts
|
||||
selected_prompts = prompts.take(4)
|
||||
seeds = seeds.take(4) if seeds
|
||||
|
||||
progress = +""
|
||||
yield(progress)
|
||||
|
||||
results = nil
|
||||
|
||||
# this ensures multisite safety since background threads
|
||||
# generate the images
|
||||
api_key = SiteSetting.ai_stability_api_key
|
||||
engine = SiteSetting.ai_stability_engine
|
||||
api_url = SiteSetting.ai_stability_api_url
|
||||
|
||||
threads = []
|
||||
selected_prompts.each_with_index do |prompt, index|
|
||||
seed = seeds ? seeds[index] : nil
|
||||
threads << Thread.new(seed, prompt) do |inner_seed, inner_prompt|
|
||||
attempts = 0
|
||||
begin
|
||||
DiscourseAi::Inference::StabilityGenerator.perform!(
|
||||
inner_prompt,
|
||||
engine: engine,
|
||||
api_key: api_key,
|
||||
api_url: api_url,
|
||||
image_count: 1,
|
||||
seed: inner_seed,
|
||||
)
|
||||
rescue => e
|
||||
attempts += 1
|
||||
retry if attempts < 3
|
||||
Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}")
|
||||
nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
while true
|
||||
progress << "."
|
||||
yield(progress)
|
||||
break if threads.all? { |t| t.join(2) }
|
||||
end
|
||||
|
||||
results = threads.map(&:value).compact
|
||||
|
||||
if !results.present?
|
||||
return { prompts: prompts, error: "Something went wrong, could not generate image" }
|
||||
end
|
||||
|
||||
uploads = []
|
||||
|
||||
results.each_with_index do |result, index|
|
||||
result[:artifacts].each do |image|
|
||||
Tempfile.create("v1_txt2img_#{index}.png") do |file|
|
||||
file.binmode
|
||||
file.write(Base64.decode64(image[:base64]))
|
||||
file.rewind
|
||||
uploads << {
|
||||
prompt: prompts[index],
|
||||
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
|
||||
seed: image[:seed],
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@custom_raw = <<~RAW
|
||||
|
||||
[grid]
|
||||
#{
|
||||
uploads
|
||||
.map do |item|
|
||||
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
|
||||
end
|
||||
.join(" ")
|
||||
}
|
||||
[/grid]
|
||||
RAW
|
||||
|
||||
{
|
||||
prompts: uploads.map { |item| item[:prompt] },
|
||||
seeds: uploads.map { |item| item[:seed] },
|
||||
}
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def description_args
|
||||
{ prompt: prompts.first }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,46 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class ListCategories < Tool
|
||||
def self.signature
|
||||
{
|
||||
name: name,
|
||||
description:
|
||||
"Will list the categories on the current discourse instance, prefer to format with # in front of the category name",
|
||||
}
|
||||
end
|
||||
|
||||
def self.name
|
||||
"categories"
|
||||
end
|
||||
|
||||
def invoke(_bot_user, _llm)
|
||||
columns = {
|
||||
name: "Name",
|
||||
slug: "Slug",
|
||||
description: "Description",
|
||||
posts_year: "Posts Year",
|
||||
posts_month: "Posts Month",
|
||||
posts_week: "Posts Week",
|
||||
id: "id",
|
||||
parent_category_id: "parent_category_id",
|
||||
}
|
||||
|
||||
rows = Category.where(read_restricted: false).limit(100).pluck(*columns.keys)
|
||||
|
||||
@last_count = rows.length
|
||||
|
||||
{ rows: rows, column_names: columns.values }
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def description_args
|
||||
{ count: @last_count || 0 }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,41 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class ListTags < Tool
|
||||
def self.signature
|
||||
{
|
||||
name: name,
|
||||
description: "Will list the 100 most popular tags on the current discourse instance",
|
||||
}
|
||||
end
|
||||
|
||||
def self.name
|
||||
"tags"
|
||||
end
|
||||
|
||||
def invoke(_bot_user, _llm)
|
||||
column_names = { name: "Name", public_topic_count: "Topic Count" }
|
||||
|
||||
tags =
|
||||
Tag
|
||||
.where("public_topic_count > 0")
|
||||
.order(public_topic_count: :desc)
|
||||
.limit(100)
|
||||
.pluck(*column_names.keys)
|
||||
|
||||
@last_count = tags.length
|
||||
|
||||
format_results(tags, column_names.values)
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def description_args
|
||||
{ count: @last_count || 0 }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,25 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class Option
|
||||
attr_reader :tool, :name, :type
|
||||
|
||||
def initialize(tool:, name:, type:)
|
||||
@tool = tool
|
||||
@name = name.to_s
|
||||
@type = type
|
||||
end
|
||||
|
||||
def localized_name
|
||||
I18n.t("discourse_ai.ai_bot.command_options.#{tool.signature[:name]}.#{name}.name")
|
||||
end
|
||||
|
||||
def localized_description
|
||||
I18n.t("discourse_ai.ai_bot.command_options.#{tool.signature[:name]}.#{name}.description")
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,90 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class Read < Tool
|
||||
def self.signature
|
||||
{
|
||||
name: name,
|
||||
description: "Will read a topic or a post on this Discourse instance",
|
||||
parameters: [
|
||||
{
|
||||
name: "topic_id",
|
||||
description: "the id of the topic to read",
|
||||
type: "integer",
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
name: "post_number",
|
||||
description: "the post number to read",
|
||||
type: "integer",
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def self.name
|
||||
"read"
|
||||
end
|
||||
|
||||
attr_reader :title, :url
|
||||
|
||||
def topic_id
|
||||
parameters[:topic_id]
|
||||
end
|
||||
|
||||
def post_number
|
||||
parameters[:post_number]
|
||||
end
|
||||
|
||||
def invoke(_bot_user, llm)
|
||||
not_found = { topic_id: topic_id, description: "Topic not found" }
|
||||
|
||||
@title = ""
|
||||
|
||||
topic = Topic.find_by(id: topic_id.to_i)
|
||||
return not_found if !topic || !Guardian.new.can_see?(topic)
|
||||
|
||||
@title = topic.title
|
||||
|
||||
posts = Post.secured(Guardian.new).where(topic_id: topic_id).order(:post_number).limit(40)
|
||||
@url = topic.relative_url(post_number)
|
||||
|
||||
posts = posts.where("post_number = ?", post_number) if post_number
|
||||
|
||||
content = +<<~TEXT.strip
|
||||
title: #{topic.title}
|
||||
TEXT
|
||||
|
||||
category_names = [
|
||||
topic.category&.parent_category&.name,
|
||||
topic.category&.name,
|
||||
].compact.join(" ")
|
||||
content << "\ncategories: #{category_names}" if category_names.present?
|
||||
|
||||
if topic.tags.length > 0
|
||||
tags = DiscourseTagging.filter_visible(topic.tags, Guardian.new)
|
||||
content << "\ntags: #{tags.map(&:name).join(", ")}\n\n" if tags.length > 0
|
||||
end
|
||||
|
||||
posts.each { |post| content << "\n\n#{post.username} said:\n\n#{post.raw}" }
|
||||
|
||||
# TODO: 16k or 100k models can handle a lot more tokens
|
||||
content = llm.tokenizer.truncate(content, 1500).squish
|
||||
|
||||
result = { topic_id: topic_id, content: content, complete: true }
|
||||
result[:post_number] = post_number if post_number
|
||||
result
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def description_args
|
||||
{ title: title, url: url }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,223 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class Search < Tool
|
||||
MIN_SEMANTIC_RESULTS = 5
|
||||
|
||||
class << self
|
||||
def signature
|
||||
{
|
||||
name: name,
|
||||
description:
|
||||
"Will search topics in the current discourse instance, when rendering always prefer to link to the topics you find",
|
||||
parameters: [
|
||||
{
|
||||
name: "search_query",
|
||||
description:
|
||||
"Specific keywords to search for, space seperated (correct bad spelling, remove connector words)",
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
name: "user",
|
||||
description:
|
||||
"Filter search results to this username (only include if user explicitly asks to filter by user)",
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
name: "order",
|
||||
description: "search result order",
|
||||
type: "string",
|
||||
enum: %w[latest latest_topic oldest views likes],
|
||||
},
|
||||
{
|
||||
name: "limit",
|
||||
description:
|
||||
"limit number of results returned (generally prefer to just keep to default)",
|
||||
type: "integer",
|
||||
},
|
||||
{
|
||||
name: "max_posts",
|
||||
description:
|
||||
"maximum number of posts on the topics (topics where lots of people posted)",
|
||||
type: "integer",
|
||||
},
|
||||
{
|
||||
name: "tags",
|
||||
description:
|
||||
"list of tags to search for. Use + to join with OR, use , to join with AND",
|
||||
type: "string",
|
||||
},
|
||||
{ name: "category", description: "category name to filter to", type: "string" },
|
||||
{
|
||||
name: "before",
|
||||
description: "only topics created before a specific date YYYY-MM-DD",
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
name: "after",
|
||||
description: "only topics created after a specific date YYYY-MM-DD",
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
name: "status",
|
||||
description: "search for topics in a particular state",
|
||||
type: "string",
|
||||
enum: %w[open closed archived noreplies single_user],
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def name
|
||||
"search"
|
||||
end
|
||||
|
||||
def custom_system_message
|
||||
<<~TEXT
|
||||
You were trained on OLD data, lean on search to get up to date information about this forum
|
||||
When searching try to SIMPLIFY search terms
|
||||
Discourse search joins all terms with AND. Reduce and simplify terms to find more results.
|
||||
TEXT
|
||||
end
|
||||
|
||||
def accepted_options
|
||||
[option(:base_query, type: :string), option(:max_results, type: :integer)]
|
||||
end
|
||||
end
|
||||
|
||||
def search_args
|
||||
parameters.slice(:user, :order, :max_posts, :tags, :before, :after, :status)
|
||||
end
|
||||
|
||||
def invoke(bot_user, llm)
|
||||
search_string =
|
||||
search_args.reduce(+parameters[:search_query].to_s) do |memo, (key, value)|
|
||||
return memo if value.blank?
|
||||
memo << " " << "#{key}:#{value}"
|
||||
end
|
||||
|
||||
@last_query = search_string
|
||||
|
||||
yield(I18n.t("discourse_ai.ai_bot.searching", query: search_string))
|
||||
|
||||
if options[:base_query].present?
|
||||
search_string = "#{search_string} #{options[:base_query]}"
|
||||
end
|
||||
|
||||
results =
|
||||
::Search.execute(
|
||||
search_string.to_s + " status:public",
|
||||
search_type: :full_page,
|
||||
guardian: Guardian.new(),
|
||||
)
|
||||
|
||||
# let's be frugal with tokens, 50 results is too much and stuff gets cut off
|
||||
max_results = calculate_max_results(llm)
|
||||
results_limit = parameters[:limit] || max_results
|
||||
results_limit = max_results if parameters[:limit].to_i > max_results
|
||||
|
||||
should_try_semantic_search =
|
||||
SiteSetting.ai_embeddings_semantic_search_enabled && results_limit == max_results &&
|
||||
parameters[:search_query].present?
|
||||
|
||||
max_semantic_results = max_results / 4
|
||||
results_limit = results_limit - max_semantic_results if should_try_semantic_search
|
||||
|
||||
posts = results&.posts || []
|
||||
posts = posts[0..results_limit.to_i - 1]
|
||||
|
||||
if should_try_semantic_search
|
||||
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(Guardian.new())
|
||||
topic_ids = Set.new(posts.map(&:topic_id))
|
||||
|
||||
search = ::Search.new(search_string, guardian: Guardian.new)
|
||||
|
||||
results = nil
|
||||
begin
|
||||
results = semantic_search.search_for_topics(search.term)
|
||||
rescue => e
|
||||
Discourse.warn_exception(e, message: "Semantic search failed")
|
||||
end
|
||||
|
||||
if results
|
||||
results = search.apply_filters(results)
|
||||
|
||||
results.each do |post|
|
||||
next if topic_ids.include?(post.topic_id)
|
||||
|
||||
topic_ids << post.topic_id
|
||||
posts << post
|
||||
|
||||
break if posts.length >= max_results
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@last_num_results = posts.length
|
||||
# this is the general pattern from core
|
||||
# if there are millions of hidden tags it may fail
|
||||
hidden_tags = nil
|
||||
|
||||
if posts.blank?
|
||||
{ args: parameters, rows: [], instruction: "nothing was found, expand your search" }
|
||||
else
|
||||
format_results(posts, args: parameters) do |post|
|
||||
category_names = [
|
||||
post.topic.category&.parent_category&.name,
|
||||
post.topic.category&.name,
|
||||
].compact.join(" > ")
|
||||
row = {
|
||||
title: post.topic.title,
|
||||
url: Discourse.base_path + post.url,
|
||||
username: post.user&.username,
|
||||
excerpt: post.excerpt,
|
||||
created: post.created_at,
|
||||
category: category_names,
|
||||
likes: post.like_count,
|
||||
topic_views: post.topic.views,
|
||||
topic_likes: post.topic.like_count,
|
||||
topic_replies: post.topic.posts_count - 1,
|
||||
}
|
||||
|
||||
if SiteSetting.tagging_enabled
|
||||
hidden_tags ||= DiscourseTagging.hidden_tag_names
|
||||
# using map over pluck to avoid n+1 (assuming caller preloading)
|
||||
tags = post.topic.tags.map(&:name) - hidden_tags
|
||||
row[:tags] = tags.join(", ") if tags.present?
|
||||
end
|
||||
|
||||
row
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def description_args
|
||||
{
|
||||
count: @last_num_results || 0,
|
||||
query: @last_query || "",
|
||||
url: "#{Discourse.base_path}/search?q=#{CGI.escape(@last_query || "")}",
|
||||
}
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def calculate_max_results(llm)
|
||||
max_results = options[:max_results].to_i
|
||||
return [max_results, 100].min if max_results > 0
|
||||
|
||||
if llm.max_prompt_tokens > 30_000
|
||||
60
|
||||
elsif llm.max_prompt_tokens > 10_000
|
||||
40
|
||||
else
|
||||
20
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,88 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class SearchSettings < Tool
|
||||
INCLUDE_DESCRIPTIONS_MAX_LENGTH = 10
|
||||
MAX_RESULTS = 200
|
||||
|
||||
def self.signature
|
||||
{
|
||||
name: name,
|
||||
description: "Will search through site settings and return top 20 results",
|
||||
parameters: [
|
||||
{
|
||||
name: "query",
|
||||
description:
|
||||
"comma delimited list of settings to search for (e.g. 'setting_1,setting_2')",
|
||||
type: "string",
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def self.name
|
||||
"search_settings"
|
||||
end
|
||||
|
||||
def query
|
||||
parameters[:query].to_s
|
||||
end
|
||||
|
||||
def invoke(_bot_user, _llm)
|
||||
@last_num_results = 0
|
||||
|
||||
terms = query.split(",").map(&:strip).map(&:downcase).reject(&:blank?)
|
||||
|
||||
found =
|
||||
SiteSetting.all_settings.filter do |setting|
|
||||
name = setting[:setting].to_s.downcase
|
||||
description = setting[:description].to_s.downcase
|
||||
plugin = setting[:plugin].to_s.downcase
|
||||
|
||||
search_string = "#{name} #{description} #{plugin}"
|
||||
|
||||
terms.any? { |term| search_string.include?(term) }
|
||||
end
|
||||
|
||||
if found.blank?
|
||||
{
|
||||
args: {
|
||||
query: query,
|
||||
},
|
||||
rows: [],
|
||||
instruction: "no settings matched #{query}, expand your search",
|
||||
}
|
||||
else
|
||||
include_descriptions = false
|
||||
|
||||
if found.length > MAX_RESULTS
|
||||
found = found[0..MAX_RESULTS]
|
||||
elsif found.length < INCLUDE_DESCRIPTIONS_MAX_LENGTH
|
||||
include_descriptions = true
|
||||
end
|
||||
|
||||
@last_num_results = found.length
|
||||
|
||||
format_results(found, args: { query: query }) do |setting|
|
||||
result = { name: setting[:setting] }
|
||||
if include_descriptions
|
||||
result[:description] = setting[:description]
|
||||
result[:plugin] = setting[:plugin]
|
||||
end
|
||||
result
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def description_args
|
||||
{ count: @last_num_results || 0, query: parameters[:query].to_s }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,160 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class SettingContext < Tool
|
||||
MAX_CONTEXT_TOKENS = 2000
|
||||
CODE_FILE_EXTENSIONS = "rb,js,gjs,hbs"
|
||||
|
||||
class << self
|
||||
def rg_installed?
|
||||
if defined?(@rg_installed)
|
||||
@rg_installed
|
||||
else
|
||||
@rg_installed =
|
||||
begin
|
||||
Discourse::Utils.execute_command("which", "rg")
|
||||
true
|
||||
rescue Discourse::Utils::CommandError
|
||||
false
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def signature
|
||||
{
|
||||
name: name,
|
||||
description:
|
||||
"Will provide you with full context regarding a particular site setting in Discourse",
|
||||
parameters: [
|
||||
{
|
||||
name: "setting_name",
|
||||
description: "The name of the site setting we need context for",
|
||||
type: "string",
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def name
|
||||
"setting_context"
|
||||
end
|
||||
end
|
||||
|
||||
def setting_name
|
||||
parameters[:setting_name]
|
||||
end
|
||||
|
||||
def invoke(_bot_user, llm)
|
||||
if !self.class.rg_installed?
|
||||
return(
|
||||
{
|
||||
setting_name: setting_name,
|
||||
context:
|
||||
"This command requires the rg command line tool to be installed on the server",
|
||||
}
|
||||
)
|
||||
end
|
||||
|
||||
if !SiteSetting.has_setting?(setting_name)
|
||||
{ setting_name: setting_name, context: "This setting does not exist" }
|
||||
else
|
||||
description = SiteSetting.description(setting_name)
|
||||
result = +"# #{setting_name}\n#{description}\n\n"
|
||||
|
||||
setting_info =
|
||||
find_setting_info(setting_name, [Rails.root.join("config", "site_settings.yml").to_s])
|
||||
if !setting_info
|
||||
setting_info =
|
||||
find_setting_info(setting_name, Dir[Rails.root.join("plugins/**/settings.yml")])
|
||||
end
|
||||
|
||||
result << setting_info
|
||||
result << "\n\n"
|
||||
|
||||
%w[lib app plugins].each do |dir|
|
||||
path = Rails.root.join(dir).to_s
|
||||
result << Discourse::Utils.execute_command(
|
||||
"rg",
|
||||
setting_name,
|
||||
path,
|
||||
"-g",
|
||||
"!**/spec/**",
|
||||
"-g",
|
||||
"!**/dist/**",
|
||||
"-g",
|
||||
"*.{#{CODE_FILE_EXTENSIONS}}",
|
||||
"-C",
|
||||
"10",
|
||||
"--color",
|
||||
"never",
|
||||
"--heading",
|
||||
"--no-ignore",
|
||||
chdir: path,
|
||||
success_status_codes: [0, 1],
|
||||
)
|
||||
end
|
||||
|
||||
result.gsub!(/^#{Regexp.escape(Rails.root.to_s)}/, "")
|
||||
|
||||
result = llm.tokenizer.truncate(result, MAX_CONTEXT_TOKENS)
|
||||
|
||||
{ setting_name: setting_name, context: result }
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def find_setting_info(name, paths)
|
||||
path, result = nil
|
||||
|
||||
paths.each do |search_path|
|
||||
result =
|
||||
Discourse::Utils.execute_command(
|
||||
"rg",
|
||||
name,
|
||||
search_path,
|
||||
"-g",
|
||||
"*.{#{CODE_FILE_EXTENSIONS}}",
|
||||
"-A",
|
||||
"10",
|
||||
"--color",
|
||||
"never",
|
||||
"--heading",
|
||||
success_status_codes: [0, 1],
|
||||
)
|
||||
if !result.blank?
|
||||
path = search_path
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if result.blank?
|
||||
nil
|
||||
else
|
||||
rows = result.split("\n")
|
||||
leading_spaces = rows[0].match(/^\s*/)[0].length
|
||||
|
||||
filtered = []
|
||||
|
||||
rows.each do |row|
|
||||
if !filtered.blank?
|
||||
break if row.match(/^\s*/)[0].length <= leading_spaces
|
||||
end
|
||||
filtered << row
|
||||
end
|
||||
|
||||
filtered.unshift("#{path}")
|
||||
filtered.join("\n")
|
||||
end
|
||||
end
|
||||
|
||||
def description_args
|
||||
parameters.slice(:setting_name)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,183 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class Summarize < Tool
|
||||
def self.signature
|
||||
{
|
||||
name: name,
|
||||
description: "Will summarize a topic attempting to answer question in guidance",
|
||||
parameters: [
|
||||
{
|
||||
name: "topic_id",
|
||||
description: "The discourse topic id to summarize",
|
||||
type: "integer",
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
name: "guidance",
|
||||
description: "Special guidance on how to summarize the topic",
|
||||
type: "string",
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def self.name
|
||||
"summary"
|
||||
end
|
||||
|
||||
def topic_id
|
||||
parameters[:topic_id].to_i
|
||||
end
|
||||
|
||||
def guidance
|
||||
parameters[:guidance]
|
||||
end
|
||||
|
||||
def chain_next_response?
|
||||
false
|
||||
end
|
||||
|
||||
def standalone?
|
||||
true
|
||||
end
|
||||
|
||||
def low_cost?
|
||||
true
|
||||
end
|
||||
|
||||
def custom_raw
|
||||
@last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found")
|
||||
end
|
||||
|
||||
def invoke(bot_user, llm, &progress_blk)
|
||||
topic = nil
|
||||
if topic_id > 0
|
||||
topic = Topic.find_by(id: topic_id)
|
||||
topic = nil if !topic || !Guardian.new.can_see?(topic)
|
||||
end
|
||||
|
||||
@last_summary = nil
|
||||
|
||||
if topic
|
||||
@last_topic_title = topic.title
|
||||
|
||||
posts =
|
||||
Post
|
||||
.where(topic_id: topic.id)
|
||||
.where("post_type in (?)", [Post.types[:regular], Post.types[:small_action]])
|
||||
.where("not hidden")
|
||||
.order(:post_number)
|
||||
|
||||
columns = ["posts.id", :post_number, :raw, :username]
|
||||
|
||||
current_post_numbers = posts.limit(5).pluck(:post_number)
|
||||
current_post_numbers += posts.reorder("posts.score desc").limit(50).pluck(:post_number)
|
||||
current_post_numbers += posts.reorder("post_number desc").limit(5).pluck(:post_number)
|
||||
|
||||
data =
|
||||
Post
|
||||
.where(topic_id: topic.id)
|
||||
.joins(:user)
|
||||
.where("post_number in (?)", current_post_numbers)
|
||||
.order(:post_number)
|
||||
.pluck(*columns)
|
||||
|
||||
@last_summary = summarize(data, topic, guidance, bot_user, llm, &progress_blk)
|
||||
end
|
||||
|
||||
if !@last_summary
|
||||
"Say: No topic found!"
|
||||
else
|
||||
"Topic summarized"
|
||||
end
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def description_args
|
||||
{ url: "#{Discourse.base_path}/t/-/#{@last_topic_id}", title: @last_topic_title || "" }
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def summarize(data, topic, guidance, bot_user, llm, &progress_blk)
|
||||
text = +""
|
||||
data.each do |id, post_number, raw, username|
|
||||
text << "(#{post_number} #{username} said: #{raw}"
|
||||
end
|
||||
|
||||
summaries = []
|
||||
current_section = +""
|
||||
split = []
|
||||
|
||||
text
|
||||
.split(/\s+/)
|
||||
.each_slice(20) do |slice|
|
||||
current_section << " "
|
||||
current_section << slice.join(" ")
|
||||
|
||||
# somehow any more will get closer to limits
|
||||
if llm.tokenizer.tokenize(current_section).length > 2500
|
||||
split << current_section
|
||||
current_section = +""
|
||||
end
|
||||
end
|
||||
|
||||
split << current_section if current_section.present?
|
||||
|
||||
split = split[0..3] + split[-3..-1] if split.length > 5
|
||||
|
||||
progress = +I18n.t("discourse_ai.ai_bot.summarizing")
|
||||
progress_blk.call(progress)
|
||||
|
||||
split.each do |section|
|
||||
progress << "."
|
||||
progress_blk.call(progress)
|
||||
|
||||
prompt = section_prompt(topic, section, guidance)
|
||||
|
||||
summary = llm.generate(prompt, temperature: 0.6, max_tokens: 400, user: bot_user)
|
||||
|
||||
summaries << summary
|
||||
end
|
||||
|
||||
if summaries.length > 1
|
||||
progress << "."
|
||||
progress_blk.call(progress)
|
||||
|
||||
contatenation_prompt = {
|
||||
insts: "You are a helpful bot",
|
||||
input:
|
||||
"concatenated the disjoint summaries, creating a cohesive narrative:\n#{summaries.join("\n")}}",
|
||||
}
|
||||
|
||||
llm.generate(contatenation_prompt, temperature: 0.6, max_tokens: 500, user: bot_user)
|
||||
else
|
||||
summaries.first
|
||||
end
|
||||
end
|
||||
|
||||
def section_prompt(topic, text, guidance)
|
||||
insts = <<~TEXT
|
||||
You are a summarization bot.
|
||||
You effectively summarise any text.
|
||||
You condense it into a shorter version.
|
||||
You understand and generate Discourse forum markdown.
|
||||
Try generating links as well the format is #{topic.url}/POST_NUMBER. eg: [ref](#{topic.url}/77)
|
||||
TEXT
|
||||
|
||||
{ insts: insts, input: <<~TEXT }
|
||||
Guidance: #{guidance}
|
||||
You are summarizing the topic: #{topic.title}
|
||||
Summarize the following in 400 words:
|
||||
|
||||
#{text}
|
||||
TEXT
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,52 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class Time < Tool
|
||||
def self.signature
|
||||
{
|
||||
name: name,
|
||||
description: "Will generate the time in a timezone",
|
||||
parameters: [
|
||||
{
|
||||
name: "timezone",
|
||||
description: "ALWAYS supply a Ruby compatible timezone",
|
||||
type: "string",
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def self.name
|
||||
"time"
|
||||
end
|
||||
|
||||
def timezone
|
||||
parameters[:timezone].to_s
|
||||
end
|
||||
|
||||
def invoke(_bot_user, _llm)
|
||||
time =
|
||||
begin
|
||||
::Time.now.in_time_zone(timezone)
|
||||
rescue StandardError
|
||||
nil
|
||||
end
|
||||
time = ::Time.now if !time
|
||||
|
||||
@last_time = time.to_s
|
||||
|
||||
{ args: { timezone: timezone }, time: time.to_s }
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def description_args
|
||||
{ timezone: timezone, time: @last_time }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,124 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class Tool
|
||||
class << self
|
||||
def signature
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def name
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def accepted_options
|
||||
[]
|
||||
end
|
||||
|
||||
def option(name, type:)
|
||||
Option.new(tool: self, name: name, type: type)
|
||||
end
|
||||
|
||||
def help
|
||||
I18n.t("discourse_ai.ai_bot.command_help.#{signature[:name]}")
|
||||
end
|
||||
|
||||
def custom_system_message
|
||||
nil
|
||||
end
|
||||
end
|
||||
|
||||
attr_accessor :custom_raw
|
||||
|
||||
def initialize(parameters, tool_call_id: "", persona_options: {})
|
||||
@parameters = parameters
|
||||
@tool_call_id = tool_call_id
|
||||
@persona_options = persona_options
|
||||
end
|
||||
|
||||
attr_reader :parameters, :tool_call_id
|
||||
|
||||
def name
|
||||
self.class.name
|
||||
end
|
||||
|
||||
def summary
|
||||
I18n.t("discourse_ai.ai_bot.command_summary.#{name}")
|
||||
end
|
||||
|
||||
def details
|
||||
I18n.t("discourse_ai.ai_bot.command_description.#{name}", description_args)
|
||||
end
|
||||
|
||||
def help
|
||||
I18n.t("discourse_ai.ai_bot.command_help.#{name}")
|
||||
end
|
||||
|
||||
def options
|
||||
self
|
||||
.class
|
||||
.accepted_options
|
||||
.reduce(HashWithIndifferentAccess.new) do |memo, option|
|
||||
val = @persona_options[option.name]
|
||||
memo[option.name] = val if val
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
||||
def chain_next_response?
|
||||
true
|
||||
end
|
||||
|
||||
def standalone?
|
||||
false
|
||||
end
|
||||
|
||||
def low_cost?
|
||||
false
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def accepted_options
|
||||
[]
|
||||
end
|
||||
|
||||
def option(name, type:)
|
||||
Option.new(tool: self, name: name, type: type)
|
||||
end
|
||||
|
||||
def description_args
|
||||
{}
|
||||
end
|
||||
|
||||
def format_results(rows, column_names = nil, args: nil)
|
||||
rows = rows&.map { |row| yield row } if block_given?
|
||||
|
||||
if !column_names
|
||||
index = -1
|
||||
column_indexes = {}
|
||||
|
||||
rows =
|
||||
rows&.map do |data|
|
||||
new_row = []
|
||||
data.each do |key, value|
|
||||
found_index = column_indexes[key.to_s] ||= (index += 1)
|
||||
new_row[found_index] = value
|
||||
end
|
||||
new_row
|
||||
end
|
||||
column_names = column_indexes.keys
|
||||
end
|
||||
|
||||
# this is not the most efficient format
|
||||
# however this is needed cause GPT 3.5 / 4 was steered using JSON
|
||||
result = { column_names: column_names, rows: rows }
|
||||
result[:args] = args if args
|
||||
result
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -46,10 +46,9 @@ module DiscourseAi
|
|||
prompt[:tools].map do |t|
|
||||
tool = t.dup
|
||||
|
||||
if tool[:parameters]
|
||||
tool[:parameters] = t[:parameters].reduce(
|
||||
{ type: "object", properties: {}, required: [] },
|
||||
) do |memo, p|
|
||||
tool[:parameters] = t[:parameters]
|
||||
.to_a
|
||||
.reduce({ type: "object", properties: {}, required: [] }) do |memo, p|
|
||||
name = p[:name]
|
||||
memo[:required] << name if p[:required]
|
||||
|
||||
|
@ -58,7 +57,6 @@ module DiscourseAi
|
|||
memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
||||
{ type: "function", function: tool }
|
||||
end
|
||||
|
@ -71,9 +69,12 @@ module DiscourseAi
|
|||
|
||||
trimmed_context.reverse.map do |context|
|
||||
if context[:type] == "tool_call"
|
||||
function = JSON.parse(context[:content], symbolize_names: true)
|
||||
function[:arguments] = function[:arguments].to_json
|
||||
|
||||
{
|
||||
role: "assistant",
|
||||
tool_calls: [{ type: "function", function: context[:content], id: context[:name] }],
|
||||
tool_calls: [{ type: "function", function: function, id: context[:name] }],
|
||||
}
|
||||
else
|
||||
translated = context.slice(:content)
|
||||
|
|
|
@ -39,12 +39,12 @@ module DiscourseAi
|
|||
def conversation_context
|
||||
return "" if prompt[:conversation_context].blank?
|
||||
|
||||
trimmed_context = trim_context(prompt[:conversation_context])
|
||||
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
|
||||
trimmed_context = trim_context(clean_context)
|
||||
|
||||
trimmed_context
|
||||
.reverse
|
||||
.reduce(+"") do |memo, context|
|
||||
next(memo) if context[:type] == "tool_call"
|
||||
memo << (context[:type] == "user" ? "Human:" : "Assistant:")
|
||||
|
||||
if context[:type] == "tool"
|
||||
|
|
|
@ -97,6 +97,13 @@ module DiscourseAi
|
|||
|
||||
message_tokens = calculate_message_token(dupped_context)
|
||||
|
||||
# Don't trim tool call metadata.
|
||||
if context[:type] == "tool_call"
|
||||
current_token_count += calculate_message_token(context) + per_message_overhead
|
||||
memo << context
|
||||
next(memo)
|
||||
end
|
||||
|
||||
# Trimming content to make sure we respect token limit.
|
||||
while dupped_context[:content].present? &&
|
||||
message_tokens + current_token_count + per_message_overhead > prompt_limit
|
||||
|
|
|
@ -39,12 +39,13 @@ module DiscourseAi
|
|||
def conversation_context
|
||||
return "" if prompt[:conversation_context].blank?
|
||||
|
||||
trimmed_context = trim_context(prompt[:conversation_context])
|
||||
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
|
||||
|
||||
trimmed_context = trim_context(clean_context)
|
||||
|
||||
trimmed_context
|
||||
.reverse
|
||||
.reduce(+"") do |memo, context|
|
||||
next(memo) if context[:type] == "tool_call"
|
||||
if context[:type] == "tool"
|
||||
memo << <<~TEXT
|
||||
[INST]
|
||||
|
|
|
@ -39,12 +39,12 @@ module DiscourseAi
|
|||
def conversation_context
|
||||
return "" if prompt[:conversation_context].blank?
|
||||
|
||||
trimmed_context = trim_context(prompt[:conversation_context])
|
||||
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
|
||||
trimmed_context = trim_context(clean_context)
|
||||
|
||||
trimmed_context
|
||||
.reverse
|
||||
.reduce(+"") do |memo, context|
|
||||
next(memo) if context[:type] == "tool_call"
|
||||
memo << "[INST] " if context[:type] == "user"
|
||||
|
||||
if context[:type] == "tool"
|
||||
|
|
|
@ -36,12 +36,12 @@ module DiscourseAi
|
|||
def conversation_context
|
||||
return "" if prompt[:conversation_context].blank?
|
||||
|
||||
trimmed_context = trim_context(prompt[:conversation_context])
|
||||
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
|
||||
trimmed_context = trim_context(clean_context)
|
||||
|
||||
trimmed_context
|
||||
.reverse
|
||||
.reduce(+"") do |memo, context|
|
||||
next(memo) if context[:type] == "tool_call"
|
||||
memo << (context[:type] == "user" ? "### User:" : "### Assistant:")
|
||||
|
||||
if context[:type] == "tool"
|
||||
|
|
|
@ -23,7 +23,7 @@ module DiscourseAi
|
|||
def default_options
|
||||
{
|
||||
model: model,
|
||||
max_tokens_to_sample: 2_000,
|
||||
max_tokens_to_sample: 3_000,
|
||||
stop_sequences: ["\n\nHuman:", "</function_calls>"],
|
||||
}
|
||||
end
|
||||
|
|
|
@ -26,11 +26,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def default_options
|
||||
{
|
||||
model: model,
|
||||
max_tokens_to_sample: 2_000,
|
||||
stop_sequences: ["\n\nHuman:", "</function_calls>"],
|
||||
}
|
||||
{ max_tokens_to_sample: 3_000, stop_sequences: ["\n\nHuman:", "</function_calls>"] }
|
||||
end
|
||||
|
||||
def provider_id
|
||||
|
|
|
@ -87,6 +87,8 @@ module DiscourseAi
|
|||
return response_data
|
||||
end
|
||||
|
||||
has_tool = false
|
||||
|
||||
begin
|
||||
cancelled = false
|
||||
cancel = lambda { cancelled = true }
|
||||
|
@ -129,17 +131,19 @@ module DiscourseAi
|
|||
partial = extract_completion_from(raw_partial)
|
||||
next if response_data.empty? && partial.blank?
|
||||
next if partial.nil?
|
||||
partials_raw << partial.to_s
|
||||
|
||||
# Skip yield for tools. We'll buffer and yield later.
|
||||
if has_tool?(partials_raw)
|
||||
# Stop streaming the response as soon as you find a tool.
|
||||
# We'll buffer and yield it later.
|
||||
has_tool = true if has_tool?(partials_raw)
|
||||
|
||||
if has_tool
|
||||
function_buffer = add_to_buffer(function_buffer, partials_raw, partial)
|
||||
else
|
||||
response_data << partial
|
||||
|
||||
yield partial, cancel if partial
|
||||
end
|
||||
|
||||
partials_raw << partial.to_s
|
||||
rescue JSON::ParserError
|
||||
leftover = redo_chunk
|
||||
json_error = true
|
||||
|
@ -158,7 +162,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
# Once we have the full response, try to return the tool as a XML doc.
|
||||
if has_tool?(partials_raw)
|
||||
if has_tool
|
||||
if function_buffer.at("tool_name").text.present?
|
||||
invocation = +function_buffer.at("function_calls").to_s
|
||||
invocation << "\n"
|
||||
|
@ -264,7 +268,7 @@ module DiscourseAi
|
|||
|
||||
read_function = Nokogiri::HTML5.fragment(raw_data)
|
||||
|
||||
if tool_name = read_function.at("tool_name").text
|
||||
if tool_name = read_function.at("tool_name")&.text
|
||||
function_buffer.at("tool_name").inner_html = tool_name
|
||||
function_buffer.at("tool_id").inner_html = tool_name
|
||||
end
|
||||
|
@ -272,7 +276,8 @@ module DiscourseAi
|
|||
_read_parameters =
|
||||
read_function
|
||||
.at("parameters")
|
||||
.elements
|
||||
&.elements
|
||||
.to_a
|
||||
.each do |elem|
|
||||
if paramenter = function_buffer.at(elem.name)&.text
|
||||
function_buffer.at(elem.name).inner_html = paramenter
|
||||
|
|
|
@ -20,7 +20,10 @@ module DiscourseAi
|
|||
def self.with_prepared_responses(responses)
|
||||
@canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses)
|
||||
|
||||
yield(@canned_response).tap { @canned_response = nil }
|
||||
yield(@canned_response)
|
||||
ensure
|
||||
# Don't leak prepared response if there's an exception.
|
||||
@canned_response = nil
|
||||
end
|
||||
|
||||
def self.proxy(model_name)
|
||||
|
@ -119,9 +122,15 @@ module DiscourseAi
|
|||
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
dialect_klass.new({}, model_name).max_prompt_tokens
|
||||
end
|
||||
|
||||
attr_reader :model_name
|
||||
|
||||
private
|
||||
|
||||
attr_reader :dialect_klass, :gateway, :model_name
|
||||
attr_reader :dialect_klass, :gateway
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,168 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "base64"
|
||||
require "json"
|
||||
require "aws-eventstream"
|
||||
require "aws-sigv4"
|
||||
|
||||
module ::DiscourseAi
|
||||
module Inference
|
||||
class AmazonBedrockInference
|
||||
CompletionFailed = Class.new(StandardError)
|
||||
TIMEOUT = 60
|
||||
|
||||
def self.perform!(
|
||||
prompt,
|
||||
model = "anthropic.claude-v2",
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
top_k: nil,
|
||||
max_tokens: 20_000,
|
||||
user_id: nil,
|
||||
stop_sequences: nil,
|
||||
tokenizer: Tokenizer::AnthropicTokenizer
|
||||
)
|
||||
raise CompletionFailed if model.blank?
|
||||
raise CompletionFailed if !SiteSetting.ai_bedrock_access_key_id.present?
|
||||
raise CompletionFailed if !SiteSetting.ai_bedrock_secret_access_key.present?
|
||||
raise CompletionFailed if !SiteSetting.ai_bedrock_region.present?
|
||||
|
||||
signer =
|
||||
Aws::Sigv4::Signer.new(
|
||||
access_key_id: SiteSetting.ai_bedrock_access_key_id,
|
||||
region: SiteSetting.ai_bedrock_region,
|
||||
secret_access_key: SiteSetting.ai_bedrock_secret_access_key,
|
||||
service: "bedrock",
|
||||
)
|
||||
|
||||
log = nil
|
||||
response_data = +""
|
||||
response_raw = +""
|
||||
|
||||
url_api =
|
||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{model}/"
|
||||
if block_given?
|
||||
url_api = url_api + "invoke-with-response-stream"
|
||||
else
|
||||
url_api = url_api + "invoke"
|
||||
end
|
||||
|
||||
url = URI(url_api)
|
||||
headers = { "content-type" => "application/json", "Accept" => "*/*" }
|
||||
|
||||
payload = { prompt: prompt }
|
||||
|
||||
payload[:top_p] = top_p if top_p
|
||||
payload[:top_k] = top_k if top_k
|
||||
payload[:max_tokens_to_sample] = max_tokens || 2000
|
||||
payload[:temperature] = temperature if temperature
|
||||
payload[:stop_sequences] = stop_sequences if stop_sequences
|
||||
|
||||
Net::HTTP.start(
|
||||
url.host,
|
||||
url.port,
|
||||
use_ssl: true,
|
||||
read_timeout: TIMEOUT,
|
||||
open_timeout: TIMEOUT,
|
||||
write_timeout: TIMEOUT,
|
||||
) do |http|
|
||||
request = Net::HTTP::Post.new(url)
|
||||
request_body = payload.to_json
|
||||
request.body = request_body
|
||||
|
||||
signed_request =
|
||||
signer.sign_request(
|
||||
req: request,
|
||||
http_method: request.method,
|
||||
url: url,
|
||||
body: request.body,
|
||||
)
|
||||
|
||||
request.initialize_http_header(headers.merge!(signed_request.headers))
|
||||
|
||||
http.request(request) do |response|
|
||||
if response.code.to_i != 200
|
||||
Rails.logger.error(
|
||||
"BedRockInference: status: #{response.code.to_i} - body: #{response.body}",
|
||||
)
|
||||
raise CompletionFailed
|
||||
end
|
||||
|
||||
log =
|
||||
AiApiAuditLog.create!(
|
||||
provider_id: AiApiAuditLog::Provider::Anthropic,
|
||||
raw_request_payload: request_body,
|
||||
user_id: user_id,
|
||||
)
|
||||
|
||||
if !block_given?
|
||||
response_body = response.read_body
|
||||
parsed_response = JSON.parse(response_body, symbolize_names: true)
|
||||
|
||||
log.update!(
|
||||
raw_response_payload: response_body,
|
||||
request_tokens: tokenizer.size(prompt),
|
||||
response_tokens: tokenizer.size(parsed_response[:completion]),
|
||||
)
|
||||
return parsed_response
|
||||
end
|
||||
|
||||
begin
|
||||
cancelled = false
|
||||
cancel = lambda { cancelled = true }
|
||||
decoder = Aws::EventStream::Decoder.new
|
||||
|
||||
response.read_body do |chunk|
|
||||
if cancelled
|
||||
http.finish
|
||||
return
|
||||
end
|
||||
|
||||
response_raw << chunk
|
||||
|
||||
begin
|
||||
message = decoder.decode_chunk(chunk)
|
||||
|
||||
partial =
|
||||
message
|
||||
.first
|
||||
.payload
|
||||
.string
|
||||
.then { JSON.parse(_1) }
|
||||
.dig("bytes")
|
||||
.then { Base64.decode64(_1) }
|
||||
.then { JSON.parse(_1, symbolize_names: true) }
|
||||
|
||||
next if !partial
|
||||
|
||||
response_data << partial[:completion].to_s
|
||||
|
||||
yield partial, cancel if partial[:completion]
|
||||
rescue JSON::ParserError,
|
||||
Aws::EventStream::Errors::MessageChecksumError,
|
||||
Aws::EventStream::Errors::PreludeChecksumError => e
|
||||
Rails.logger.error("BedrockInference: #{e}")
|
||||
end
|
||||
rescue IOError
|
||||
raise if !cancelled
|
||||
end
|
||||
end
|
||||
|
||||
return response_data
|
||||
end
|
||||
ensure
|
||||
if block_given?
|
||||
log.update!(
|
||||
raw_response_payload: response_data,
|
||||
request_tokens: tokenizer.size(prompt),
|
||||
response_tokens: tokenizer.size(response_data),
|
||||
)
|
||||
end
|
||||
if Rails.env.development? && log
|
||||
puts "BedrockInference: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,158 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAi
|
||||
module Inference
|
||||
class AnthropicCompletions
|
||||
CompletionFailed = Class.new(StandardError)
|
||||
TIMEOUT = 60
|
||||
|
||||
def self.perform!(
|
||||
prompt,
|
||||
model = "claude-2",
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
user_id: nil,
|
||||
stop_sequences: nil,
|
||||
post: nil,
|
||||
&blk
|
||||
)
|
||||
# HACK to get around the fact that they have different APIs
|
||||
# we will introduce a proper LLM abstraction layer to handle this shenanigas later this year
|
||||
if model == "claude-2" && SiteSetting.ai_bedrock_access_key_id.present? &&
|
||||
SiteSetting.ai_bedrock_secret_access_key.present? &&
|
||||
SiteSetting.ai_bedrock_region.present?
|
||||
return(
|
||||
AmazonBedrockInference.perform!(
|
||||
prompt,
|
||||
temperature: temperature,
|
||||
top_p: top_p,
|
||||
max_tokens: max_tokens,
|
||||
user_id: user_id,
|
||||
stop_sequences: stop_sequences,
|
||||
&blk
|
||||
)
|
||||
)
|
||||
end
|
||||
|
||||
log = nil
|
||||
response_data = +""
|
||||
response_raw = +""
|
||||
|
||||
url = URI("https://api.anthropic.com/v1/complete")
|
||||
headers = {
|
||||
"anthropic-version" => "2023-06-01",
|
||||
"x-api-key" => SiteSetting.ai_anthropic_api_key,
|
||||
"content-type" => "application/json",
|
||||
}
|
||||
|
||||
payload = { model: model, prompt: prompt }
|
||||
|
||||
payload[:top_p] = top_p if top_p
|
||||
payload[:max_tokens_to_sample] = max_tokens || 2000
|
||||
payload[:temperature] = temperature if temperature
|
||||
payload[:stream] = true if block_given?
|
||||
payload[:stop_sequences] = stop_sequences if stop_sequences
|
||||
|
||||
Net::HTTP.start(
|
||||
url.host,
|
||||
url.port,
|
||||
use_ssl: true,
|
||||
read_timeout: TIMEOUT,
|
||||
open_timeout: TIMEOUT,
|
||||
write_timeout: TIMEOUT,
|
||||
) do |http|
|
||||
request = Net::HTTP::Post.new(url, headers)
|
||||
request_body = payload.to_json
|
||||
request.body = request_body
|
||||
|
||||
http.request(request) do |response|
|
||||
if response.code.to_i != 200
|
||||
Rails.logger.error(
|
||||
"AnthropicCompletions: status: #{response.code.to_i} - body: #{response.body}",
|
||||
)
|
||||
raise CompletionFailed
|
||||
end
|
||||
|
||||
log =
|
||||
AiApiAuditLog.create!(
|
||||
provider_id: AiApiAuditLog::Provider::Anthropic,
|
||||
raw_request_payload: request_body,
|
||||
user_id: user_id,
|
||||
post_id: post&.id,
|
||||
topic_id: post&.topic_id,
|
||||
)
|
||||
|
||||
if !block_given?
|
||||
response_body = response.read_body
|
||||
parsed_response = JSON.parse(response_body, symbolize_names: true)
|
||||
|
||||
log.update!(
|
||||
raw_response_payload: response_body,
|
||||
request_tokens: DiscourseAi::Tokenizer::AnthropicTokenizer.size(prompt),
|
||||
response_tokens:
|
||||
DiscourseAi::Tokenizer::AnthropicTokenizer.size(parsed_response[:completion]),
|
||||
)
|
||||
return parsed_response
|
||||
end
|
||||
|
||||
begin
|
||||
cancelled = false
|
||||
cancel = lambda { cancelled = true }
|
||||
|
||||
response.read_body do |chunk|
|
||||
if cancelled
|
||||
http.finish
|
||||
return
|
||||
end
|
||||
|
||||
response_raw << chunk
|
||||
|
||||
chunk
|
||||
.split("\n")
|
||||
.each do |line|
|
||||
data = line.split("data: ", 2)[1]
|
||||
next if !data
|
||||
|
||||
if !cancelled
|
||||
begin
|
||||
partial = JSON.parse(data, symbolize_names: true)
|
||||
response_data << partial[:completion].to_s
|
||||
|
||||
# ping has no data... do not yeild it
|
||||
yield partial, cancel if partial[:completion]
|
||||
rescue JSON::ParserError
|
||||
nil
|
||||
# TODO leftover chunk carry over to next
|
||||
end
|
||||
end
|
||||
end
|
||||
rescue IOError
|
||||
raise if !cancelled
|
||||
end
|
||||
end
|
||||
|
||||
return response_data
|
||||
end
|
||||
ensure
|
||||
if block_given?
|
||||
log.update!(
|
||||
raw_response_payload: response_raw,
|
||||
request_tokens: DiscourseAi::Tokenizer::AnthropicTokenizer.size(prompt),
|
||||
response_tokens: DiscourseAi::Tokenizer::AnthropicTokenizer.size(response_data),
|
||||
)
|
||||
end
|
||||
if Rails.env.development? && log
|
||||
puts "AnthropicCompletions: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
|
||||
end
|
||||
end
|
||||
|
||||
def self.try_parse(data)
|
||||
JSON.parse(data, symbolize_names: true)
|
||||
rescue JSON::ParserError
|
||||
nil
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,69 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAi
|
||||
module Inference
|
||||
class Function
|
||||
attr_reader :name, :description, :parameters, :type
|
||||
|
||||
def initialize(name:, description:, type: nil)
|
||||
@name = name
|
||||
@description = description
|
||||
@type = type || "object"
|
||||
@parameters = []
|
||||
end
|
||||
|
||||
def add_parameter(parameter = nil, **kwargs)
|
||||
if parameter
|
||||
add_parameter_kwargs(
|
||||
name: parameter.name,
|
||||
type: parameter.type,
|
||||
description: parameter.description,
|
||||
required: parameter.required,
|
||||
enum: parameter.enum,
|
||||
item_type: parameter.item_type,
|
||||
)
|
||||
else
|
||||
add_parameter_kwargs(**kwargs)
|
||||
end
|
||||
end
|
||||
|
||||
def add_parameter_kwargs(
|
||||
name:,
|
||||
type:,
|
||||
description:,
|
||||
enum: nil,
|
||||
required: false,
|
||||
item_type: nil
|
||||
)
|
||||
param = { name: name, type: type, description: description, enum: enum, required: required }
|
||||
param[:enum] = enum if enum
|
||||
param[:item_type] = item_type if item_type
|
||||
|
||||
@parameters << param
|
||||
end
|
||||
|
||||
def to_json(*args)
|
||||
as_json.to_json(*args)
|
||||
end
|
||||
|
||||
def as_json
|
||||
required_params = []
|
||||
|
||||
properties = {}
|
||||
parameters.each do |parameter|
|
||||
definition = { type: parameter[:type], description: parameter[:description] }
|
||||
definition[:enum] = parameter[:enum] if parameter[:enum]
|
||||
definition[:items] = { type: parameter[:item_type] } if parameter[:item_type]
|
||||
required_params << parameter[:name] if parameter[:required]
|
||||
properties[parameter[:name]] = definition
|
||||
end
|
||||
|
||||
params = { type: @type, properties: properties }
|
||||
|
||||
params[:required] = required_params if required_params.present?
|
||||
|
||||
{ name: name, description: description, parameters: params }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,119 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAi
|
||||
module Inference
|
||||
class FunctionList
|
||||
def initialize
|
||||
@functions = []
|
||||
end
|
||||
|
||||
def <<(function)
|
||||
@functions << function
|
||||
end
|
||||
|
||||
def parse_prompt(prompt)
|
||||
xml = prompt.sub(%r{<function_calls>(.*)</function_calls>}m, '\1')
|
||||
if xml.present?
|
||||
parsed = []
|
||||
Nokogiri
|
||||
.XML(xml)
|
||||
.xpath("//invoke")
|
||||
.each do |invoke_node|
|
||||
function = { name: invoke_node.xpath("//tool_name").text, arguments: {} }
|
||||
parsed << function
|
||||
invoke_node
|
||||
.xpath("//parameters")
|
||||
.children
|
||||
.each do |parameters_node|
|
||||
if parameters_node.is_a?(Nokogiri::XML::Element) && name = parameters_node.name
|
||||
function[:arguments][name.to_sym] = parameters_node.text
|
||||
end
|
||||
end
|
||||
end
|
||||
coerce_arguments!(parsed)
|
||||
end
|
||||
end
|
||||
|
||||
def coerce_arguments!(parsed)
|
||||
parsed.each do |function_call|
|
||||
arguments = function_call[:arguments]
|
||||
|
||||
function = @functions.find { |f| f.name == function_call[:name] }
|
||||
next if !function
|
||||
|
||||
arguments.each do |name, value|
|
||||
parameter = function.parameters.find { |p| p[:name].to_s == name.to_s }
|
||||
if !parameter
|
||||
arguments.delete(name)
|
||||
next
|
||||
end
|
||||
|
||||
type = parameter[:type]
|
||||
if type == "array"
|
||||
begin
|
||||
arguments[name] = JSON.parse(value)
|
||||
rescue JSON::ParserError
|
||||
# maybe LLM chose a different shape for the array
|
||||
arguments[name] = value.to_s.split("\n").map(&:strip).reject(&:blank?)
|
||||
end
|
||||
elsif type == "integer"
|
||||
arguments[name] = value.to_i
|
||||
elsif type == "float"
|
||||
arguments[name] = value.to_f
|
||||
end
|
||||
end
|
||||
end
|
||||
parsed
|
||||
end
|
||||
|
||||
def system_prompt
|
||||
tools = +""
|
||||
|
||||
@functions.each do |function|
|
||||
parameters = +""
|
||||
if function.parameters.present?
|
||||
parameters << "\n"
|
||||
function.parameters.each do |parameter|
|
||||
parameters << <<~PARAMETER
|
||||
<parameter>
|
||||
<name>#{parameter[:name]}</name>
|
||||
<type>#{parameter[:type]}</type>
|
||||
<description>#{parameter[:description]}</description>
|
||||
<required>#{parameter[:required]}</required>
|
||||
PARAMETER
|
||||
parameters << "<options>#{parameter[:enum].join(",")}</options>\n" if parameter[:enum]
|
||||
parameters << "</parameter>\n"
|
||||
end
|
||||
end
|
||||
|
||||
tools << <<~TOOLS
|
||||
<tool_description>
|
||||
<tool_name>#{function.name}</tool_name>
|
||||
<description>#{function.description}</description>
|
||||
<parameters>#{parameters}</parameters>
|
||||
</tool_description>
|
||||
TOOLS
|
||||
end
|
||||
|
||||
<<~PROMPT
|
||||
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>$TOOL_NAME</tool_name>
|
||||
<parameters>
|
||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||
...
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Here are the tools available:
|
||||
|
||||
<tools>
|
||||
#{tools}</tools>
|
||||
PROMPT
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,146 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAi
|
||||
module Inference
|
||||
class HuggingFaceTextGeneration
|
||||
CompletionFailed = Class.new(StandardError)
|
||||
TIMEOUT = 120
|
||||
|
||||
def self.perform!(
|
||||
prompt,
|
||||
model,
|
||||
temperature: 0.7,
|
||||
top_p: nil,
|
||||
top_k: nil,
|
||||
typical_p: nil,
|
||||
max_tokens: 2000,
|
||||
repetition_penalty: 1.1,
|
||||
user_id: nil,
|
||||
tokenizer: DiscourseAi::Tokenizer::Llama2Tokenizer,
|
||||
token_limit: nil
|
||||
)
|
||||
raise CompletionFailed if model.blank?
|
||||
|
||||
url = URI(SiteSetting.ai_hugging_face_api_url)
|
||||
headers = { "Content-Type" => "application/json" }
|
||||
|
||||
if SiteSetting.ai_hugging_face_api_key.present?
|
||||
headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_api_key}"
|
||||
end
|
||||
|
||||
token_limit = token_limit || SiteSetting.ai_hugging_face_token_limit
|
||||
|
||||
parameters = {}
|
||||
payload = { inputs: prompt, parameters: parameters }
|
||||
prompt_size = tokenizer.size(prompt)
|
||||
|
||||
parameters[:top_p] = top_p if top_p
|
||||
parameters[:top_k] = top_k if top_k
|
||||
parameters[:typical_p] = typical_p if typical_p
|
||||
parameters[:max_new_tokens] = token_limit - prompt_size
|
||||
parameters[:temperature] = temperature if temperature
|
||||
parameters[:repetition_penalty] = repetition_penalty if repetition_penalty
|
||||
parameters[:return_full_text] = false
|
||||
|
||||
payload[:stream] = true if block_given?
|
||||
|
||||
Net::HTTP.start(
|
||||
url.host,
|
||||
url.port,
|
||||
use_ssl: url.scheme == "https",
|
||||
read_timeout: TIMEOUT,
|
||||
open_timeout: TIMEOUT,
|
||||
write_timeout: TIMEOUT,
|
||||
) do |http|
|
||||
request = Net::HTTP::Post.new(url, headers)
|
||||
request_body = payload.to_json
|
||||
request.body = request_body
|
||||
|
||||
http.request(request) do |response|
|
||||
if response.code.to_i != 200
|
||||
Rails.logger.error(
|
||||
"HuggingFaceTextGeneration: status: #{response.code.to_i} - body: #{response.body}",
|
||||
)
|
||||
raise CompletionFailed
|
||||
end
|
||||
|
||||
log =
|
||||
AiApiAuditLog.create!(
|
||||
provider_id: AiApiAuditLog::Provider::HuggingFaceTextGeneration,
|
||||
raw_request_payload: request_body,
|
||||
user_id: user_id,
|
||||
)
|
||||
|
||||
if !block_given?
|
||||
response_body = response.read_body
|
||||
parsed_response = JSON.parse(response_body, symbolize_names: true)
|
||||
|
||||
log.update!(
|
||||
raw_response_payload: response_body,
|
||||
request_tokens: tokenizer.size(prompt),
|
||||
response_tokens: tokenizer.size(parsed_response.first[:generated_text]),
|
||||
)
|
||||
return parsed_response
|
||||
end
|
||||
|
||||
response_data = +""
|
||||
|
||||
begin
|
||||
cancelled = false
|
||||
cancel = lambda { cancelled = true }
|
||||
response_raw = +""
|
||||
|
||||
response.read_body do |chunk|
|
||||
if cancelled
|
||||
http.finish
|
||||
return
|
||||
end
|
||||
|
||||
response_raw << chunk
|
||||
|
||||
chunk
|
||||
.split("\n")
|
||||
.each do |line|
|
||||
data = line.split("data:", 2)[1]
|
||||
next if !data || data.squish == "[DONE]"
|
||||
|
||||
if !cancelled
|
||||
begin
|
||||
# partial contains the entire payload till now
|
||||
partial = JSON.parse(data, symbolize_names: true)
|
||||
|
||||
# this is the last chunk and contains the full response
|
||||
next if partial[:token][:special] == true
|
||||
|
||||
response_data << partial[:token][:text].to_s
|
||||
|
||||
yield partial, cancel
|
||||
rescue JSON::ParserError
|
||||
nil
|
||||
end
|
||||
end
|
||||
end
|
||||
rescue IOError
|
||||
raise if !cancelled
|
||||
ensure
|
||||
log.update!(
|
||||
raw_response_payload: response_raw,
|
||||
request_tokens: tokenizer.size(prompt),
|
||||
response_tokens: tokenizer.size(response_data),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
return response_data
|
||||
end
|
||||
end
|
||||
|
||||
def self.try_parse(data)
|
||||
JSON.parse(data, symbolize_names: true)
|
||||
rescue JSON::ParserError
|
||||
nil
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,194 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAi
|
||||
module Inference
|
||||
class OpenAiCompletions
|
||||
TIMEOUT = 60
|
||||
DEFAULT_RETRIES = 3
|
||||
DEFAULT_RETRY_TIMEOUT_SECONDS = 3
|
||||
RETRY_TIMEOUT_BACKOFF_MULTIPLIER = 3
|
||||
|
||||
CompletionFailed = Class.new(StandardError)
|
||||
|
||||
def self.perform!(
|
||||
messages,
|
||||
model,
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
functions: nil,
|
||||
user_id: nil,
|
||||
retries: DEFAULT_RETRIES,
|
||||
retry_timeout: DEFAULT_RETRY_TIMEOUT_SECONDS,
|
||||
post: nil,
|
||||
&blk
|
||||
)
|
||||
log = nil
|
||||
response_data = +""
|
||||
response_raw = +""
|
||||
|
||||
url =
|
||||
if model.include?("gpt-4")
|
||||
if model.include?("turbo") || model.include?("1106-preview")
|
||||
URI(SiteSetting.ai_openai_gpt4_turbo_url)
|
||||
elsif model.include?("32k")
|
||||
URI(SiteSetting.ai_openai_gpt4_32k_url)
|
||||
else
|
||||
URI(SiteSetting.ai_openai_gpt4_url)
|
||||
end
|
||||
else
|
||||
if model.include?("16k")
|
||||
URI(SiteSetting.ai_openai_gpt35_16k_url)
|
||||
else
|
||||
URI(SiteSetting.ai_openai_gpt35_url)
|
||||
end
|
||||
end
|
||||
headers = { "Content-Type" => "application/json" }
|
||||
|
||||
if url.host.include?("azure")
|
||||
headers["api-key"] = SiteSetting.ai_openai_api_key
|
||||
else
|
||||
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
||||
end
|
||||
|
||||
if SiteSetting.ai_openai_organization.present?
|
||||
headers["OpenAI-Organization"] = SiteSetting.ai_openai_organization
|
||||
end
|
||||
|
||||
payload = { model: model, messages: messages }
|
||||
|
||||
payload[:temperature] = temperature if temperature
|
||||
payload[:top_p] = top_p if top_p
|
||||
payload[:max_tokens] = max_tokens if max_tokens
|
||||
payload[:functions] = functions if functions
|
||||
payload[:stream] = true if block_given?
|
||||
|
||||
Net::HTTP.start(
|
||||
url.host,
|
||||
url.port,
|
||||
use_ssl: true,
|
||||
read_timeout: TIMEOUT,
|
||||
open_timeout: TIMEOUT,
|
||||
write_timeout: TIMEOUT,
|
||||
) do |http|
|
||||
request = Net::HTTP::Post.new(url, headers)
|
||||
request_body = payload.to_json
|
||||
request.body = request_body
|
||||
|
||||
http.request(request) do |response|
|
||||
if retries > 0 && response.code.to_i == 429
|
||||
sleep(retry_timeout)
|
||||
retries -= 1
|
||||
retry_timeout *= RETRY_TIMEOUT_BACKOFF_MULTIPLIER
|
||||
return(
|
||||
perform!(
|
||||
messages,
|
||||
model,
|
||||
temperature: temperature,
|
||||
top_p: top_p,
|
||||
max_tokens: max_tokens,
|
||||
functions: functions,
|
||||
user_id: user_id,
|
||||
retries: retries,
|
||||
retry_timeout: retry_timeout,
|
||||
&blk
|
||||
)
|
||||
)
|
||||
elsif response.code.to_i != 200
|
||||
Rails.logger.error(
|
||||
"OpenAiCompletions: status: #{response.code.to_i} - body: #{response.body}",
|
||||
)
|
||||
raise CompletionFailed, "status: #{response.code.to_i} - body: #{response.body}"
|
||||
end
|
||||
|
||||
log =
|
||||
AiApiAuditLog.create!(
|
||||
provider_id: AiApiAuditLog::Provider::OpenAI,
|
||||
raw_request_payload: request_body,
|
||||
user_id: user_id,
|
||||
post_id: post&.id,
|
||||
topic_id: post&.topic_id,
|
||||
)
|
||||
|
||||
if !blk
|
||||
response_body = response.read_body
|
||||
parsed_response = JSON.parse(response_body, symbolize_names: true)
|
||||
|
||||
log.update!(
|
||||
raw_response_payload: response_body,
|
||||
request_tokens: parsed_response.dig(:usage, :prompt_tokens),
|
||||
response_tokens: parsed_response.dig(:usage, :completion_tokens),
|
||||
)
|
||||
return parsed_response
|
||||
end
|
||||
|
||||
begin
|
||||
cancelled = false
|
||||
cancel = lambda { cancelled = true }
|
||||
|
||||
leftover = ""
|
||||
|
||||
response.read_body do |chunk|
|
||||
if cancelled
|
||||
http.finish
|
||||
break
|
||||
end
|
||||
|
||||
response_raw << chunk
|
||||
|
||||
if (leftover + chunk).length < "data: [DONE]".length
|
||||
leftover += chunk
|
||||
next
|
||||
end
|
||||
|
||||
(leftover + chunk)
|
||||
.split("\n")
|
||||
.each do |line|
|
||||
data = line.split("data: ", 2)[1]
|
||||
next if !data || data == "[DONE]"
|
||||
next if cancelled
|
||||
|
||||
partial = nil
|
||||
begin
|
||||
partial = JSON.parse(data, symbolize_names: true)
|
||||
leftover = ""
|
||||
rescue JSON::ParserError
|
||||
leftover = line
|
||||
end
|
||||
|
||||
if partial
|
||||
response_data << partial.dig(:choices, 0, :delta, :content).to_s
|
||||
response_data << partial.dig(:choices, 0, :delta, :function_call).to_s
|
||||
|
||||
blk.call(partial, cancel)
|
||||
end
|
||||
end
|
||||
rescue IOError
|
||||
raise if !cancelled
|
||||
end
|
||||
end
|
||||
|
||||
return response_data
|
||||
end
|
||||
end
|
||||
ensure
|
||||
if log && block_given?
|
||||
request_tokens = DiscourseAi::Tokenizer::OpenAiTokenizer.size(extract_prompt(messages))
|
||||
response_tokens = DiscourseAi::Tokenizer::OpenAiTokenizer.size(response_data)
|
||||
log.update!(
|
||||
raw_response_payload: response_raw,
|
||||
request_tokens: request_tokens,
|
||||
response_tokens: response_tokens,
|
||||
)
|
||||
end
|
||||
if log && Rails.env.development?
|
||||
puts "OpenAiCompletions: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
|
||||
end
|
||||
end
|
||||
|
||||
def self.extract_prompt(messages)
|
||||
messages.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -58,9 +58,7 @@ after_initialize do
|
|||
end
|
||||
|
||||
if Rails.env.test?
|
||||
require_relative "spec/support/openai_completions_inference_stubs"
|
||||
require_relative "spec/support/anthropic_completion_stubs"
|
||||
require_relative "spec/support/stable_diffusion_stubs"
|
||||
require_relative "spec/support/embeddings_generation_stubs"
|
||||
require_relative "spec/support/stable_diffusion_stubs"
|
||||
end
|
||||
end
|
||||
|
|
|
@ -64,7 +64,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
let(:tool_deltas) { ["<function", <<~REPLY] }
|
||||
let(:tool_deltas) { ["Let me use a tool for that<function", <<~REPLY] }
|
||||
_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
|
|
|
@ -1,183 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAi
|
||||
module AiBot
|
||||
describe AnthropicBot do
|
||||
def bot_user
|
||||
User.find(EntryPoint::CLAUDE_V2_ID)
|
||||
end
|
||||
|
||||
before do
|
||||
SiteSetting.ai_bot_enabled_chat_bots = "claude-2"
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
end
|
||||
|
||||
let(:bot) { described_class.new(bot_user) }
|
||||
fab!(:post)
|
||||
|
||||
describe "system message" do
|
||||
it "includes the full command framework" do
|
||||
prompt = bot.system_prompt(post, allow_commands: true)
|
||||
|
||||
expect(prompt).to include("read")
|
||||
expect(prompt).to include("search_query")
|
||||
end
|
||||
end
|
||||
|
||||
it "does not include half parsed function calls in reply" do
|
||||
completion1 = "<function"
|
||||
completion2 = <<~REPLY
|
||||
_calls>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<parameters>
|
||||
<search_query>hello world</search_query>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
junk
|
||||
REPLY
|
||||
|
||||
completion1 = { completion: completion1 }.to_json
|
||||
completion2 = { completion: completion2 }.to_json
|
||||
|
||||
completion3 = { completion: "<func" }.to_json
|
||||
|
||||
request_number = 0
|
||||
|
||||
last_body = nil
|
||||
|
||||
stub_request(:post, "https://api.anthropic.com/v1/complete").with(
|
||||
body:
|
||||
lambda do |body|
|
||||
last_body = body
|
||||
request_number == 2
|
||||
end,
|
||||
).to_return(status: 200, body: lambda { |request| +"data: #{completion3}" })
|
||||
|
||||
stub_request(:post, "https://api.anthropic.com/v1/complete").with(
|
||||
body:
|
||||
lambda do |body|
|
||||
request_number += 1
|
||||
request_number == 1
|
||||
end,
|
||||
).to_return(
|
||||
status: 200,
|
||||
body: lambda { |request| +"data: #{completion1}\ndata: #{completion2}" },
|
||||
)
|
||||
|
||||
bot.reply_to(post)
|
||||
|
||||
post.topic.reload
|
||||
|
||||
raw = post.topic.ordered_posts.last.raw
|
||||
|
||||
prompt = JSON.parse(last_body)["prompt"]
|
||||
|
||||
# function call is bundled into Assitant prompt
|
||||
expect(prompt.split("Human:").length).to eq(2)
|
||||
|
||||
# this should be stripped
|
||||
expect(prompt).not_to include("junk")
|
||||
|
||||
expect(raw).to end_with("<func")
|
||||
|
||||
# leading <function_call> should be stripped
|
||||
expect(raw).to start_with("\n\n<details")
|
||||
end
|
||||
|
||||
it "does not include Assistant: in front of the system prompt" do
|
||||
prompt = nil
|
||||
|
||||
stub_request(:post, "https://api.anthropic.com/v1/complete").with(
|
||||
body:
|
||||
lambda do |body|
|
||||
json = JSON.parse(body)
|
||||
prompt = json["prompt"]
|
||||
true
|
||||
end,
|
||||
).to_return(
|
||||
status: 200,
|
||||
body: lambda { |request| +"data: " << { completion: "Hello World" }.to_json },
|
||||
)
|
||||
|
||||
bot.reply_to(post)
|
||||
|
||||
expect(prompt).not_to be_nil
|
||||
expect(prompt).not_to start_with("Assistant:")
|
||||
end
|
||||
|
||||
describe "parsing a reply prompt" do
|
||||
it "can correctly predict that a completion needs to be cancelled" do
|
||||
functions = DiscourseAi::AiBot::Bot::FunctionCalls.new
|
||||
|
||||
# note anthropic API has a silly leading space, we need to make sure we can handle that
|
||||
prompt = +<<~REPLY.strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<parameters>
|
||||
<search_query>hello world</search_query>
|
||||
<random_stuff>77</random_stuff>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls
|
||||
REPLY
|
||||
|
||||
bot.populate_functions(
|
||||
partial: nil,
|
||||
reply: prompt,
|
||||
functions: functions,
|
||||
done: false,
|
||||
current_delta: "",
|
||||
)
|
||||
|
||||
expect(functions.found?).to eq(true)
|
||||
expect(functions.cancel_completion?).to eq(false)
|
||||
|
||||
prompt << ">"
|
||||
|
||||
bot.populate_functions(
|
||||
partial: nil,
|
||||
reply: prompt,
|
||||
functions: functions,
|
||||
done: true,
|
||||
current_delta: "",
|
||||
)
|
||||
|
||||
expect(functions.found?).to eq(true)
|
||||
|
||||
expect(functions.to_a.length).to eq(1)
|
||||
|
||||
expect(functions.to_a).to eq(
|
||||
[{ name: "search", arguments: "{\"search_query\":\"hello world\"}" }],
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#update_with_delta" do
|
||||
describe "get_delta" do
|
||||
it "can properly remove first leading space" do
|
||||
context = {}
|
||||
reply = +""
|
||||
|
||||
reply << bot.get_delta({ completion: " Hello" }, context)
|
||||
reply << bot.get_delta({ completion: " World" }, context)
|
||||
expect(reply).to eq("Hello World")
|
||||
end
|
||||
|
||||
it "can properly remove Assistant prefix" do
|
||||
context = {}
|
||||
reply = +""
|
||||
|
||||
reply << bot.get_delta({ completion: "Hello " }, context)
|
||||
expect(reply).to eq("Hello ")
|
||||
|
||||
reply << bot.get_delta({ completion: "world" }, context)
|
||||
expect(reply).to eq("Hello world")
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,116 +1,16 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class FakeBot < DiscourseAi::AiBot::Bot
|
||||
class Tokenizer
|
||||
def tokenize(text)
|
||||
text.split(" ")
|
||||
end
|
||||
end
|
||||
RSpec.describe DiscourseAi::AiBot::Bot do
|
||||
subject(:bot) { described_class.as(bot_user) }
|
||||
|
||||
def tokenizer
|
||||
Tokenizer.new
|
||||
end
|
||||
|
||||
def prompt_limit(allow_commands: false)
|
||||
10_000
|
||||
end
|
||||
|
||||
def build_message(poster_username, content, system: false, function: nil)
|
||||
role = poster_username == bot_user.username ? "Assistant" : "Human"
|
||||
|
||||
"#{role}: #{content}"
|
||||
end
|
||||
|
||||
def submit_prompt(prompt, post: nil, prefer_low_cost: false)
|
||||
rows = @responses.shift
|
||||
rows.each { |data| yield data, lambda {} }
|
||||
end
|
||||
|
||||
def get_delta(partial, context)
|
||||
partial
|
||||
end
|
||||
|
||||
def add_response(response)
|
||||
@responses ||= []
|
||||
@responses << response
|
||||
end
|
||||
end
|
||||
|
||||
describe FakeBot do
|
||||
before do
|
||||
SiteSetting.ai_bot_enabled_chat_bots = "gpt-4"
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
end
|
||||
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) }
|
||||
fab!(:post) { Fabricate(:post, raw: "hello world") }
|
||||
|
||||
it "can handle command truncation for long messages" do
|
||||
bot = FakeBot.new(bot_user)
|
||||
|
||||
tags_command = <<~TEXT
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>tags</tool_name>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
|
||||
bot.add_response(["hello this is a big test I am testing 123\n", "#{tags_command}\nabc"])
|
||||
bot.add_response(["this is the reply"])
|
||||
|
||||
bot.reply_to(post)
|
||||
|
||||
reply = post.topic.posts.order(:post_number).last
|
||||
|
||||
expect(reply.raw).not_to include("abc")
|
||||
expect(reply.post_custom_prompt.custom_prompt.to_s).not_to include("abc")
|
||||
expect(reply.post_custom_prompt.custom_prompt.length).to eq(3)
|
||||
expect(reply.post_custom_prompt.custom_prompt[0][0]).to eq(
|
||||
"hello this is a big test I am testing 123\n#{tags_command.strip}",
|
||||
)
|
||||
end
|
||||
|
||||
it "can handle command truncation for short bot messages" do
|
||||
bot = FakeBot.new(bot_user)
|
||||
|
||||
tags_command = <<~TEXT
|
||||
_calls>
|
||||
<invoke>
|
||||
<tool_name>tags</tool_name>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
|
||||
bot.add_response(["hello\n<function", "#{tags_command}\nabc"])
|
||||
bot.add_response(["this is the reply"])
|
||||
|
||||
bot.reply_to(post)
|
||||
|
||||
reply = post.topic.posts.order(:post_number).last
|
||||
|
||||
expect(reply.raw).not_to include("abc")
|
||||
expect(reply.post_custom_prompt.custom_prompt.to_s).not_to include("abc")
|
||||
expect(reply.post_custom_prompt.custom_prompt.length).to eq(3)
|
||||
expect(reply.post_custom_prompt.custom_prompt[0][0]).to eq(
|
||||
"hello\n<function#{tags_command.strip}",
|
||||
)
|
||||
|
||||
# we don't want function leftovers
|
||||
expect(reply.raw).to start_with("hello\n\n<details>")
|
||||
end
|
||||
end
|
||||
|
||||
describe DiscourseAi::AiBot::Bot do
|
||||
before do
|
||||
SiteSetting.ai_bot_enabled_chat_bots = "gpt-4"
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
end
|
||||
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) }
|
||||
let(:bot) { described_class.as(bot_user) }
|
||||
|
||||
fab!(:user) { Fabricate(:user) }
|
||||
let!(:user) { Fabricate(:user) }
|
||||
let!(:pm) do
|
||||
Fabricate(
|
||||
:private_message_topic,
|
||||
|
@ -122,101 +22,39 @@ describe DiscourseAi::AiBot::Bot do
|
|||
],
|
||||
)
|
||||
end
|
||||
let!(:first_post) { Fabricate(:post, topic: pm, user: user, raw: "This is a reply by the user") }
|
||||
let!(:second_post) do
|
||||
Fabricate(:post, topic: pm, user: user, raw: "This is a second reply by the user")
|
||||
end
|
||||
let!(:pm_post) { Fabricate(:post, topic: pm, user: user, raw: "Does my site has tags?") }
|
||||
|
||||
describe "#system_prompt" do
|
||||
it "includes relevant context in system prompt" do
|
||||
bot.system_prompt_style!(:standard)
|
||||
let(:function_call) { <<~TEXT }
|
||||
Let me try using a function to get more info:<function_calls>
|
||||
<invoke>
|
||||
<tool_name>categories</tool_name>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
|
||||
SiteSetting.title = "My Forum"
|
||||
SiteSetting.site_description = "My Forum Description"
|
||||
let(:response) { "As expected, your forum has multiple tags" }
|
||||
|
||||
system_prompt = bot.system_prompt(second_post, allow_commands: true)
|
||||
let(:llm_responses) { [function_call, response] }
|
||||
|
||||
expect(system_prompt).to include(SiteSetting.title)
|
||||
expect(system_prompt).to include(SiteSetting.site_description)
|
||||
describe "#reply" do
|
||||
context "when using function chaining" do
|
||||
it "yields a loading placeholder while proceeds to invoke the command" do
|
||||
tool = DiscourseAi::AiBot::Tools::ListCategories.new({})
|
||||
partial_placeholder = +(<<~HTML)
|
||||
<details>
|
||||
<summary>#{tool.summary}</summary>
|
||||
<p></p>
|
||||
</details>
|
||||
HTML
|
||||
|
||||
expect(system_prompt).to include(user.username)
|
||||
context = {}
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(llm_responses) do
|
||||
bot.reply(context) do |_bot_reply_post, cancel, placeholder|
|
||||
expect(placeholder).to eq(partial_placeholder) if placeholder
|
||||
end
|
||||
end
|
||||
|
||||
describe "#reply_to" do
|
||||
it "can respond to a search command" do
|
||||
bot.system_prompt_style!(:simple)
|
||||
|
||||
expected_response = {
|
||||
function_call: {
|
||||
name: "search",
|
||||
arguments: { query: "test search" }.to_json,
|
||||
},
|
||||
}
|
||||
|
||||
prompt = bot.bot_prompt_with_topic_context(second_post, allow_commands: true)
|
||||
|
||||
req_opts = bot.reply_params.merge({ functions: bot.available_functions, stream: true })
|
||||
|
||||
OpenAiCompletionsInferenceStubs.stub_streamed_response(
|
||||
prompt,
|
||||
[expected_response],
|
||||
model: bot.model_for,
|
||||
req_opts: req_opts,
|
||||
)
|
||||
|
||||
result =
|
||||
DiscourseAi::AiBot::Commands::SearchCommand
|
||||
.new(bot: nil, args: nil)
|
||||
.process(query: "test search")
|
||||
.to_json
|
||||
|
||||
prompt << { role: "function", content: result, name: "search" }
|
||||
|
||||
OpenAiCompletionsInferenceStubs.stub_streamed_response(
|
||||
prompt,
|
||||
[content: "I found nothing, sorry"],
|
||||
model: bot.model_for,
|
||||
req_opts: req_opts,
|
||||
)
|
||||
|
||||
bot.reply_to(second_post)
|
||||
|
||||
last = second_post.topic.posts.order("id desc").first
|
||||
|
||||
expect(last.raw).to include("<details>")
|
||||
expect(last.raw).to include("<summary>Search</summary>")
|
||||
expect(last.raw).not_to include("translation missing")
|
||||
expect(last.raw).to include("I found nothing")
|
||||
|
||||
expect(last.post_custom_prompt.custom_prompt).to eq(
|
||||
[[result, "search", "function"], ["I found nothing, sorry", bot_user.username]],
|
||||
)
|
||||
log = AiApiAuditLog.find_by(post_id: second_post.id)
|
||||
expect(log).to be_present
|
||||
end
|
||||
end
|
||||
|
||||
describe "#update_pm_title" do
|
||||
let(:expected_response) { "This is a suggested title" }
|
||||
|
||||
before { SiteSetting.min_personal_message_post_length = 5 }
|
||||
|
||||
it "updates the title using bot suggestions" do
|
||||
OpenAiCompletionsInferenceStubs.stub_response(
|
||||
bot.title_prompt(second_post),
|
||||
expected_response,
|
||||
model: bot.model_for,
|
||||
req_opts: {
|
||||
temperature: 0.7,
|
||||
top_p: 0.9,
|
||||
max_tokens: 40,
|
||||
},
|
||||
)
|
||||
|
||||
bot.update_pm_title(second_post)
|
||||
|
||||
expect(pm.reload.title).to eq(expected_response)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::CategoriesCommand do
|
||||
describe "#generate_categories_info" do
|
||||
it "can generate correct info" do
|
||||
Fabricate(:category, name: "america", posts_year: 999)
|
||||
|
||||
info = DiscourseAi::AiBot::Commands::CategoriesCommand.new(bot: nil, args: nil).process
|
||||
expect(info.to_s).to include("america")
|
||||
expect(info.to_s).to include("999")
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,36 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::Command do
|
||||
let(:command) { DiscourseAi::AiBot::Commands::GoogleCommand.new(bot: nil, args: nil) }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
describe "#format_results" do
|
||||
it "can generate efficient tables of data" do
|
||||
rows = [1, 2, 3, 4, 5]
|
||||
column_names = %w[first second third]
|
||||
|
||||
formatted =
|
||||
command.format_results(rows, column_names) { |row| ["row ¦ 1", row + 1, "a|b,\nc"] }
|
||||
|
||||
expect(formatted[:column_names].length).to eq(3)
|
||||
expect(formatted[:rows].length).to eq(5)
|
||||
expect(formatted.to_s).to include("a|b,\\nc")
|
||||
end
|
||||
|
||||
it "can also generate results by returning hash per row" do
|
||||
rows = [1, 2, 3, 4, 5]
|
||||
column_names = %w[first second third]
|
||||
|
||||
formatted =
|
||||
command.format_results(rows, column_names) { |row| ["row ¦ 1", row + 1, "a|b,\nc"] }
|
||||
|
||||
formatted2 =
|
||||
command.format_results(rows) do |row|
|
||||
{ first: "row ¦ 1", second: row + 1, third: "a|b,\nc" }
|
||||
end
|
||||
|
||||
expect(formatted).to eq(formatted2)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,207 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
|
||||
before { SearchIndexer.enable }
|
||||
after { SearchIndexer.disable }
|
||||
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
fab!(:admin)
|
||||
fab!(:parent_category) { Fabricate(:category, name: "animals") }
|
||||
fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") }
|
||||
fab!(:tag_funny) { Fabricate(:tag, name: "funny") }
|
||||
fab!(:tag_sad) { Fabricate(:tag, name: "sad") }
|
||||
fab!(:tag_hidden) { Fabricate(:tag, name: "hidden") }
|
||||
fab!(:staff_tag_group) do
|
||||
tag_group = Fabricate.build(:tag_group, name: "Staff only", tag_names: ["hidden"])
|
||||
|
||||
tag_group.permissions = [
|
||||
[Group::AUTO_GROUPS[:staff], TagGroupPermission.permission_types[:full]],
|
||||
]
|
||||
tag_group.save!
|
||||
tag_group
|
||||
end
|
||||
fab!(:topic_with_tags) do
|
||||
Fabricate(:topic, category: category, tags: [tag_funny, tag_sad, tag_hidden])
|
||||
end
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
it "can properly list options" do
|
||||
options = described_class.options.sort_by(&:name)
|
||||
expect(options.length).to eq(2)
|
||||
expect(options.first.name.to_s).to eq("base_query")
|
||||
expect(options.first.localized_name).not_to include("Translation missing:")
|
||||
expect(options.first.localized_description).not_to include("Translation missing:")
|
||||
|
||||
expect(options.second.name.to_s).to eq("max_results")
|
||||
expect(options.second.localized_name).not_to include("Translation missing:")
|
||||
expect(options.second.localized_description).not_to include("Translation missing:")
|
||||
end
|
||||
|
||||
describe "#process" do
|
||||
it "can retreive options from persona correctly" do
|
||||
persona =
|
||||
Fabricate(
|
||||
:ai_persona,
|
||||
allowed_group_ids: [Group::AUTO_GROUPS[:admins]],
|
||||
commands: [["SearchCommand", { "base_query" => "#funny" }]],
|
||||
)
|
||||
Group.refresh_automatic_groups!
|
||||
|
||||
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona_id: persona.id, user: admin)
|
||||
search_post = Fabricate(:post, topic: topic_with_tags)
|
||||
|
||||
bot_post = Fabricate(:post)
|
||||
|
||||
search = described_class.new(bot: bot, post: bot_post, args: nil)
|
||||
|
||||
results = search.process(order: "latest")
|
||||
expect(results[:rows].length).to eq(1)
|
||||
|
||||
search_post.topic.tags = []
|
||||
search_post.topic.save!
|
||||
|
||||
# no longer has the tag funny
|
||||
results = search.process(order: "latest")
|
||||
expect(results[:rows].length).to eq(0)
|
||||
end
|
||||
|
||||
it "can handle no results" do
|
||||
post1 = Fabricate(:post, topic: topic_with_tags)
|
||||
search = described_class.new(bot: nil, post: post1, args: nil)
|
||||
|
||||
results = search.process(query: "order:fake ABDDCDCEDGDG")
|
||||
|
||||
expect(results[:args]).to eq({ query: "order:fake ABDDCDCEDGDG" })
|
||||
expect(results[:rows]).to eq([])
|
||||
end
|
||||
|
||||
describe "semantic search" do
|
||||
let (:query) {
|
||||
"this is an expanded search"
|
||||
}
|
||||
after { DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query) }
|
||||
|
||||
it "supports semantic search when enabled" do
|
||||
SiteSetting.ai_embeddings_semantic_search_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
|
||||
WebMock.stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||
status: 200,
|
||||
body: JSON.dump(OpenAiCompletionsInferenceStubs.response(query)),
|
||||
)
|
||||
|
||||
hyde_embedding = [0.049382, 0.9999]
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
SiteSetting.ai_embeddings_model,
|
||||
query,
|
||||
hyde_embedding,
|
||||
)
|
||||
|
||||
post1 = Fabricate(:post, topic: topic_with_tags)
|
||||
search = described_class.new(bot: nil, post: post1, args: nil)
|
||||
|
||||
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
|
||||
.any_instance
|
||||
.expects(:asymmetric_topics_similarity_search)
|
||||
.returns([post1.topic_id])
|
||||
|
||||
results =
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(["<ai>#{query}</ai>"]) do
|
||||
search.process(search_query: "hello world, sam", status: "public")
|
||||
end
|
||||
|
||||
expect(results[:args]).to eq({ search_query: "hello world, sam", status: "public" })
|
||||
expect(results[:rows].length).to eq(1)
|
||||
end
|
||||
end
|
||||
|
||||
it "supports subfolder properly" do
|
||||
Discourse.stubs(:base_path).returns("/subfolder")
|
||||
|
||||
post1 = Fabricate(:post, topic: topic_with_tags)
|
||||
|
||||
search = described_class.new(bot: nil, post: post1, args: nil)
|
||||
|
||||
results = search.process(limit: 1, user: post1.user.username)
|
||||
expect(results[:rows].to_s).to include("/subfolder" + post1.url)
|
||||
end
|
||||
|
||||
it "returns rich topic information" do
|
||||
post1 = Fabricate(:post, like_count: 1, topic: topic_with_tags)
|
||||
search = described_class.new(bot: nil, post: post1, args: nil)
|
||||
post1.topic.update!(views: 100, posts_count: 2, like_count: 10)
|
||||
|
||||
results = search.process(user: post1.user.username)
|
||||
|
||||
row = results[:rows].first
|
||||
category = row[results[:column_names].index("category")]
|
||||
|
||||
expect(category).to eq("animals > amazing-cat")
|
||||
|
||||
tags = row[results[:column_names].index("tags")]
|
||||
expect(tags).to eq("funny, sad")
|
||||
|
||||
likes = row[results[:column_names].index("likes")]
|
||||
expect(likes).to eq(1)
|
||||
|
||||
username = row[results[:column_names].index("username")]
|
||||
expect(username).to eq(post1.user.username)
|
||||
|
||||
likes = row[results[:column_names].index("topic_likes")]
|
||||
expect(likes).to eq(10)
|
||||
|
||||
views = row[results[:column_names].index("topic_views")]
|
||||
expect(views).to eq(100)
|
||||
|
||||
replies = row[results[:column_names].index("topic_replies")]
|
||||
expect(replies).to eq(1)
|
||||
end
|
||||
|
||||
it "scales results to number of tokens" do
|
||||
SiteSetting.ai_bot_enabled_chat_bots = "gpt-3.5-turbo|gpt-4|claude-2"
|
||||
|
||||
post1 = Fabricate(:post)
|
||||
|
||||
gpt_3_5_turbo =
|
||||
DiscourseAi::AiBot::Bot.as(User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID))
|
||||
gpt4 = DiscourseAi::AiBot::Bot.as(User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID))
|
||||
claude = DiscourseAi::AiBot::Bot.as(User.find(DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID))
|
||||
|
||||
expect(described_class.new(bot: claude, post: post1, args: nil).max_results).to eq(60)
|
||||
expect(described_class.new(bot: gpt_3_5_turbo, post: post1, args: nil).max_results).to eq(40)
|
||||
expect(described_class.new(bot: gpt4, post: post1, args: nil).max_results).to eq(20)
|
||||
|
||||
persona =
|
||||
Fabricate(
|
||||
:ai_persona,
|
||||
commands: [["SearchCommand", { "max_results" => 6 }]],
|
||||
enabled: true,
|
||||
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
|
||||
)
|
||||
|
||||
Group.refresh_automatic_groups!
|
||||
|
||||
custom_bot = DiscourseAi::AiBot::Bot.as(bot_user, persona_id: persona.id, user: admin)
|
||||
|
||||
expect(described_class.new(bot: custom_bot, post: post1, args: nil).max_results).to eq(6)
|
||||
end
|
||||
|
||||
it "can handle limits" do
|
||||
post1 = Fabricate(:post, topic: topic_with_tags)
|
||||
_post2 = Fabricate(:post, user: post1.user)
|
||||
_post3 = Fabricate(:post, user: post1.user)
|
||||
|
||||
# search has no built in support for limit: so handle it from the outside
|
||||
search = described_class.new(bot: nil, post: post1, args: nil)
|
||||
|
||||
results = search.process(limit: 2, user: post1.user.username)
|
||||
|
||||
expect(results[:rows].length).to eq(2)
|
||||
|
||||
# just searching for everything
|
||||
results = search.process(order: "latest_topic")
|
||||
expect(results[:rows].length).to be > 1
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,29 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::SearchSettingsCommand do
|
||||
let(:search) { described_class.new(bot: nil, args: nil) }
|
||||
|
||||
describe "#process" do
|
||||
it "can handle no results" do
|
||||
results = search.process(query: "this will not exist frogs")
|
||||
expect(results[:args]).to eq({ query: "this will not exist frogs" })
|
||||
expect(results[:rows]).to eq([])
|
||||
end
|
||||
|
||||
it "can return more many settings with no descriptions if there are lots of hits" do
|
||||
results = search.process(query: "a")
|
||||
|
||||
expect(results[:rows].length).to be > 30
|
||||
expect(results[:rows][0].length).to eq(1)
|
||||
end
|
||||
|
||||
it "can return descriptions if there are few matches" do
|
||||
results =
|
||||
search.process(query: "this will not be found!@,default_locale,ai_bot_enabled_chat_bots")
|
||||
|
||||
expect(results[:rows].length).to eq(2)
|
||||
|
||||
expect(results[:rows][0][1]).not_to eq(nil)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,42 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::SummarizeCommand do
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:bot) { DiscourseAi::AiBot::OpenAiBot.new(bot_user) }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
describe "#process" do
|
||||
it "can generate correct info" do
|
||||
post = Fabricate(:post)
|
||||
|
||||
WebMock.stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||
status: 200,
|
||||
body: JSON.dump({ choices: [{ message: { content: "summary stuff" } }] }),
|
||||
)
|
||||
|
||||
summarizer = described_class.new(bot: bot, args: nil, post: post)
|
||||
info = summarizer.process(topic_id: post.topic_id, guidance: "why did it happen?")
|
||||
|
||||
expect(info).to include("Topic summarized")
|
||||
expect(summarizer.custom_raw).to include("summary stuff")
|
||||
expect(summarizer.chain_next_response).to eq(false)
|
||||
end
|
||||
|
||||
it "protects hidden data" do
|
||||
category = Fabricate(:category)
|
||||
category.set_permissions({})
|
||||
category.save!
|
||||
|
||||
topic = Fabricate(:topic, category_id: category.id)
|
||||
post = Fabricate(:post, topic: topic)
|
||||
|
||||
summarizer = described_class.new(bot: bot, post: post, args: nil)
|
||||
info = summarizer.process(topic_id: post.topic_id, guidance: "why did it happen?")
|
||||
|
||||
expect(info).not_to include(post.raw)
|
||||
|
||||
expect(summarizer.custom_raw).to eq(I18n.t("discourse_ai.ai_bot.topic_not_found"))
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,17 +0,0 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::TagsCommand do
|
||||
describe "#process" do
|
||||
it "can generate correct info" do
|
||||
SiteSetting.tagging_enabled = true
|
||||
|
||||
Fabricate(:tag, name: "america", public_topic_count: 100)
|
||||
Fabricate(:tag, name: "not_here", public_topic_count: 0)
|
||||
|
||||
info = DiscourseAi::AiBot::Commands::TagsCommand.new(bot: nil, args: nil).process
|
||||
|
||||
expect(info.to_s).to include("america")
|
||||
expect(info.to_s).not_to include("not_here")
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,11 +1,7 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe Jobs::CreateAiReply do
|
||||
before do
|
||||
# got to do this cause we include times in system message
|
||||
freeze_time
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
end
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
describe "#execute" do
|
||||
fab!(:topic) { Fabricate(:topic) }
|
||||
|
@ -17,95 +13,15 @@ RSpec.describe Jobs::CreateAiReply do
|
|||
|
||||
before { SiteSetting.min_personal_message_post_length = 5 }
|
||||
|
||||
context "when chatting with the Open AI bot" do
|
||||
let(:deltas) { expected_response.split(" ").map { |w| { content: "#{w} " } } }
|
||||
|
||||
before do
|
||||
bot_user = User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID)
|
||||
bot = DiscourseAi::AiBot::Bot.as(bot_user)
|
||||
|
||||
# time needs to be frozen so time in prompt does not drift
|
||||
freeze_time
|
||||
|
||||
OpenAiCompletionsInferenceStubs.stub_streamed_response(
|
||||
DiscourseAi::AiBot::OpenAiBot.new(bot_user).bot_prompt_with_topic_context(
|
||||
post,
|
||||
allow_commands: true,
|
||||
),
|
||||
deltas,
|
||||
model: bot.model_for,
|
||||
req_opts: {
|
||||
temperature: 0.4,
|
||||
top_p: 0.9,
|
||||
max_tokens: 2500,
|
||||
functions: bot.available_functions,
|
||||
stream: true,
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
it "adds a reply from the GPT bot" do
|
||||
subject.execute(
|
||||
post_id: topic.first_post.id,
|
||||
bot_user_id: DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID,
|
||||
)
|
||||
|
||||
expect(topic.posts.last.raw).to eq(expected_response)
|
||||
end
|
||||
|
||||
it "streams the reply on the fly to the client through MB" do
|
||||
messages =
|
||||
MessageBus.track_publish("discourse-ai/ai-bot/topic/#{topic.id}") do
|
||||
it "adds a reply from the bot" do
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([expected_response]) do
|
||||
subject.execute(
|
||||
post_id: topic.first_post.id,
|
||||
bot_user_id: DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID,
|
||||
)
|
||||
end
|
||||
|
||||
done_signal = messages.pop
|
||||
|
||||
expect(messages.length).to eq(deltas.length)
|
||||
|
||||
messages.each_with_index do |m, idx|
|
||||
expect(m.data[:raw]).to eq(deltas[0..(idx + 1)].map { |d| d[:content] }.join)
|
||||
end
|
||||
|
||||
expect(done_signal.data[:done]).to eq(true)
|
||||
end
|
||||
end
|
||||
|
||||
context "when chatting with Claude from Anthropic" do
|
||||
let(:claude_response) { "#{expected_response}" }
|
||||
let(:deltas) { claude_response.split(" ").map { |w| "#{w} " } }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_bot_enabled_chat_bots = "claude-2"
|
||||
bot_user = User.find(DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID)
|
||||
|
||||
AnthropicCompletionStubs.stub_streamed_response(
|
||||
DiscourseAi::AiBot::AnthropicBot.new(bot_user).bot_prompt_with_topic_context(
|
||||
post,
|
||||
allow_commands: true,
|
||||
),
|
||||
deltas,
|
||||
model: "claude-2",
|
||||
req_opts: {
|
||||
max_tokens_to_sample: 3000,
|
||||
temperature: 0.4,
|
||||
stream: true,
|
||||
stop_sequences: ["\n\nHuman:", "</function_calls>"],
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
it "adds a reply from the Claude bot" do
|
||||
subject.execute(
|
||||
post_id: topic.first_post.id,
|
||||
bot_user_id: DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID,
|
||||
)
|
||||
|
||||
expect(topic.posts.last.raw).to eq(expected_response)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -12,23 +12,6 @@ RSpec.describe Jobs::UpdateAiBotPmTitle do
|
|||
it "will properly update title on bot PMs" do
|
||||
SiteSetting.ai_bot_allowed_groups = Group::AUTO_GROUPS[:staff]
|
||||
|
||||
Jobs.run_immediately!
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
||||
.with(body: /You are a helpful Discourse assistant/)
|
||||
.to_return(status: 200, body: "data: {\"completion\": \"Hello back at you\"}", headers: {})
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
||||
.with(body: /Suggest a 7 word title/)
|
||||
.to_return(
|
||||
status: 200,
|
||||
body: "{\"completion\": \"A great title would be:\n\nMy amazing title\n\n\"}",
|
||||
headers: {
|
||||
},
|
||||
)
|
||||
|
||||
post =
|
||||
create_post(
|
||||
user: user,
|
||||
|
@ -38,11 +21,20 @@ RSpec.describe Jobs::UpdateAiBotPmTitle do
|
|||
target_usernames: bot_user.username,
|
||||
)
|
||||
|
||||
title_result = "A great title would be:\n\nMy amazing title\n\n"
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([title_result]) do
|
||||
subject.execute(bot_user_id: bot_user.id, post_id: post.id)
|
||||
|
||||
expect(post.reload.topic.title).to eq("My amazing title")
|
||||
end
|
||||
|
||||
WebMock.reset!
|
||||
another_title = "I'm a different title"
|
||||
|
||||
Jobs::UpdateAiBotPmTitle.new.execute(bot_user_id: bot_user.id, post_id: post.id)
|
||||
# should be a no op cause title is updated
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([another_title]) do
|
||||
subject.execute(bot_user_id: bot_user.id, post_id: post.id)
|
||||
|
||||
expect(post.reload.topic.title).to eq("My amazing title")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::OpenAiBot do
|
||||
describe "#bot_prompt_with_topic_context" do
|
||||
fab!(:topic) { Fabricate(:topic) }
|
||||
|
||||
def post_body(post_number)
|
||||
"This is post #{post_number}"
|
||||
end
|
||||
|
||||
def bot_user
|
||||
User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID)
|
||||
end
|
||||
|
||||
subject { described_class.new(bot_user) }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_bot_enabled_chat_bots = "gpt-4"
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
end
|
||||
|
||||
context "when cleaning usernames" do
|
||||
it "can properly clean usernames so OpenAI allows it" do
|
||||
expect(subject.clean_username("test test")).to eq("test_test")
|
||||
expect(subject.clean_username("test.test")).to eq("test_test")
|
||||
expect(subject.clean_username("test😀test")).to eq("test_test")
|
||||
end
|
||||
end
|
||||
|
||||
context "when the topic has one post" do
|
||||
fab!(:post_1) { Fabricate(:post, topic: topic, raw: post_body(1), post_number: 1) }
|
||||
|
||||
it "includes it in the prompt" do
|
||||
prompt_messages = subject.bot_prompt_with_topic_context(post_1, allow_commands: true)
|
||||
|
||||
post_1_message = prompt_messages[-1]
|
||||
|
||||
expect(post_1_message[:role]).to eq("user")
|
||||
expect(post_1_message[:content]).to eq(post_body(1))
|
||||
expect(post_1_message[:name]).to eq(post_1.user.username)
|
||||
end
|
||||
end
|
||||
|
||||
context "when prompt gets very long" do
|
||||
fab!(:post_1) { Fabricate(:post, topic: topic, raw: "test " * 6000, post_number: 1) }
|
||||
|
||||
it "trims the prompt" do
|
||||
prompt_messages = subject.bot_prompt_with_topic_context(post_1, allow_commands: true)
|
||||
|
||||
# trimming is tricky... it needs to account for system message as
|
||||
# well... just make sure we trim for now
|
||||
expect(prompt_messages[-1][:content].length).to be < post_1.raw.length
|
||||
end
|
||||
end
|
||||
|
||||
context "when the topic has multiple posts" do
|
||||
let!(:post_1) { Fabricate(:post, topic: topic, raw: post_body(1), post_number: 1) }
|
||||
let!(:post_2) do
|
||||
Fabricate(:post, topic: topic, user: bot_user, raw: post_body(2), post_number: 2)
|
||||
end
|
||||
let!(:post_3) { Fabricate(:post, topic: topic, raw: post_body(3), post_number: 3) }
|
||||
|
||||
it "includes them in the prompt respecting the post number order" do
|
||||
prompt_messages = subject.bot_prompt_with_topic_context(post_3, allow_commands: true)
|
||||
|
||||
# negative cause we may have grounding prompts
|
||||
expect(prompt_messages[-3][:role]).to eq("user")
|
||||
expect(prompt_messages[-3][:content]).to eq(post_body(1))
|
||||
expect(prompt_messages[-3][:name]).to eq(post_1.username)
|
||||
|
||||
expect(prompt_messages[-2][:role]).to eq("assistant")
|
||||
expect(prompt_messages[-2][:content]).to eq(post_body(2))
|
||||
|
||||
expect(prompt_messages[-1][:role]).to eq("user")
|
||||
expect(prompt_messages[-1][:content]).to eq(post_body(3))
|
||||
expect(prompt_messages[-1][:name]).to eq(post_3.username)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,11 +1,11 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
class TestPersona < DiscourseAi::AiBot::Personas::Persona
|
||||
def commands
|
||||
def tools
|
||||
[
|
||||
DiscourseAi::AiBot::Commands::TagsCommand,
|
||||
DiscourseAi::AiBot::Commands::SearchCommand,
|
||||
DiscourseAi::AiBot::Commands::ImageCommand,
|
||||
DiscourseAi::AiBot::Tools::ListTags,
|
||||
DiscourseAi::AiBot::Tools::Search,
|
||||
DiscourseAi::AiBot::Tools::Image,
|
||||
]
|
||||
end
|
||||
|
||||
|
@ -37,41 +37,36 @@ module DiscourseAi::AiBot::Personas
|
|||
AiPersona.persona_cache.flush!
|
||||
end
|
||||
|
||||
fab!(:user)
|
||||
|
||||
it "can disable commands" do
|
||||
persona = TestPersona.new
|
||||
|
||||
rendered = persona.render_system_prompt(topic: topic_with_users, allow_commands: false)
|
||||
|
||||
expect(rendered).not_to include("!tags")
|
||||
expect(rendered).not_to include("!search")
|
||||
let(:context) do
|
||||
{
|
||||
site_url: Discourse.base_url,
|
||||
site_title: "test site title",
|
||||
site_description: "test site description",
|
||||
time: Time.zone.now,
|
||||
participants: topic_with_users.allowed_users.map(&:username).join(", "),
|
||||
}
|
||||
end
|
||||
|
||||
fab!(:user)
|
||||
|
||||
it "renders the system prompt" do
|
||||
freeze_time
|
||||
|
||||
SiteSetting.title = "test site title"
|
||||
SiteSetting.site_description = "test site description"
|
||||
rendered = persona.craft_prompt(context)
|
||||
|
||||
rendered =
|
||||
persona.render_system_prompt(topic: topic_with_users, render_function_instructions: true)
|
||||
expect(rendered[:insts]).to include(Discourse.base_url)
|
||||
expect(rendered[:insts]).to include("test site title")
|
||||
expect(rendered[:insts]).to include("test site description")
|
||||
expect(rendered[:insts]).to include("joe, jane")
|
||||
expect(rendered[:insts]).to include(Time.zone.now.to_s)
|
||||
|
||||
tools = rendered[:tools]
|
||||
|
||||
expect(tools.find { |t| t[:name] == "search" }).to be_present
|
||||
expect(tools.find { |t| t[:name] == "tags" }).to be_present
|
||||
|
||||
expect(rendered).to include(Discourse.base_url)
|
||||
expect(rendered).to include("test site title")
|
||||
expect(rendered).to include("test site description")
|
||||
expect(rendered).to include("joe, jane")
|
||||
expect(rendered).to include(Time.zone.now.to_s)
|
||||
expect(rendered).to include("<tool_name>search</tool_name>")
|
||||
expect(rendered).to include("<tool_name>tags</tool_name>")
|
||||
# needs to be configured so it is not available
|
||||
expect(rendered).not_to include("<tool_name>image</tool_name>")
|
||||
|
||||
rendered =
|
||||
persona.render_system_prompt(topic: topic_with_users, render_function_instructions: false)
|
||||
|
||||
expect(rendered).not_to include("<tool_name>search</tool_name>")
|
||||
expect(rendered).not_to include("<tool_name>tags</tool_name>")
|
||||
expect(tools.find { |t| t[:name] == "image" }).to be_nil
|
||||
end
|
||||
|
||||
describe "custom personas" do
|
||||
|
@ -88,31 +83,29 @@ module DiscourseAi::AiBot::Personas
|
|||
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
|
||||
)
|
||||
|
||||
custom_persona = DiscourseAi::AiBot::Personas.all(user: user).last
|
||||
custom_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last
|
||||
expect(custom_persona.name).to eq("zzzpun_bot")
|
||||
expect(custom_persona.description).to eq("you write puns")
|
||||
|
||||
instance = custom_persona.new
|
||||
expect(instance.commands).to eq([DiscourseAi::AiBot::Commands::ImageCommand])
|
||||
expect(instance.render_system_prompt(render_function_instructions: true)).to eq(
|
||||
"you are pun bot",
|
||||
)
|
||||
expect(instance.tools).to eq([DiscourseAi::AiBot::Tools::Image])
|
||||
expect(instance.craft_prompt(context).dig(:insts)).to eq("you are pun bot\n\n")
|
||||
|
||||
# should update
|
||||
persona.update!(name: "zzzpun_bot2")
|
||||
custom_persona = DiscourseAi::AiBot::Personas.all(user: user).last
|
||||
custom_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last
|
||||
expect(custom_persona.name).to eq("zzzpun_bot2")
|
||||
|
||||
# can be disabled
|
||||
persona.update!(enabled: false)
|
||||
last_persona = DiscourseAi::AiBot::Personas.all(user: user).last
|
||||
last_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last
|
||||
expect(last_persona.name).not_to eq("zzzpun_bot2")
|
||||
|
||||
persona.update!(enabled: true)
|
||||
# no groups have access
|
||||
persona.update!(allowed_group_ids: [])
|
||||
|
||||
last_persona = DiscourseAi::AiBot::Personas.all(user: user).last
|
||||
last_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last
|
||||
expect(last_persona.name).not_to eq("zzzpun_bot2")
|
||||
end
|
||||
end
|
||||
|
@ -127,7 +120,7 @@ module DiscourseAi::AiBot::Personas
|
|||
SiteSetting.ai_google_custom_search_cx = "abc123"
|
||||
|
||||
# should be ordered by priority and then alpha
|
||||
expect(DiscourseAi::AiBot::Personas.all(user: user)).to eq(
|
||||
expect(DiscourseAi::AiBot::Personas::Persona.all(user: user)).to eq(
|
||||
[General, Artist, Creative, Researcher, SettingsExplorer, SqlHelper],
|
||||
)
|
||||
|
||||
|
@ -135,18 +128,18 @@ module DiscourseAi::AiBot::Personas
|
|||
SiteSetting.ai_stability_api_key = ""
|
||||
SiteSetting.ai_google_custom_search_api_key = ""
|
||||
|
||||
expect(DiscourseAi::AiBot::Personas.all(user: user)).to contain_exactly(
|
||||
expect(DiscourseAi::AiBot::Personas::Persona.all(user: user)).to contain_exactly(
|
||||
General,
|
||||
SqlHelper,
|
||||
SettingsExplorer,
|
||||
Creative,
|
||||
)
|
||||
|
||||
AiPersona.find(DiscourseAi::AiBot::Personas.system_personas[General]).update!(
|
||||
AiPersona.find(DiscourseAi::AiBot::Personas::Persona.system_personas[General]).update!(
|
||||
enabled: false,
|
||||
)
|
||||
|
||||
expect(DiscourseAi::AiBot::Personas.all(user: user)).to contain_exactly(
|
||||
expect(DiscourseAi::AiBot::Personas::Persona.all(user: user)).to contain_exactly(
|
||||
SqlHelper,
|
||||
SettingsExplorer,
|
||||
Creative,
|
||||
|
|
|
@ -6,6 +6,6 @@ RSpec.describe DiscourseAi::AiBot::Personas::Researcher do
|
|||
end
|
||||
|
||||
it "renders schema" do
|
||||
expect(researcher.commands).to eq([DiscourseAi::AiBot::Commands::GoogleCommand])
|
||||
expect(researcher.tools).to eq([DiscourseAi::AiBot::Tools::Google])
|
||||
end
|
||||
end
|
||||
|
|
|
@ -6,18 +6,15 @@ RSpec.describe DiscourseAi::AiBot::Personas::SettingsExplorer do
|
|||
end
|
||||
|
||||
it "renders schema" do
|
||||
prompt = settings_explorer.render_system_prompt
|
||||
prompt = settings_explorer.system_prompt
|
||||
|
||||
# check we do not render plugin settings
|
||||
expect(prompt).not_to include("ai_bot_enabled_personas")
|
||||
|
||||
expect(prompt).to include("site_description")
|
||||
|
||||
expect(settings_explorer.available_commands).to eq(
|
||||
[
|
||||
DiscourseAi::AiBot::Commands::SettingContextCommand,
|
||||
DiscourseAi::AiBot::Commands::SearchSettingsCommand,
|
||||
],
|
||||
expect(settings_explorer.tools).to eq(
|
||||
[DiscourseAi::AiBot::Tools::SettingContext, DiscourseAi::AiBot::Tools::SearchSettings],
|
||||
)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -6,12 +6,12 @@ RSpec.describe DiscourseAi::AiBot::Personas::SqlHelper do
|
|||
end
|
||||
|
||||
it "renders schema" do
|
||||
prompt = sql_helper.render_system_prompt
|
||||
prompt = sql_helper.system_prompt
|
||||
expect(prompt).to include("posts(")
|
||||
expect(prompt).to include("topics(")
|
||||
expect(prompt).not_to include("translation_key") # not a priority table
|
||||
expect(prompt).to include("user_api_keys") # not a priority table
|
||||
|
||||
expect(sql_helper.available_commands).to eq([DiscourseAi::AiBot::Commands::DbSchemaCommand])
|
||||
expect(sql_helper.tools).to eq([DiscourseAi::AiBot::Tools::DbSchema])
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Playground do
|
||||
subject(:playground) { described_class.new(bot) }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_bot_enabled_chat_bots = "gpt-4"
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
end
|
||||
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) }
|
||||
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user) }
|
||||
|
||||
fab!(:user) { Fabricate(:user) }
|
||||
let!(:pm) do
|
||||
Fabricate(
|
||||
:private_message_topic,
|
||||
title: "This is my special PM",
|
||||
user: user,
|
||||
topic_allowed_users: [
|
||||
Fabricate.build(:topic_allowed_user, user: user),
|
||||
Fabricate.build(:topic_allowed_user, user: bot_user),
|
||||
],
|
||||
)
|
||||
end
|
||||
let!(:first_post) do
|
||||
Fabricate(:post, topic: pm, user: user, post_number: 1, raw: "This is a reply by the user")
|
||||
end
|
||||
let!(:second_post) do
|
||||
Fabricate(:post, topic: pm, user: bot_user, post_number: 2, raw: "This is a bot reply")
|
||||
end
|
||||
let!(:third_post) do
|
||||
Fabricate(
|
||||
:post,
|
||||
topic: pm,
|
||||
user: user,
|
||||
post_number: 3,
|
||||
raw: "This is a second reply by the user",
|
||||
)
|
||||
end
|
||||
|
||||
describe "#title_playground" do
|
||||
let(:expected_response) { "This is a suggested title" }
|
||||
|
||||
before { SiteSetting.min_personal_message_post_length = 5 }
|
||||
|
||||
it "updates the title using bot suggestions" do
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([expected_response]) do
|
||||
playground.title_playground(third_post)
|
||||
|
||||
expect(pm.reload.title).to eq(expected_response)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "#reply_to" do
|
||||
it "streams the bot reply through MB and create a new post in the PM with a cooked responses" do
|
||||
expected_bot_response =
|
||||
"Hello this is a bot and what you just said is an interesting question"
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([expected_bot_response]) do
|
||||
messages =
|
||||
MessageBus.track_publish("discourse-ai/ai-bot/topic/#{pm.id}") do
|
||||
playground.reply_to(third_post)
|
||||
end
|
||||
|
||||
done_signal = messages.pop
|
||||
expect(done_signal.data[:done]).to eq(true)
|
||||
|
||||
messages.each_with_index do |m, idx|
|
||||
expect(m.data[:raw]).to eq(expected_bot_response[0..idx])
|
||||
end
|
||||
|
||||
expect(pm.reload.posts.last.cooked).to eq(PrettyText.cook(expected_bot_response))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "#conversation_context" do
|
||||
it "includes previous posts ordered by post_number" do
|
||||
context = playground.conversation_context(third_post)
|
||||
|
||||
expect(context).to contain_exactly(
|
||||
*[
|
||||
{ type: "user", name: user.username, content: third_post.raw },
|
||||
{ type: "assistant", content: second_post.raw },
|
||||
{ type: "user", name: user.username, content: first_post.raw },
|
||||
],
|
||||
)
|
||||
end
|
||||
|
||||
it "only include regular posts" do
|
||||
first_post.update!(post_type: Post.types[:whisper])
|
||||
|
||||
context = playground.conversation_context(third_post)
|
||||
|
||||
expect(context).to contain_exactly(
|
||||
*[
|
||||
{ type: "user", name: user.username, content: third_post.raw },
|
||||
{ type: "assistant", content: second_post.raw },
|
||||
],
|
||||
)
|
||||
end
|
||||
|
||||
context "with custom prompts" do
|
||||
it "When post custom prompt is present, we use that instead of the post content" do
|
||||
custom_prompt = [
|
||||
[
|
||||
{ args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
|
||||
"time",
|
||||
"tool",
|
||||
],
|
||||
[
|
||||
{ name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json,
|
||||
"time",
|
||||
"tool_call",
|
||||
],
|
||||
["I replied this thanks to the time command", bot_user.username],
|
||||
]
|
||||
|
||||
PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
|
||||
|
||||
context = playground.conversation_context(third_post)
|
||||
|
||||
expect(context).to contain_exactly(
|
||||
*[
|
||||
{ type: "user", name: user.username, content: third_post.raw },
|
||||
{ type: "assistant", content: custom_prompt.third.first },
|
||||
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
|
||||
{ type: "tool", name: "time", content: custom_prompt.first.first },
|
||||
{ type: "user", name: user.username, content: first_post.raw },
|
||||
],
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
it "include replies generated from tools only once" do
|
||||
custom_prompt = [
|
||||
[
|
||||
{ args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
|
||||
"time",
|
||||
"tool",
|
||||
],
|
||||
[
|
||||
{ name: "time", arguments: { name: "time", timezone: "Buenos Aires" }.to_json }.to_json,
|
||||
"time",
|
||||
"tool_call",
|
||||
],
|
||||
["I replied this thanks to the time command", bot_user.username],
|
||||
]
|
||||
PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
|
||||
PostCustomPrompt.create!(post: first_post, custom_prompt: custom_prompt)
|
||||
|
||||
context = playground.conversation_context(third_post)
|
||||
|
||||
expect(context).to contain_exactly(
|
||||
*[
|
||||
{ type: "user", name: user.username, content: third_post.raw },
|
||||
{ type: "assistant", content: custom_prompt.third.first },
|
||||
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
|
||||
{ type: "tool", name: "time", content: custom_prompt.first.first },
|
||||
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
|
||||
{ type: "tool", name: "time", content: custom_prompt.first.first },
|
||||
],
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,8 +1,13 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::DallE do
|
||||
subject(:dall_e) { described_class.new({ prompts: prompts }) }
|
||||
|
||||
let(:prompts) { ["a pink cow", "a red cow"] }
|
||||
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:bot) { DiscourseAi::AiBot::OpenAiBot.new(bot_user) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
||||
let(:progress_blk) { Proc.new {} }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
|
@ -17,7 +22,6 @@ RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do
|
|||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
||||
|
||||
data = [{ b64_json: image, revised_prompt: "a pink cow 1" }]
|
||||
prompts = ["a pink cow", "a red cow"]
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, SiteSetting.ai_openai_dall_e_3_url)
|
||||
|
@ -30,14 +34,12 @@ RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do
|
|||
end
|
||||
.to_return(status: 200, body: { data: data }.to_json)
|
||||
|
||||
image = described_class.new(bot: bot, post: post, args: nil)
|
||||
|
||||
info = image.process(prompts: prompts).to_json
|
||||
info = dall_e.invoke(bot_user, llm, &progress_blk).to_json
|
||||
|
||||
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"])
|
||||
expect(image.custom_raw).to include("upload://")
|
||||
expect(image.custom_raw).to include("[grid]")
|
||||
expect(image.custom_raw).to include("a pink cow 1")
|
||||
expect(subject.custom_raw).to include("upload://")
|
||||
expect(subject.custom_raw).to include("[grid]")
|
||||
expect(subject.custom_raw).to include("a pink cow 1")
|
||||
end
|
||||
|
||||
it "can generate correct info" do
|
||||
|
@ -49,7 +51,6 @@ RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do
|
|||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
||||
|
||||
data = [{ b64_json: image, revised_prompt: "a pink cow 1" }]
|
||||
prompts = ["a pink cow", "a red cow"]
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.openai.com/v1/images/generations")
|
||||
|
@ -60,14 +61,12 @@ RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do
|
|||
end
|
||||
.to_return(status: 200, body: { data: data }.to_json)
|
||||
|
||||
image = described_class.new(bot: bot, post: post, args: nil)
|
||||
|
||||
info = image.process(prompts: prompts).to_json
|
||||
info = dall_e.invoke(bot_user, llm, &progress_blk).to_json
|
||||
|
||||
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"])
|
||||
expect(image.custom_raw).to include("upload://")
|
||||
expect(image.custom_raw).to include("[grid]")
|
||||
expect(image.custom_raw).to include("a pink cow 1")
|
||||
expect(subject.custom_raw).to include("upload://")
|
||||
expect(subject.custom_raw).to include("[grid]")
|
||||
expect(subject.custom_raw).to include("a pink cow 1")
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,10 +1,14 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::DbSchemaCommand do
|
||||
let(:command) { DiscourseAi::AiBot::Commands::DbSchemaCommand.new(bot: nil, args: nil) }
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::DbSchema do
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
describe "#process" do
|
||||
it "returns rich schema for tables" do
|
||||
result = command.process(tables: "posts,topics")
|
||||
result = described_class.new({ tables: "posts,topics" }).invoke(bot_user, llm)
|
||||
|
||||
expect(result[:schema_info]).to include("raw text")
|
||||
expect(result[:schema_info]).to include("views integer")
|
||||
expect(result[:schema_info]).to include("posts")
|
|
@ -1,8 +1,11 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::Google do
|
||||
subject(:search) { described_class.new({ query: "some search term" }) }
|
||||
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:bot) { DiscourseAi::AiBot::OpenAiBot.new(bot_user) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
||||
let(:progress_blk) { Proc.new {} }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
|
@ -20,10 +23,9 @@ RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do
|
|||
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
|
||||
).to_return(status: 200, body: json_text, headers: {})
|
||||
|
||||
google = described_class.new(bot: nil, post: post, args: {}.to_json)
|
||||
info = google.process(query: "some search term").to_json
|
||||
info = search.invoke(bot_user, llm, &progress_blk).to_json
|
||||
|
||||
expect(google.description_args[:count]).to eq(0)
|
||||
expect(search.results_count).to eq(0)
|
||||
expect(info).to_not include("oops")
|
||||
end
|
||||
|
||||
|
@ -61,23 +63,14 @@ RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do
|
|||
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
|
||||
).to_return(status: 200, body: json_text, headers: {})
|
||||
|
||||
google =
|
||||
described_class.new(bot: bot, post: post, args: { query: "some search term" }.to_json)
|
||||
info = search.invoke(bot_user, llm, &progress_blk).to_json
|
||||
|
||||
info = google.process(query: "some search term").to_json
|
||||
|
||||
expect(google.description_args[:count]).to eq(2)
|
||||
expect(search.results_count).to eq(2)
|
||||
expect(info).to include("title1")
|
||||
expect(info).to include("snippet1")
|
||||
expect(info).to include("some+search+term")
|
||||
expect(info).to include("title2")
|
||||
expect(info).to_not include("oops")
|
||||
|
||||
google.invoke!
|
||||
|
||||
expect(post.reload.raw).to include("some search term")
|
||||
|
||||
expect { google.invoke! }.to raise_error(StandardError)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,8 +1,14 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::Image do
|
||||
subject(:tool) { described_class.new({ prompts: prompts, seeds: [99, 32] }) }
|
||||
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
||||
let(:progress_blk) { Proc.new {} }
|
||||
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:bot) { DiscourseAi::AiBot::OpenAiBot.new(bot_user) }
|
||||
|
||||
let(:prompts) { ["a pink cow", "a red cow"] }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
|
@ -17,7 +23,6 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
|
|||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
||||
|
||||
artifacts = [{ base64: image, seed: 99 }]
|
||||
prompts = ["a pink cow", "a red cow"]
|
||||
|
||||
WebMock
|
||||
.stub_request(
|
||||
|
@ -31,15 +36,13 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
|
|||
end
|
||||
.to_return(status: 200, body: { artifacts: artifacts }.to_json)
|
||||
|
||||
image = described_class.new(bot: bot, post: post, args: nil)
|
||||
|
||||
info = image.process(prompts: prompts).to_json
|
||||
info = tool.invoke(bot_user, llm, &progress_blk).to_json
|
||||
|
||||
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow", "a red cow"], "seeds" => [99, 99])
|
||||
expect(image.custom_raw).to include("upload://")
|
||||
expect(image.custom_raw).to include("[grid]")
|
||||
expect(image.custom_raw).to include("a pink cow")
|
||||
expect(image.custom_raw).to include("a red cow")
|
||||
expect(tool.custom_raw).to include("upload://")
|
||||
expect(tool.custom_raw).to include("[grid]")
|
||||
expect(tool.custom_raw).to include("a pink cow")
|
||||
expect(tool.custom_raw).to include("a red cow")
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,19 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::ListCategories do
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
describe "#process" do
|
||||
it "list available categories" do
|
||||
Fabricate(:category, name: "america", posts_year: 999)
|
||||
|
||||
info = described_class.new({}).invoke(bot_user, llm).to_s
|
||||
|
||||
expect(info).to include("america")
|
||||
expect(info).to include("999")
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,23 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::ListTags do
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
SiteSetting.tagging_enabled = true
|
||||
end
|
||||
|
||||
describe "#process" do
|
||||
it "can generate correct info" do
|
||||
Fabricate(:tag, name: "america", public_topic_count: 100)
|
||||
Fabricate(:tag, name: "not_here", public_topic_count: 0)
|
||||
|
||||
info = described_class.new({}).invoke(bot_user, llm)
|
||||
|
||||
expect(info.to_s).to include("america")
|
||||
expect(info.to_s).not_to include("not_here")
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,8 +1,10 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::ReadCommand do
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::Read do
|
||||
subject(:tool) { described_class.new({ topic_id: topic_with_tags.id }) }
|
||||
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
||||
|
||||
fab!(:parent_category) { Fabricate(:category, name: "animals") }
|
||||
fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") }
|
||||
|
@ -32,9 +34,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::ReadCommand do
|
|||
Fabricate(:post, topic: topic_with_tags, raw: "hello there")
|
||||
Fabricate(:post, topic: topic_with_tags, raw: "mister sam")
|
||||
|
||||
read = described_class.new(bot: bot, args: nil)
|
||||
|
||||
results = read.process(topic_id: topic_id)
|
||||
results = tool.invoke(bot_user, llm)
|
||||
|
||||
expect(results[:topic_id]).to eq(topic_id)
|
||||
expect(results[:content]).to include("hello")
|
||||
|
@ -44,10 +44,8 @@ RSpec.describe DiscourseAi::AiBot::Commands::ReadCommand do
|
|||
expect(results[:content]).to include("sad")
|
||||
expect(results[:content]).to include("animals")
|
||||
expect(results[:content]).not_to include("hidden")
|
||||
expect(read.description_args).to eq(
|
||||
title: topic_with_tags.title,
|
||||
url: topic_with_tags.relative_url,
|
||||
)
|
||||
expect(tool.title).to eq(topic_with_tags.title)
|
||||
expect(tool.url).to eq(topic_with_tags.relative_url)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,39 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
def search_settings(query)
|
||||
described_class.new({ query: query })
|
||||
end
|
||||
|
||||
describe "#process" do
|
||||
it "can handle no results" do
|
||||
results = search_settings("this will not exist frogs").invoke(bot_user, llm)
|
||||
expect(results[:args]).to eq({ query: "this will not exist frogs" })
|
||||
expect(results[:rows]).to eq([])
|
||||
end
|
||||
|
||||
it "can return more many settings with no descriptions if there are lots of hits" do
|
||||
results = search_settings("a").invoke(bot_user, llm)
|
||||
|
||||
expect(results[:rows].length).to be > 30
|
||||
expect(results[:rows][0].length).to eq(1)
|
||||
end
|
||||
|
||||
it "can return descriptions if there are few matches" do
|
||||
results =
|
||||
search_settings("this will not be found!@,default_locale,ai_bot_enabled_chat_bots").invoke(
|
||||
bot_user,
|
||||
llm,
|
||||
)
|
||||
|
||||
expect(results[:rows].length).to eq(2)
|
||||
|
||||
expect(results[:rows][0][1]).not_to eq(nil)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,140 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
||||
before { SearchIndexer.enable }
|
||||
after { SearchIndexer.disable }
|
||||
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
||||
let(:progress_blk) { Proc.new {} }
|
||||
|
||||
fab!(:admin)
|
||||
fab!(:parent_category) { Fabricate(:category, name: "animals") }
|
||||
fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") }
|
||||
fab!(:tag_funny) { Fabricate(:tag, name: "funny") }
|
||||
fab!(:tag_sad) { Fabricate(:tag, name: "sad") }
|
||||
fab!(:tag_hidden) { Fabricate(:tag, name: "hidden") }
|
||||
fab!(:staff_tag_group) do
|
||||
tag_group = Fabricate.build(:tag_group, name: "Staff only", tag_names: ["hidden"])
|
||||
|
||||
tag_group.permissions = [
|
||||
[Group::AUTO_GROUPS[:staff], TagGroupPermission.permission_types[:full]],
|
||||
]
|
||||
tag_group.save!
|
||||
tag_group
|
||||
end
|
||||
fab!(:topic_with_tags) do
|
||||
Fabricate(:topic, category: category, tags: [tag_funny, tag_sad, tag_hidden])
|
||||
end
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
describe "#invoke" do
|
||||
it "can retreive options from persona correctly" do
|
||||
persona_options = { "base_query" => "#funny" }
|
||||
|
||||
search_post = Fabricate(:post, topic: topic_with_tags)
|
||||
|
||||
bot_post = Fabricate(:post)
|
||||
|
||||
search = described_class.new({ order: "latest" }, persona_options: persona_options)
|
||||
|
||||
results = search.invoke(bot_user, llm, &progress_blk)
|
||||
expect(results[:rows].length).to eq(1)
|
||||
|
||||
search_post.topic.tags = []
|
||||
search_post.topic.save!
|
||||
|
||||
# no longer has the tag funny
|
||||
results = search.invoke(bot_user, llm, &progress_blk)
|
||||
expect(results[:rows].length).to eq(0)
|
||||
end
|
||||
|
||||
it "can handle no results" do
|
||||
post1 = Fabricate(:post, topic: topic_with_tags)
|
||||
search = described_class.new({ search_query: "ABDDCDCEDGDG", order: "fake" })
|
||||
|
||||
results = search.invoke(bot_user, llm, &progress_blk)
|
||||
|
||||
expect(results[:args]).to eq({ search_query: "ABDDCDCEDGDG", order: "fake" })
|
||||
expect(results[:rows]).to eq([])
|
||||
end
|
||||
|
||||
describe "semantic search" do
|
||||
let (:query) {
|
||||
"this is an expanded search"
|
||||
}
|
||||
after { DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query) }
|
||||
|
||||
it "supports semantic search when enabled" do
|
||||
SiteSetting.ai_embeddings_semantic_search_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
|
||||
hyde_embedding = [0.049382, 0.9999]
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
SiteSetting.ai_embeddings_model,
|
||||
query,
|
||||
hyde_embedding,
|
||||
)
|
||||
|
||||
post1 = Fabricate(:post, topic: topic_with_tags)
|
||||
search = described_class.new({ search_query: "hello world, sam", status: "public" })
|
||||
|
||||
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
|
||||
.any_instance
|
||||
.expects(:asymmetric_topics_similarity_search)
|
||||
.returns([post1.topic_id])
|
||||
|
||||
results =
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(["<ai>#{query}</ai>"]) do
|
||||
search.invoke(bot_user, llm, &progress_blk)
|
||||
end
|
||||
|
||||
expect(results[:args]).to eq({ search_query: "hello world, sam", status: "public" })
|
||||
expect(results[:rows].length).to eq(1)
|
||||
end
|
||||
end
|
||||
|
||||
it "supports subfolder properly" do
|
||||
Discourse.stubs(:base_path).returns("/subfolder")
|
||||
|
||||
post1 = Fabricate(:post, topic: topic_with_tags)
|
||||
|
||||
search = described_class.new({ limit: 1, user: post1.user.username })
|
||||
|
||||
results = search.invoke(bot_user, llm, &progress_blk)
|
||||
expect(results[:rows].to_s).to include("/subfolder" + post1.url)
|
||||
end
|
||||
|
||||
it "returns rich topic information" do
|
||||
post1 = Fabricate(:post, like_count: 1, topic: topic_with_tags)
|
||||
search = described_class.new({ user: post1.user.username })
|
||||
post1.topic.update!(views: 100, posts_count: 2, like_count: 10)
|
||||
|
||||
results = search.invoke(bot_user, llm, &progress_blk)
|
||||
|
||||
row = results[:rows].first
|
||||
category = row[results[:column_names].index("category")]
|
||||
|
||||
expect(category).to eq("animals > amazing-cat")
|
||||
|
||||
tags = row[results[:column_names].index("tags")]
|
||||
expect(tags).to eq("funny, sad")
|
||||
|
||||
likes = row[results[:column_names].index("likes")]
|
||||
expect(likes).to eq(1)
|
||||
|
||||
username = row[results[:column_names].index("username")]
|
||||
expect(username).to eq(post1.user.username)
|
||||
|
||||
likes = row[results[:column_names].index("topic_likes")]
|
||||
expect(likes).to eq(10)
|
||||
|
||||
views = row[results[:column_names].index("topic_views")]
|
||||
expect(views).to eq(100)
|
||||
|
||||
replies = row[results[:column_names].index("topic_replies")]
|
||||
expect(replies).to eq(1)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,21 +1,26 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::SettingContextCommand do
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:command) { described_class.new(bot_user: bot_user, args: nil) }
|
||||
|
||||
def has_rg?
|
||||
def has_rg?
|
||||
if defined?(@has_rg)
|
||||
@has_rg
|
||||
else
|
||||
@has_rg |= system("which rg")
|
||||
end
|
||||
end
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
def setting_context(setting_name)
|
||||
described_class.new({ setting_name: setting_name })
|
||||
end
|
||||
|
||||
describe "#execute" do
|
||||
skip("rg is needed for these tests") if !has_rg?
|
||||
it "returns the context for core setting" do
|
||||
result = command.process(setting_name: "moderators_view_emails")
|
||||
result = setting_context("moderators_view_emails").invoke(bot_user, llm)
|
||||
|
||||
expect(result[:setting_name]).to eq("moderators_view_emails")
|
||||
|
||||
|
@ -23,18 +28,17 @@ RSpec.describe DiscourseAi::AiBot::Commands::SettingContextCommand do
|
|||
expect(result[:context]).to include("moderators_view_emails")
|
||||
end
|
||||
|
||||
skip("rg is needed for these tests") if !has_rg?
|
||||
it "returns the context for plugin setting" do
|
||||
result = command.process(setting_name: "ai_bot_enabled")
|
||||
result = setting_context("ai_bot_enabled").invoke(bot_user, llm)
|
||||
|
||||
expect(result[:setting_name]).to eq("ai_bot_enabled")
|
||||
expect(result[:context]).to include("ai_bot_enabled:")
|
||||
end
|
||||
|
||||
context "when the setting does not exist" do
|
||||
skip("rg is needed for these tests") if !has_rg?
|
||||
it "returns an error message" do
|
||||
result = command.process(setting_name: "this_setting_does_not_exist")
|
||||
result = setting_context("this_setting_does_not_exist").invoke(bot_user, llm)
|
||||
|
||||
expect(result[:context]).to eq("This setting does not exist")
|
||||
end
|
||||
end
|
|
@ -0,0 +1,46 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::Summarize do
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
||||
let(:progress_blk) { Proc.new {} }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
let(:summary) { "summary stuff" }
|
||||
|
||||
describe "#process" do
|
||||
it "can generate correct info" do
|
||||
post = Fabricate(:post)
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
|
||||
summarization =
|
||||
described_class.new({ topic_id: post.topic_id, guidance: "why did it happen?" })
|
||||
info = summarization.invoke(bot_user, llm, &progress_blk)
|
||||
|
||||
expect(info).to include("Topic summarized")
|
||||
expect(summarization.custom_raw).to include(summary)
|
||||
expect(summarization.chain_next_response?).to eq(false)
|
||||
end
|
||||
end
|
||||
|
||||
it "protects hidden data" do
|
||||
category = Fabricate(:category)
|
||||
category.set_permissions({})
|
||||
category.save!
|
||||
|
||||
topic = Fabricate(:topic, category_id: category.id)
|
||||
post = Fabricate(:post, topic: topic)
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
|
||||
summarization =
|
||||
described_class.new({ topic_id: post.topic_id, guidance: "why did it happen?" })
|
||||
info = summarization.invoke(bot_user, llm, &progress_blk)
|
||||
|
||||
expect(info).not_to include(post.raw)
|
||||
|
||||
expect(summarization.custom_raw).to eq(I18n.t("discourse_ai.ai_bot.topic_not_found"))
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,12 +1,17 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::TimeCommand do
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::Time do
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
describe "#process" do
|
||||
it "can generate correct info" do
|
||||
freeze_time
|
||||
|
||||
args = { timezone: "America/Los_Angeles" }
|
||||
info = DiscourseAi::AiBot::Commands::TimeCommand.new(bot: nil, args: nil).process(**args)
|
||||
info = described_class.new(args).invoke(bot_user, llm)
|
||||
|
||||
expect(info).to eq({ args: args, time: Time.now.in_time_zone("America/Los_Angeles").to_s })
|
||||
expect(info.to_s).not_to include("not_here")
|
|
@ -14,7 +14,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
|
||||
expect(response.parsed_body["ai_personas"].length).to eq(AiPersona.count)
|
||||
expect(response.parsed_body["meta"]["commands"].length).to eq(
|
||||
DiscourseAi::AiBot::Personas::Persona.all_available_commands.length,
|
||||
DiscourseAi::AiBot::Personas::Persona.all_available_tools.length,
|
||||
)
|
||||
end
|
||||
|
||||
|
@ -34,7 +34,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
serializer_persona2 = response.parsed_body["ai_personas"].find { |p| p["id"] == persona2.id }
|
||||
|
||||
commands = response.parsed_body["meta"]["commands"]
|
||||
search_command = commands.find { |c| c["id"] == "SearchCommand" }
|
||||
search_command = commands.find { |c| c["id"] == "Search" }
|
||||
|
||||
expect(search_command["help"]).to eq(I18n.t("discourse_ai.ai_bot.command_help.search"))
|
||||
|
||||
|
@ -71,7 +71,8 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
"Général Description",
|
||||
)
|
||||
|
||||
id = DiscourseAi::AiBot::Personas.system_personas[DiscourseAi::AiBot::Personas::General]
|
||||
id =
|
||||
DiscourseAi::AiBot::Personas::Persona.system_personas[DiscourseAi::AiBot::Personas::General]
|
||||
name = I18n.t("discourse_ai.ai_bot.personas.general.name")
|
||||
description = I18n.t("discourse_ai.ai_bot.personas.general.description")
|
||||
persona = response.parsed_body["ai_personas"].find { |p| p["id"] == id }
|
||||
|
@ -147,7 +148,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
|
||||
context "with system personas" do
|
||||
it "does not allow editing of system prompts" do
|
||||
put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas.system_personas.values.first}.json",
|
||||
put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json",
|
||||
params: {
|
||||
ai_persona: {
|
||||
system_prompt: "you are not a helpful bot",
|
||||
|
@ -160,7 +161,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
end
|
||||
|
||||
it "does not allow editing of commands" do
|
||||
put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas.system_personas.values.first}.json",
|
||||
put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json",
|
||||
params: {
|
||||
ai_persona: {
|
||||
commands: %w[SearchCommand ImageCommand],
|
||||
|
@ -173,7 +174,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
end
|
||||
|
||||
it "does not allow editing of name and description cause it is localized" do
|
||||
put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas.system_personas.values.first}.json",
|
||||
put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json",
|
||||
params: {
|
||||
ai_persona: {
|
||||
name: "bob",
|
||||
|
@ -187,7 +188,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
end
|
||||
|
||||
it "does allow some actions" do
|
||||
put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas.system_personas.values.first}.json",
|
||||
put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json",
|
||||
params: {
|
||||
ai_persona: {
|
||||
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_1]],
|
||||
|
@ -225,7 +226,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
|
||||
it "is not allowed to delete system personas" do
|
||||
expect {
|
||||
delete "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas.system_personas.values.first}.json"
|
||||
delete "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json"
|
||||
expect(response).to have_http_status(:unprocessable_entity)
|
||||
expect(response.parsed_body["errors"].join).not_to be_blank
|
||||
# let's make sure this is translated
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Inference::AnthropicCompletions do
|
||||
before { SiteSetting.ai_anthropic_api_key = "abc-123" }
|
||||
|
||||
it "can complete a trivial prompt" do
|
||||
response_text = "1. Serenity\\n2. Laughter\\n3. Adventure"
|
||||
prompt = "Human: write 3 words\n\n"
|
||||
user_id = 183
|
||||
req_opts = { max_tokens_to_sample: 700, temperature: 0.5 }
|
||||
|
||||
AnthropicCompletionStubs.stub_response(prompt, response_text, req_opts: req_opts)
|
||||
|
||||
completions =
|
||||
DiscourseAi::Inference::AnthropicCompletions.perform!(
|
||||
prompt,
|
||||
"claude-2",
|
||||
temperature: req_opts[:temperature],
|
||||
max_tokens: req_opts[:max_tokens_to_sample],
|
||||
user_id: user_id,
|
||||
)
|
||||
|
||||
expect(completions[:completion]).to eq(response_text)
|
||||
|
||||
expect(AiApiAuditLog.count).to eq(1)
|
||||
log = AiApiAuditLog.first
|
||||
|
||||
request_body = { model: "claude-2", prompt: prompt }.merge(req_opts).to_json
|
||||
response_body = AnthropicCompletionStubs.response(response_text).to_json
|
||||
|
||||
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
|
||||
expect(log.request_tokens).to eq(6)
|
||||
expect(log.response_tokens).to eq(16)
|
||||
expect(log.raw_request_payload).to eq(request_body)
|
||||
expect(log.raw_response_payload).to eq(response_body)
|
||||
end
|
||||
|
||||
it "supports streaming mode" do
|
||||
deltas = ["Mount", "ain", " ", "Tree ", "Frog"]
|
||||
prompt = "Human: write 3 words\n\n"
|
||||
req_opts = { max_tokens_to_sample: 300, stream: true }
|
||||
content = +""
|
||||
|
||||
AnthropicCompletionStubs.stub_streamed_response(prompt, deltas, req_opts: req_opts)
|
||||
|
||||
DiscourseAi::Inference::AnthropicCompletions.perform!(
|
||||
prompt,
|
||||
"claude-2",
|
||||
max_tokens: req_opts[:max_tokens_to_sample],
|
||||
) do |partial, cancel|
|
||||
data = partial[:completion]
|
||||
content << data if data
|
||||
cancel.call if content.split(" ").length == 2
|
||||
end
|
||||
|
||||
expect(content).to eq("Mountain Tree ")
|
||||
|
||||
expect(AiApiAuditLog.count).to eq(1)
|
||||
log = AiApiAuditLog.first
|
||||
|
||||
request_body = { model: "claude-2", prompt: prompt }.merge(req_opts).to_json
|
||||
|
||||
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
|
||||
expect(log.request_tokens).to eq(6)
|
||||
expect(log.response_tokens).to eq(3)
|
||||
expect(log.raw_request_payload).to eq(request_body)
|
||||
expect(log.raw_response_payload).to be_present
|
||||
end
|
||||
end
|
|
@ -1,108 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
require "rails_helper"
|
||||
|
||||
module DiscourseAi::Inference
|
||||
describe FunctionList do
|
||||
let :function_list do
|
||||
function =
|
||||
Function.new(name: "get_weather", description: "Get the weather in a city (default to c)")
|
||||
|
||||
function.add_parameter(
|
||||
name: "location",
|
||||
type: "string",
|
||||
description: "the city name",
|
||||
required: true,
|
||||
)
|
||||
|
||||
function.add_parameter(
|
||||
name: "unit",
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
required: false,
|
||||
)
|
||||
|
||||
list = FunctionList.new
|
||||
list << function
|
||||
list
|
||||
end
|
||||
|
||||
let :image_function_list do
|
||||
function = Function.new(name: "image", description: "generates an image")
|
||||
|
||||
function.add_parameter(
|
||||
name: "prompts",
|
||||
type: "array",
|
||||
item_type: "string",
|
||||
required: true,
|
||||
description: "the prompts",
|
||||
)
|
||||
|
||||
list = FunctionList.new
|
||||
list << function
|
||||
list
|
||||
end
|
||||
|
||||
it "can handle function call parsing" do
|
||||
raw_prompt = <<~PROMPT
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>image</tool_name>
|
||||
<parameters>
|
||||
<prompts>
|
||||
[
|
||||
"an oil painting",
|
||||
"a cute fluffy orange",
|
||||
"3 apple's",
|
||||
"a cat"
|
||||
]
|
||||
</prompts>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
PROMPT
|
||||
parsed = image_function_list.parse_prompt(raw_prompt)
|
||||
expect(parsed).to eq(
|
||||
[
|
||||
{
|
||||
name: "image",
|
||||
arguments: {
|
||||
prompts: ["an oil painting", "a cute fluffy orange", "3 apple's", "a cat"],
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
end
|
||||
|
||||
it "can generate a general custom system prompt" do
|
||||
prompt = function_list.system_prompt
|
||||
|
||||
# this is fragile, by design, we need to test something here
|
||||
#
|
||||
expected = <<~PROMPT
|
||||
<tools>
|
||||
<tool_description>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<description>Get the weather in a city (default to c)</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>location</name>
|
||||
<type>string</type>
|
||||
<description>the city name</description>
|
||||
<required>true</required>
|
||||
</parameter>
|
||||
<parameter>
|
||||
<name>unit</name>
|
||||
<type>string</type>
|
||||
<description>the unit of measurement celcius c or fahrenheit f</description>
|
||||
<required>false</required>
|
||||
<options>c,f</options>
|
||||
</parameter>
|
||||
</parameters>
|
||||
</tool_description>
|
||||
</tools>
|
||||
PROMPT
|
||||
expect(prompt).to include(expected)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,417 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
require "rails_helper"
|
||||
|
||||
describe DiscourseAi::Inference::OpenAiCompletions do
|
||||
before { SiteSetting.ai_openai_api_key = "abc-123" }
|
||||
|
||||
fab!(:user)
|
||||
|
||||
it "supports sending an organization id" do
|
||||
SiteSetting.ai_openai_organization = "org_123"
|
||||
|
||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").with(
|
||||
body:
|
||||
"{\"model\":\"gpt-3.5-turbo-0613\",\"messages\":[{\"role\":\"system\",\"content\":\"hello\"}]}",
|
||||
headers: {
|
||||
"Authorization" => "Bearer abc-123",
|
||||
"Content-Type" => "application/json",
|
||||
"Host" => "api.openai.com",
|
||||
"User-Agent" => "Ruby",
|
||||
"OpenAI-Organization" => "org_123",
|
||||
},
|
||||
).to_return(
|
||||
status: 200,
|
||||
body: { choices: [message: { content: "world" }] }.to_json,
|
||||
headers: {
|
||||
},
|
||||
)
|
||||
|
||||
result =
|
||||
DiscourseAi::Inference::OpenAiCompletions.perform!(
|
||||
[{ role: "system", content: "hello" }],
|
||||
"gpt-3.5-turbo-0613",
|
||||
)
|
||||
|
||||
expect(result.dig(:choices, 0, :message, :content)).to eq("world")
|
||||
end
|
||||
|
||||
context "when configured using Azure" do
|
||||
it "Supports custom Azure endpoints for completions" do
|
||||
gpt_url_base =
|
||||
"https://company.openai.azure.com/openai/deployments/deployment/chat/completions?api-version=2023-03-15-preview"
|
||||
key = "12345"
|
||||
SiteSetting.ai_openai_api_key = key
|
||||
|
||||
[
|
||||
{ setting_name: "ai_openai_gpt35_url", model: "gpt-35-turbo" },
|
||||
{ setting_name: "ai_openai_gpt35_16k_url", model: "gpt-35-16k-turbo" },
|
||||
{ setting_name: "ai_openai_gpt4_url", model: "gpt-4" },
|
||||
{ setting_name: "ai_openai_gpt4_32k_url", model: "gpt-4-32k" },
|
||||
{ setting_name: "ai_openai_gpt4_turbo_url", model: "gpt-4-1106-preview" },
|
||||
].each do |config|
|
||||
gpt_url = "#{gpt_url_base}/#{config[:model]}"
|
||||
setting_name = config[:setting_name]
|
||||
model = config[:model]
|
||||
|
||||
SiteSetting.public_send("#{setting_name}=".to_sym, gpt_url)
|
||||
|
||||
expected = {
|
||||
id: "chatcmpl-7TfPzOyBGW5K6dyWp3NPU0mYLGZRQ",
|
||||
object: "chat.completion",
|
||||
created: 1_687_305_079,
|
||||
model: model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
finish_reason: "stop",
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: "Hi there! How can I assist you today?",
|
||||
},
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
completion_tokens: 10,
|
||||
prompt_tokens: 9,
|
||||
total_tokens: 19,
|
||||
},
|
||||
}
|
||||
|
||||
stub_request(:post, gpt_url).with(
|
||||
body: "{\"model\":\"#{model}\",\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}]}",
|
||||
headers: {
|
||||
"Api-Key" => "12345",
|
||||
"Content-Type" => "application/json",
|
||||
"Host" => "company.openai.azure.com",
|
||||
},
|
||||
).to_return(status: 200, body: expected.to_json, headers: {})
|
||||
|
||||
result =
|
||||
DiscourseAi::Inference::OpenAiCompletions.perform!(
|
||||
[role: "user", content: "hello"],
|
||||
model,
|
||||
)
|
||||
|
||||
expect(result).to eq(expected)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
it "supports function calling" do
|
||||
prompt = [role: "system", content: "you are weatherbot"]
|
||||
prompt << { role: "user", content: "what is the weather in sydney?" }
|
||||
|
||||
functions = []
|
||||
|
||||
function =
|
||||
DiscourseAi::Inference::Function.new(
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
)
|
||||
|
||||
function.add_parameter(
|
||||
name: "location",
|
||||
type: "string",
|
||||
description: "the city name",
|
||||
required: true,
|
||||
)
|
||||
|
||||
function.add_parameter(
|
||||
name: "unit",
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
required: true,
|
||||
)
|
||||
|
||||
functions << function
|
||||
|
||||
function_calls = []
|
||||
current_function_call = nil
|
||||
|
||||
deltas = [
|
||||
{ role: "assistant" },
|
||||
{ function_call: { name: "get_weather", arguments: "" } },
|
||||
{ function_call: { arguments: "{ \"location\": " } },
|
||||
{ function_call: { arguments: "\"sydney\", \"unit\": \"c\" }" } },
|
||||
]
|
||||
|
||||
OpenAiCompletionsInferenceStubs.stub_streamed_response(
|
||||
prompt,
|
||||
deltas,
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
req_opts: {
|
||||
functions: functions,
|
||||
stream: true,
|
||||
},
|
||||
)
|
||||
|
||||
DiscourseAi::Inference::OpenAiCompletions.perform!(
|
||||
prompt,
|
||||
"gpt-3.5-turbo-0613",
|
||||
functions: functions,
|
||||
) do |json, cancel|
|
||||
fn = json.dig(:choices, 0, :delta, :function_call)
|
||||
if fn && fn[:name]
|
||||
current_function_call = { name: fn[:name], arguments: +fn[:arguments].to_s.dup }
|
||||
function_calls << current_function_call
|
||||
elsif fn && fn[:arguments] && current_function_call
|
||||
current_function_call[:arguments] << fn[:arguments]
|
||||
end
|
||||
end
|
||||
|
||||
expect(function_calls.length).to eq(1)
|
||||
expect(function_calls[0][:name]).to eq("get_weather")
|
||||
expect(JSON.parse(function_calls[0][:arguments])).to eq(
|
||||
{ "location" => "sydney", "unit" => "c" },
|
||||
)
|
||||
|
||||
prompt << { role: "function", name: "get_weather", content: 22.to_json }
|
||||
|
||||
OpenAiCompletionsInferenceStubs.stub_response(
|
||||
prompt,
|
||||
"The current temperature in Sydney is 22 degrees Celsius.",
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
req_opts: {
|
||||
functions: functions,
|
||||
},
|
||||
)
|
||||
|
||||
result =
|
||||
DiscourseAi::Inference::OpenAiCompletions.perform!(
|
||||
prompt,
|
||||
"gpt-3.5-turbo-0613",
|
||||
functions: functions,
|
||||
)
|
||||
|
||||
expect(result.dig(:choices, 0, :message, :content)).to eq(
|
||||
"The current temperature in Sydney is 22 degrees Celsius.",
|
||||
)
|
||||
end
|
||||
|
||||
it "supports rate limits" do
|
||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||
[
|
||||
{ status: 429, body: "", headers: {} },
|
||||
{ status: 429, body: "", headers: {} },
|
||||
{ status: 200, body: { choices: [message: { content: "ok" }] }.to_json, headers: {} },
|
||||
],
|
||||
)
|
||||
completions =
|
||||
DiscourseAi::Inference::OpenAiCompletions.perform!(
|
||||
[{ role: "user", content: "hello" }],
|
||||
"gpt-3.5-turbo",
|
||||
temperature: 0.5,
|
||||
top_p: 0.8,
|
||||
max_tokens: 700,
|
||||
retries: 3,
|
||||
retry_timeout: 0,
|
||||
)
|
||||
|
||||
expect(completions.dig(:choices, 0, :message, :content)).to eq("ok")
|
||||
end
|
||||
|
||||
it "supports will raise once rate limit is met" do
|
||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||
[
|
||||
{ status: 429, body: "", headers: {} },
|
||||
{ status: 429, body: "", headers: {} },
|
||||
{ status: 429, body: "", headers: {} },
|
||||
],
|
||||
)
|
||||
|
||||
expect do
|
||||
DiscourseAi::Inference::OpenAiCompletions.perform!(
|
||||
[{ role: "user", content: "hello" }],
|
||||
"gpt-3.5-turbo",
|
||||
temperature: 0.5,
|
||||
top_p: 0.8,
|
||||
max_tokens: 700,
|
||||
retries: 3,
|
||||
retry_timeout: 0,
|
||||
)
|
||||
end.to raise_error(DiscourseAi::Inference::OpenAiCompletions::CompletionFailed)
|
||||
end
|
||||
|
||||
it "can complete a trivial prompt" do
|
||||
response_text = "1. Serenity\\n2. Laughter\\n3. Adventure"
|
||||
prompt = [role: "user", content: "write 3 words"]
|
||||
user_id = 183
|
||||
req_opts = { temperature: 0.5, top_p: 0.8, max_tokens: 700 }
|
||||
|
||||
OpenAiCompletionsInferenceStubs.stub_response(prompt, response_text, req_opts: req_opts)
|
||||
|
||||
completions =
|
||||
DiscourseAi::Inference::OpenAiCompletions.perform!(
|
||||
prompt,
|
||||
"gpt-3.5-turbo",
|
||||
temperature: 0.5,
|
||||
top_p: 0.8,
|
||||
max_tokens: 700,
|
||||
user_id: user_id,
|
||||
)
|
||||
|
||||
expect(completions.dig(:choices, 0, :message, :content)).to eq(response_text)
|
||||
|
||||
expect(AiApiAuditLog.count).to eq(1)
|
||||
log = AiApiAuditLog.first
|
||||
|
||||
body = { model: "gpt-3.5-turbo", messages: prompt }.merge(req_opts).to_json
|
||||
request_body = OpenAiCompletionsInferenceStubs.response(response_text).to_json
|
||||
|
||||
expect(log.provider_id).to eq(AiApiAuditLog::Provider::OpenAI)
|
||||
expect(log.request_tokens).to eq(337)
|
||||
expect(log.response_tokens).to eq(162)
|
||||
expect(log.raw_request_payload).to eq(body)
|
||||
expect(log.raw_response_payload).to eq(request_body)
|
||||
end
|
||||
|
||||
context "when Webmock has streaming support" do
|
||||
# See: https://github.com/bblimke/webmock/issues/629
|
||||
let(:mock_net_http) do
|
||||
Class.new(Net::HTTP) do
|
||||
def request(*)
|
||||
super do |response|
|
||||
response.instance_eval do
|
||||
def read_body(*, &)
|
||||
@body.each(&)
|
||||
end
|
||||
end
|
||||
|
||||
yield response if block_given?
|
||||
|
||||
response
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
let(:remove_original_net_http) { Net.send(:remove_const, :HTTP) }
|
||||
let(:original_http) { remove_original_net_http }
|
||||
let(:stub_net_http) { Net.send(:const_set, :HTTP, mock_net_http) }
|
||||
|
||||
let(:remove_stubbed_net_http) { Net.send(:remove_const, :HTTP) }
|
||||
let(:restore_net_http) { Net.send(:const_set, :HTTP, original_http) }
|
||||
|
||||
before do
|
||||
mock_net_http
|
||||
remove_original_net_http
|
||||
stub_net_http
|
||||
end
|
||||
|
||||
after do
|
||||
remove_stubbed_net_http
|
||||
restore_net_http
|
||||
end
|
||||
|
||||
it "recovers from chunked payload issues" do
|
||||
raw_data = <<~TEXT
|
||||
da|ta: |{"choices":[{"delta":{"content"|:"test"}}]}
|
||||
|
||||
data: {"choices":[{"delta":{"content":"test1"}}]}
|
||||
|
||||
data: {"choices":[{"delta":{"conte|nt":"test2"}}]|}
|
||||
|
||||
data: {"ch|oices":[{"delta|":{"content":"test3"}}]}
|
||||
|
||||
data: [DONE]
|
||||
TEXT
|
||||
|
||||
chunks = raw_data.split("|")
|
||||
|
||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||
status: 200,
|
||||
body: chunks,
|
||||
)
|
||||
|
||||
partials = []
|
||||
DiscourseAi::Inference::OpenAiCompletions.perform!([], "gpt-3.5-turbo") do |partial, cancel|
|
||||
partials << partial
|
||||
end
|
||||
|
||||
expect(partials.length).to eq(4)
|
||||
expect(partials).to eq(
|
||||
[
|
||||
{ choices: [{ delta: { content: "test" } }] },
|
||||
{ choices: [{ delta: { content: "test1" } }] },
|
||||
{ choices: [{ delta: { content: "test2" } }] },
|
||||
{ choices: [{ delta: { content: "test3" } }] },
|
||||
],
|
||||
)
|
||||
end
|
||||
|
||||
it "support extremely slow streaming" do
|
||||
raw_data = <<~TEXT
|
||||
data: {"choices":[{"delta":{"content":"test"}}]}
|
||||
|
||||
data: {"choices":[{"delta":{"content":"test1"}}]}
|
||||
|
||||
data: {"choices":[{"delta":{"content":"test2"}}]}
|
||||
|
||||
data: [DONE]
|
||||
TEXT
|
||||
|
||||
chunks = raw_data.split("")
|
||||
|
||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||
status: 200,
|
||||
body: chunks,
|
||||
)
|
||||
|
||||
partials = []
|
||||
DiscourseAi::Inference::OpenAiCompletions.perform!([], "gpt-3.5-turbo") do |partial, cancel|
|
||||
partials << partial
|
||||
end
|
||||
|
||||
expect(partials.length).to eq(3)
|
||||
expect(partials).to eq(
|
||||
[
|
||||
{ choices: [{ delta: { content: "test" } }] },
|
||||
{ choices: [{ delta: { content: "test1" } }] },
|
||||
{ choices: [{ delta: { content: "test2" } }] },
|
||||
],
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
it "can operate in streaming mode" do
|
||||
deltas = [
|
||||
{ role: "assistant" },
|
||||
{ content: "Mount" },
|
||||
{ content: "ain" },
|
||||
{ content: " " },
|
||||
{ content: "Tree " },
|
||||
{ content: "Frog" },
|
||||
]
|
||||
|
||||
prompt = [role: "user", content: "write 3 words"]
|
||||
content = +""
|
||||
|
||||
OpenAiCompletionsInferenceStubs.stub_streamed_response(
|
||||
prompt,
|
||||
deltas,
|
||||
req_opts: {
|
||||
stream: true,
|
||||
},
|
||||
)
|
||||
|
||||
DiscourseAi::Inference::OpenAiCompletions.perform!(prompt, "gpt-3.5-turbo") do |partial, cancel|
|
||||
data = partial.dig(:choices, 0, :delta, :content)
|
||||
content << data if data
|
||||
cancel.call if content.split(" ").length == 2
|
||||
end
|
||||
|
||||
expect(content).to eq("Mountain Tree ")
|
||||
|
||||
expect(AiApiAuditLog.count).to eq(1)
|
||||
log = AiApiAuditLog.first
|
||||
|
||||
request_body = { model: "gpt-3.5-turbo", messages: prompt, stream: true }.to_json
|
||||
|
||||
expect(log.provider_id).to eq(AiApiAuditLog::Provider::OpenAI)
|
||||
expect(log.request_tokens).to eq(4)
|
||||
expect(log.response_tokens).to eq(3)
|
||||
expect(log.raw_request_payload).to eq(request_body)
|
||||
expect(log.raw_response_payload).to be_present
|
||||
end
|
||||
end
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue