REFACTOR: Simplify tool invocation by removing bot_user and llm parameters (#603)
* Well, it was quite a journey but now tools have "context" which can be critical for the stuff they generate This entire change was so Dall E and Artist generate images in the correct context * FIX: improve error handling around image generation - also corrects image markdown and clarifies code * fix spec
This commit is contained in:
parent
88c7427fab
commit
ab78d9b597
|
@ -68,13 +68,13 @@ module DiscourseAi
|
||||||
|
|
||||||
result =
|
result =
|
||||||
llm.generate(prompt, **llm_kwargs) do |partial, cancel|
|
llm.generate(prompt, **llm_kwargs) do |partial, cancel|
|
||||||
tools = persona.find_tools(partial)
|
tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context)
|
||||||
|
|
||||||
if (tools.present?)
|
if (tools.present?)
|
||||||
tool_found = true
|
tool_found = true
|
||||||
tools[0..MAX_TOOLS].each do |tool|
|
tools[0..MAX_TOOLS].each do |tool|
|
||||||
ongoing_chain &&= tool.chain_next_response?
|
|
||||||
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
|
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
|
||||||
|
ongoing_chain &&= tool.chain_next_response?
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
update_blk.call(partial, cancel, nil)
|
update_blk.call(partial, cancel, nil)
|
||||||
|
@ -137,7 +137,7 @@ module DiscourseAi
|
||||||
update_blk.call("", cancel, build_placeholder(tool.summary, ""))
|
update_blk.call("", cancel, build_placeholder(tool.summary, ""))
|
||||||
|
|
||||||
result =
|
result =
|
||||||
tool.invoke(bot_user, llm) do |progress|
|
tool.invoke do |progress|
|
||||||
placeholder = build_placeholder(tool.summary, progress)
|
placeholder = build_placeholder(tool.summary, progress)
|
||||||
update_blk.call("", cancel, placeholder)
|
update_blk.call("", cancel, placeholder)
|
||||||
end
|
end
|
||||||
|
|
|
@ -174,16 +174,21 @@ module DiscourseAi
|
||||||
prompt
|
prompt
|
||||||
end
|
end
|
||||||
|
|
||||||
def find_tools(partial)
|
def find_tools(partial, bot_user:, llm:, context:)
|
||||||
return [] if !partial.include?("</invoke>")
|
return [] if !partial.include?("</invoke>")
|
||||||
|
|
||||||
parsed_function = Nokogiri::HTML5.fragment(partial)
|
parsed_function = Nokogiri::HTML5.fragment(partial)
|
||||||
parsed_function.css("invoke").map { |fragment| find_tool(fragment) }.compact
|
parsed_function
|
||||||
|
.css("invoke")
|
||||||
|
.map do |fragment|
|
||||||
|
tool_instance(fragment, bot_user: bot_user, llm: llm, context: context)
|
||||||
|
end
|
||||||
|
.compact
|
||||||
end
|
end
|
||||||
|
|
||||||
protected
|
protected
|
||||||
|
|
||||||
def find_tool(parsed_function)
|
def tool_instance(parsed_function, bot_user:, llm:, context:)
|
||||||
function_id = parsed_function.at("tool_id")&.text
|
function_id = parsed_function.at("tool_id")&.text
|
||||||
function_name = parsed_function.at("tool_name")&.text
|
function_name = parsed_function.at("tool_name")&.text
|
||||||
return nil if function_name.nil?
|
return nil if function_name.nil?
|
||||||
|
@ -212,6 +217,9 @@ module DiscourseAi
|
||||||
arguments,
|
arguments,
|
||||||
tool_call_id: function_id || function_name,
|
tool_call_id: function_id || function_name,
|
||||||
persona_options: options[tool_klass].to_h,
|
persona_options: options[tool_klass].to_h,
|
||||||
|
bot_user: bot_user,
|
||||||
|
llm: llm,
|
||||||
|
context: context,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -350,6 +350,7 @@ module DiscourseAi
|
||||||
)
|
)
|
||||||
context[:post_id] = post.id
|
context[:post_id] = post.id
|
||||||
context[:topic_id] = post.topic_id
|
context[:topic_id] = post.topic_id
|
||||||
|
context[:private_message] = post.topic.private_message?
|
||||||
|
|
||||||
reply_user = bot.bot_user
|
reply_user = bot.bot_user
|
||||||
if bot.persona.class.respond_to?(:user_id)
|
if bot.persona.class.respond_to?(:user_id)
|
||||||
|
|
|
@ -33,7 +33,7 @@ module DiscourseAi
|
||||||
false
|
false
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(bot_user, _llm)
|
def invoke
|
||||||
# max 4 prompts
|
# max 4 prompts
|
||||||
max_prompts = prompts.take(4)
|
max_prompts = prompts.take(4)
|
||||||
progress = prompts.first
|
progress = prompts.first
|
||||||
|
@ -88,7 +88,12 @@ module DiscourseAi
|
||||||
file.rewind
|
file.rewind
|
||||||
uploads << {
|
uploads << {
|
||||||
prompt: image[:revised_prompt],
|
prompt: image[:revised_prompt],
|
||||||
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
|
upload:
|
||||||
|
UploadCreator.new(
|
||||||
|
file,
|
||||||
|
"image.png",
|
||||||
|
for_private_message: context[:private_message],
|
||||||
|
).create_for(bot_user.id),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -99,9 +104,7 @@ module DiscourseAi
|
||||||
[grid]
|
[grid]
|
||||||
#{
|
#{
|
||||||
uploads
|
uploads
|
||||||
.map do |item|
|
.map { |item| "![#{item[:prompt].gsub(/\|\'\"/, "")}](#{item[:upload].short_url})" }
|
||||||
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
|
|
||||||
end
|
|
||||||
.join(" ")
|
.join(" ")
|
||||||
}
|
}
|
||||||
[/grid]
|
[/grid]
|
||||||
|
|
|
@ -28,7 +28,7 @@ module DiscourseAi
|
||||||
parameters[:tables]
|
parameters[:tables]
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(_bot_user, _llm)
|
def invoke
|
||||||
tables_arr = tables.split(",").map(&:strip)
|
tables_arr = tables.split(",").map(&:strip)
|
||||||
|
|
||||||
table_info = {}
|
table_info = {}
|
||||||
|
|
|
@ -78,7 +78,7 @@ module DiscourseAi
|
||||||
parameters.slice(:category, :user, :order, :max_posts, :tags, :before, :after, :status)
|
parameters.slice(:category, :user, :order, :max_posts, :tags, :before, :after, :status)
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(bot_user, llm)
|
def invoke
|
||||||
search_string =
|
search_string =
|
||||||
search_args.reduce(+parameters[:search_query].to_s) do |memo, (key, value)|
|
search_args.reduce(+parameters[:search_query].to_s) do |memo, (key, value)|
|
||||||
return memo if value.blank?
|
return memo if value.blank?
|
||||||
|
|
|
@ -52,7 +52,7 @@ module DiscourseAi
|
||||||
{ repo_name: repo_name, file_paths: file_paths.join(", "), branch: branch }
|
{ repo_name: repo_name, file_paths: file_paths.join(", "), branch: branch }
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(_bot_user, llm)
|
def invoke
|
||||||
owner, repo = repo_name.split("/")
|
owner, repo = repo_name.split("/")
|
||||||
file_contents = {}
|
file_contents = {}
|
||||||
missing_files = []
|
missing_files = []
|
||||||
|
|
|
@ -43,7 +43,7 @@ module DiscourseAi
|
||||||
@url
|
@url
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(_bot_user, llm)
|
def invoke
|
||||||
api_url = "https://api.github.com/repos/#{repo}/pulls/#{pull_id}"
|
api_url = "https://api.github.com/repos/#{repo}/pulls/#{pull_id}"
|
||||||
@url = "https://github.com/#{repo}/pull/#{pull_id}"
|
@url = "https://github.com/#{repo}/pull/#{pull_id}"
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ module DiscourseAi
|
||||||
{ repo: repo, query: query }
|
{ repo: repo, query: query }
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(_bot_user, llm)
|
def invoke
|
||||||
api_url = "https://api.github.com/search/code?q=#{query}+repo:#{repo}"
|
api_url = "https://api.github.com/search/code?q=#{query}+repo:#{repo}"
|
||||||
|
|
||||||
response_code = "unknown error"
|
response_code = "unknown error"
|
||||||
|
|
|
@ -27,7 +27,7 @@ module DiscourseAi
|
||||||
parameters[:query].to_s.strip
|
parameters[:query].to_s.strip
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(bot_user, llm)
|
def invoke
|
||||||
yield(query)
|
yield(query)
|
||||||
|
|
||||||
api_key = SiteSetting.ai_google_custom_search_api_key
|
api_key = SiteSetting.ai_google_custom_search_api_key
|
||||||
|
|
|
@ -40,6 +40,11 @@ module DiscourseAi
|
||||||
"image"
|
"image"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def initialize(*args, **kwargs)
|
||||||
|
super
|
||||||
|
@chain_next_response = false
|
||||||
|
end
|
||||||
|
|
||||||
def prompts
|
def prompts
|
||||||
parameters[:prompts]
|
parameters[:prompts]
|
||||||
end
|
end
|
||||||
|
@ -53,10 +58,10 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def chain_next_response?
|
def chain_next_response?
|
||||||
false
|
@chain_next_response
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(bot_user, _llm)
|
def invoke
|
||||||
# max 4 prompts
|
# max 4 prompts
|
||||||
selected_prompts = prompts.take(4)
|
selected_prompts = prompts.take(4)
|
||||||
seeds = seeds.take(4) if seeds
|
seeds = seeds.take(4) if seeds
|
||||||
|
@ -101,7 +106,15 @@ module DiscourseAi
|
||||||
results = threads.map(&:value).compact
|
results = threads.map(&:value).compact
|
||||||
|
|
||||||
if !results.present?
|
if !results.present?
|
||||||
return { prompts: prompts, error: "Something went wrong, could not generate image" }
|
@chain_next_response = true
|
||||||
|
return(
|
||||||
|
{
|
||||||
|
prompts: prompts,
|
||||||
|
error:
|
||||||
|
"Something went wrong inform user you could not generate image, check Discourse logs, give up don't try anymore",
|
||||||
|
give_up: true,
|
||||||
|
}
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
uploads = []
|
uploads = []
|
||||||
|
@ -114,7 +127,12 @@ module DiscourseAi
|
||||||
file.rewind
|
file.rewind
|
||||||
uploads << {
|
uploads << {
|
||||||
prompt: prompts[index],
|
prompt: prompts[index],
|
||||||
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
|
upload:
|
||||||
|
UploadCreator.new(
|
||||||
|
file,
|
||||||
|
"image.png",
|
||||||
|
for_private_message: context[:private_message],
|
||||||
|
).create_for(bot_user.id),
|
||||||
seed: image[:seed],
|
seed: image[:seed],
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
@ -126,9 +144,7 @@ module DiscourseAi
|
||||||
[grid]
|
[grid]
|
||||||
#{
|
#{
|
||||||
uploads
|
uploads
|
||||||
.map do |item|
|
.map { |item| "![#{item[:prompt].gsub(/\|\'\"/, "")}](#{item[:upload].short_url})" }
|
||||||
"![#{item[:prompt].gsub(/\|\'\"/, "")}|50%](#{item[:upload].short_url})"
|
|
||||||
end
|
|
||||||
.join(" ")
|
.join(" ")
|
||||||
}
|
}
|
||||||
[/grid]
|
[/grid]
|
||||||
|
|
|
@ -16,7 +16,7 @@ module DiscourseAi
|
||||||
"categories"
|
"categories"
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(_bot_user, _llm)
|
def invoke
|
||||||
columns = {
|
columns = {
|
||||||
name: "Name",
|
name: "Name",
|
||||||
slug: "Slug",
|
slug: "Slug",
|
||||||
|
|
|
@ -15,7 +15,7 @@ module DiscourseAi
|
||||||
"tags"
|
"tags"
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(_bot_user, _llm)
|
def invoke
|
||||||
column_names = { name: "Name", public_topic_count: "Topic Count" }
|
column_names = { name: "Name", public_topic_count: "Topic Count" }
|
||||||
|
|
||||||
tags =
|
tags =
|
||||||
|
|
|
@ -30,7 +30,7 @@ module DiscourseAi
|
||||||
parameters[:options]
|
parameters[:options]
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(_bot_user, _llm)
|
def invoke
|
||||||
result = nil
|
result = nil
|
||||||
# can be a naive list of strings
|
# can be a naive list of strings
|
||||||
if options.none? { |option| option.match?(/\A\d+-\d+\z/) || option.include?(",") }
|
if options.none? { |option| option.match?(/\A\d+-\d+\z/) || option.include?(",") }
|
||||||
|
|
|
@ -39,7 +39,7 @@ module DiscourseAi
|
||||||
parameters[:post_number]
|
parameters[:post_number]
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(_bot_user, llm)
|
def invoke
|
||||||
not_found = { topic_id: topic_id, description: "Topic not found" }
|
not_found = { topic_id: topic_id, description: "Topic not found" }
|
||||||
|
|
||||||
@title = ""
|
@title = ""
|
||||||
|
|
|
@ -91,7 +91,7 @@ module DiscourseAi
|
||||||
parameters.slice(:category, :user, :order, :max_posts, :tags, :before, :after, :status)
|
parameters.slice(:category, :user, :order, :max_posts, :tags, :before, :after, :status)
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(bot_user, llm)
|
def invoke
|
||||||
search_string =
|
search_string =
|
||||||
search_args.reduce(+parameters[:search_query].to_s) do |memo, (key, value)|
|
search_args.reduce(+parameters[:search_query].to_s) do |memo, (key, value)|
|
||||||
return memo if value.blank?
|
return memo if value.blank?
|
||||||
|
|
|
@ -31,7 +31,7 @@ module DiscourseAi
|
||||||
parameters[:query].to_s
|
parameters[:query].to_s
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(_bot_user, _llm)
|
def invoke
|
||||||
@last_num_results = 0
|
@last_num_results = 0
|
||||||
|
|
||||||
terms = query.split(",").map(&:strip).map(&:downcase).reject(&:blank?)
|
terms = query.split(",").map(&:strip).map(&:downcase).reject(&:blank?)
|
||||||
|
|
|
@ -47,7 +47,7 @@ module DiscourseAi
|
||||||
parameters[:setting_name]
|
parameters[:setting_name]
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(_bot_user, llm)
|
def invoke
|
||||||
if !self.class.rg_installed?
|
if !self.class.rg_installed?
|
||||||
return(
|
return(
|
||||||
{
|
{
|
||||||
|
|
|
@ -48,7 +48,7 @@ module DiscourseAi
|
||||||
@last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found")
|
@last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found")
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(bot_user, llm, &progress_blk)
|
def invoke(&progress_blk)
|
||||||
topic = nil
|
topic = nil
|
||||||
if topic_id > 0
|
if topic_id > 0
|
||||||
topic = Topic.find_by(id: topic_id)
|
topic = Topic.find_by(id: topic_id)
|
||||||
|
|
|
@ -27,7 +27,7 @@ module DiscourseAi
|
||||||
parameters[:timezone].to_s
|
parameters[:timezone].to_s
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(_bot_user, _llm)
|
def invoke
|
||||||
time =
|
time =
|
||||||
begin
|
begin
|
||||||
::Time.now.in_time_zone(timezone)
|
::Time.now.in_time_zone(timezone)
|
||||||
|
|
|
@ -31,15 +31,24 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
attr_accessor :custom_raw
|
attr_accessor :custom_raw
|
||||||
|
attr_reader :tool_call_id, :persona_options, :bot_user, :llm, :context, :parameters
|
||||||
|
|
||||||
def initialize(parameters, tool_call_id: "", persona_options: {})
|
def initialize(
|
||||||
|
parameters,
|
||||||
|
tool_call_id: "",
|
||||||
|
persona_options: {},
|
||||||
|
bot_user:,
|
||||||
|
llm:,
|
||||||
|
context: {}
|
||||||
|
)
|
||||||
@parameters = parameters
|
@parameters = parameters
|
||||||
@tool_call_id = tool_call_id
|
@tool_call_id = tool_call_id
|
||||||
@persona_options = persona_options
|
@persona_options = persona_options
|
||||||
|
@bot_user = bot_user
|
||||||
|
@llm = llm
|
||||||
|
@context = context
|
||||||
end
|
end
|
||||||
|
|
||||||
attr_reader :parameters, :tool_call_id
|
|
||||||
|
|
||||||
def name
|
def name
|
||||||
self.class.name
|
self.class.name
|
||||||
end
|
end
|
||||||
|
|
|
@ -32,7 +32,7 @@ module DiscourseAi
|
||||||
@url
|
@url
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(_bot_user, llm)
|
def invoke
|
||||||
send_http_request(url, follow_redirects: true) do |response|
|
send_http_request(url, follow_redirects: true) do |response|
|
||||||
if response.code == "200"
|
if response.code == "200"
|
||||||
html = read_response_body(response)
|
html = read_response_body(response)
|
||||||
|
|
|
@ -66,6 +66,12 @@ after_initialize do
|
||||||
reloadable_patch { |plugin| Guardian.prepend DiscourseAi::GuardianExtensions }
|
reloadable_patch { |plugin| Guardian.prepend DiscourseAi::GuardianExtensions }
|
||||||
|
|
||||||
register_modifier(:post_should_secure_uploads?) do |_, _, topic|
|
register_modifier(:post_should_secure_uploads?) do |_, _, topic|
|
||||||
false if topic.private_message? && SharedAiConversation.exists?(target: topic)
|
if topic.private_message? && SharedAiConversation.exists?(target: topic)
|
||||||
|
false
|
||||||
|
else
|
||||||
|
# revert to default behavior
|
||||||
|
# even though this can be shortened this is the clearest way to express it
|
||||||
|
nil
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -63,7 +63,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
|
||||||
|
|
||||||
context "when using function chaining" do
|
context "when using function chaining" do
|
||||||
it "yields a loading placeholder while proceeds to invoke the command" do
|
it "yields a loading placeholder while proceeds to invoke the command" do
|
||||||
tool = DiscourseAi::AiBot::Tools::ListCategories.new({})
|
tool = DiscourseAi::AiBot::Tools::ListCategories.new({}, bot_user: nil, llm: nil)
|
||||||
partial_placeholder = +(<<~HTML)
|
partial_placeholder = +(<<~HTML)
|
||||||
<details>
|
<details>
|
||||||
<summary>#{tool.summary}</summary>
|
<summary>#{tool.summary}</summary>
|
||||||
|
|
|
@ -99,7 +99,14 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
</invoke>
|
</invoke>
|
||||||
</function_calls>
|
</function_calls>
|
||||||
XML
|
XML
|
||||||
dall_e1, dall_e2 = tools = DiscourseAi::AiBot::Personas::DallE3.new.find_tools(xml)
|
dall_e1, dall_e2 =
|
||||||
|
tools =
|
||||||
|
DiscourseAi::AiBot::Personas::DallE3.new.find_tools(
|
||||||
|
xml,
|
||||||
|
bot_user: nil,
|
||||||
|
llm: nil,
|
||||||
|
context: nil,
|
||||||
|
)
|
||||||
expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"])
|
expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"])
|
||||||
expect(dall_e2.parameters[:prompts]).to eq(["pic3"])
|
expect(dall_e2.parameters[:prompts]).to eq(["pic3"])
|
||||||
expect(tools.length).to eq(2)
|
expect(tools.length).to eq(2)
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
#frozen_string_literal: true
|
#frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::DallE do
|
RSpec.describe DiscourseAi::AiBot::Tools::DallE do
|
||||||
subject(:dall_e) { described_class.new({ prompts: prompts }) }
|
|
||||||
|
|
||||||
let(:prompts) { ["a pink cow", "a red cow"] }
|
let(:prompts) { ["a pink cow", "a red cow"] }
|
||||||
|
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
|
|
||||||
|
let(:dall_e) do
|
||||||
|
described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user, context: {})
|
||||||
|
end
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
||||||
describe "#process" do
|
describe "#process" do
|
||||||
|
@ -34,12 +36,12 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
|
||||||
end
|
end
|
||||||
.to_return(status: 200, body: { data: data }.to_json)
|
.to_return(status: 200, body: { data: data }.to_json)
|
||||||
|
|
||||||
info = dall_e.invoke(bot_user, llm, &progress_blk).to_json
|
info = dall_e.invoke(&progress_blk).to_json
|
||||||
|
|
||||||
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"])
|
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"])
|
||||||
expect(subject.custom_raw).to include("upload://")
|
expect(dall_e.custom_raw).to include("upload://")
|
||||||
expect(subject.custom_raw).to include("[grid]")
|
expect(dall_e.custom_raw).to include("[grid]")
|
||||||
expect(subject.custom_raw).to include("a pink cow 1")
|
expect(dall_e.custom_raw).to include("a pink cow 1")
|
||||||
end
|
end
|
||||||
|
|
||||||
it "can generate correct info" do
|
it "can generate correct info" do
|
||||||
|
@ -59,12 +61,12 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
|
||||||
end
|
end
|
||||||
.to_return(status: 200, body: { data: data }.to_json)
|
.to_return(status: 200, body: { data: data }.to_json)
|
||||||
|
|
||||||
info = dall_e.invoke(bot_user, llm, &progress_blk).to_json
|
info = dall_e.invoke(&progress_blk).to_json
|
||||||
|
|
||||||
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"])
|
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"])
|
||||||
expect(subject.custom_raw).to include("upload://")
|
expect(dall_e.custom_raw).to include("upload://")
|
||||||
expect(subject.custom_raw).to include("[grid]")
|
expect(dall_e.custom_raw).to include("[grid]")
|
||||||
expect(subject.custom_raw).to include("a pink cow 1")
|
expect(dall_e.custom_raw).to include("a pink cow 1")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -7,7 +7,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DbSchema do
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
describe "#process" do
|
describe "#process" do
|
||||||
it "returns rich schema for tables" do
|
it "returns rich schema for tables" do
|
||||||
result = described_class.new({ tables: "posts,topics" }).invoke(bot_user, llm)
|
result = described_class.new({ tables: "posts,topics" }, bot_user: bot_user, llm: llm).invoke
|
||||||
|
|
||||||
expect(result[:schema_info]).to include("raw text")
|
expect(result[:schema_info]).to include("raw text")
|
||||||
expect(result[:schema_info]).to include("views integer")
|
expect(result[:schema_info]).to include("views integer")
|
||||||
|
|
|
@ -45,8 +45,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::DiscourseMetaSearch do
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
search = described_class.new({ search_query: "test" })
|
search = described_class.new({ search_query: "test" }, bot_user: bot_user, llm: llm)
|
||||||
results = search.invoke(bot_user, llm, &progress_blk)
|
results = search.invoke(&progress_blk)
|
||||||
expect(results[:rows].length).to eq(20)
|
expect(results[:rows].length).to eq(20)
|
||||||
|
|
||||||
expect(results[:rows].first[results[:column_names].index("category")]).to eq(
|
expect(results[:rows].first[results[:column_names].index("category")]).to eq(
|
||||||
|
@ -71,8 +71,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::DiscourseMetaSearch do
|
||||||
.to_h
|
.to_h
|
||||||
.symbolize_keys
|
.symbolize_keys
|
||||||
|
|
||||||
search = described_class.new(params)
|
search = described_class.new(params, bot_user: bot_user, llm: llm)
|
||||||
results = search.invoke(bot_user, llm, &progress_blk)
|
results = search.invoke(&progress_blk)
|
||||||
|
|
||||||
expect(results[:args]).to eq(params)
|
expect(results[:args]).to eq(params)
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,6 +3,8 @@
|
||||||
require "rails_helper"
|
require "rails_helper"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do
|
RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
||||||
|
|
||||||
let(:tool) do
|
let(:tool) do
|
||||||
described_class.new(
|
described_class.new(
|
||||||
{
|
{
|
||||||
|
@ -10,11 +12,11 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do
|
||||||
file_paths: %w[lib/database/connection.rb lib/ai_bot/tools/github_pull_request_diff.rb],
|
file_paths: %w[lib/database/connection.rb lib/ai_bot/tools/github_pull_request_diff.rb],
|
||||||
branch: "8b382d6098fde879d28bbee68d3cbe0a193e4ffc",
|
branch: "8b382d6098fde879d28bbee68d3cbe0a193e4ffc",
|
||||||
},
|
},
|
||||||
|
bot_user: nil,
|
||||||
|
llm: llm,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
|
||||||
|
|
||||||
describe "#invoke" do
|
describe "#invoke" do
|
||||||
before do
|
before do
|
||||||
stub_request(
|
stub_request(
|
||||||
|
@ -35,7 +37,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do
|
||||||
end
|
end
|
||||||
|
|
||||||
it "retrieves the content of the specified GitHub files" do
|
it "retrieves the content of the specified GitHub files" do
|
||||||
result = tool.invoke(nil, llm)
|
result = tool.invoke
|
||||||
expected = {
|
expected = {
|
||||||
file_contents:
|
file_contents:
|
||||||
"File Path: lib/database/connection.rb:\ncontent of connection.rb\nFile Path: lib/ai_bot/tools/github_pull_request_diff.rb:\ncontent of github_pull_request_diff.rb",
|
"File Path: lib/database/connection.rb:\ncontent of connection.rb\nFile Path: lib/ai_bot/tools/github_pull_request_diff.rb:\ncontent of github_pull_request_diff.rb",
|
||||||
|
|
|
@ -3,9 +3,9 @@
|
||||||
require "rails_helper"
|
require "rails_helper"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do
|
RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do
|
||||||
let(:tool) { described_class.new({ repo: repo, pull_id: pull_id }) }
|
|
||||||
let(:bot_user) { Fabricate(:user) }
|
let(:bot_user) { Fabricate(:user) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
||||||
|
let(:tool) { described_class.new({ repo: repo, pull_id: pull_id }, bot_user: bot_user, llm: llm) }
|
||||||
|
|
||||||
context "with #sort_and_shorten_diff" do
|
context "with #sort_and_shorten_diff" do
|
||||||
it "sorts and shortens the diff without dropping data" do
|
it "sorts and shortens the diff without dropping data" do
|
||||||
|
@ -64,7 +64,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do
|
||||||
},
|
},
|
||||||
).to_return(status: 200, body: diff)
|
).to_return(status: 200, body: diff)
|
||||||
|
|
||||||
result = tool.invoke(bot_user, llm)
|
result = tool.invoke
|
||||||
expect(result[:diff]).to eq(diff)
|
expect(result[:diff]).to eq(diff)
|
||||||
expect(result[:error]).to be_nil
|
expect(result[:error]).to be_nil
|
||||||
end
|
end
|
||||||
|
@ -80,7 +80,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do
|
||||||
},
|
},
|
||||||
).to_return(status: 200, body: diff)
|
).to_return(status: 200, body: diff)
|
||||||
|
|
||||||
result = tool.invoke(bot_user, llm)
|
result = tool.invoke
|
||||||
expect(result[:diff]).to eq(diff)
|
expect(result[:diff]).to eq(diff)
|
||||||
expect(result[:error]).to be_nil
|
expect(result[:error]).to be_nil
|
||||||
end
|
end
|
||||||
|
@ -98,7 +98,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do
|
||||||
},
|
},
|
||||||
).to_return(status: 404)
|
).to_return(status: 404)
|
||||||
|
|
||||||
result = tool.invoke(bot_user, nil)
|
result = tool.invoke
|
||||||
expect(result[:diff]).to be_nil
|
expect(result[:diff]).to be_nil
|
||||||
expect(result[:error]).to include("Failed to retrieve the diff")
|
expect(result[:error]).to include("Failed to retrieve the diff")
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,9 +3,9 @@
|
||||||
require "rails_helper"
|
require "rails_helper"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do
|
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do
|
||||||
let(:tool) { described_class.new({ repo: repo, query: query }) }
|
|
||||||
let(:bot_user) { Fabricate(:user) }
|
let(:bot_user) { Fabricate(:user) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
||||||
|
let(:tool) { described_class.new({ repo: repo, query: query }, bot_user: bot_user, llm: llm) }
|
||||||
|
|
||||||
context "with valid search results" do
|
context "with valid search results" do
|
||||||
let(:repo) { "discourse/discourse" }
|
let(:repo) { "discourse/discourse" }
|
||||||
|
@ -34,7 +34,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do
|
||||||
}.to_json,
|
}.to_json,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = tool.invoke(bot_user, llm)
|
result = tool.invoke
|
||||||
expect(result[:search_results]).to include("def hello\n puts 'hello'\nend")
|
expect(result[:search_results]).to include("def hello\n puts 'hello'\nend")
|
||||||
expect(result[:search_results]).to include("test/hello.rb")
|
expect(result[:search_results]).to include("test/hello.rb")
|
||||||
expect(result[:error]).to be_nil
|
expect(result[:error]).to be_nil
|
||||||
|
@ -64,7 +64,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do
|
||||||
},
|
},
|
||||||
).to_return(status: 200, body: { total_count: 0, items: [] }.to_json)
|
).to_return(status: 200, body: { total_count: 0, items: [] }.to_json)
|
||||||
|
|
||||||
result = tool.invoke(bot_user, llm)
|
result = tool.invoke
|
||||||
expect(result[:search_results]).to be_empty
|
expect(result[:search_results]).to be_empty
|
||||||
expect(result[:error]).to be_nil
|
expect(result[:error]).to be_nil
|
||||||
end
|
end
|
||||||
|
@ -85,7 +85,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do
|
||||||
},
|
},
|
||||||
).to_return(status: 403)
|
).to_return(status: 403)
|
||||||
|
|
||||||
result = tool.invoke(bot_user, llm)
|
result = tool.invoke
|
||||||
expect(result[:search_results]).to be_nil
|
expect(result[:search_results]).to be_nil
|
||||||
expect(result[:error]).to include("Failed to perform code search")
|
expect(result[:error]).to include("Failed to perform code search")
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,18 +1,15 @@
|
||||||
#frozen_string_literal: true
|
#frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::Google do
|
RSpec.describe DiscourseAi::AiBot::Tools::Google do
|
||||||
subject(:search) { described_class.new({ query: "some search term" }) }
|
|
||||||
|
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
|
let(:search) { described_class.new({ query: "some search term" }, bot_user: bot_user, llm: llm) }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
||||||
describe "#process" do
|
describe "#process" do
|
||||||
it "will not explode if there are no results" do
|
it "will not explode if there are no results" do
|
||||||
post = Fabricate(:post)
|
|
||||||
|
|
||||||
SiteSetting.ai_google_custom_search_api_key = "abc"
|
SiteSetting.ai_google_custom_search_api_key = "abc"
|
||||||
SiteSetting.ai_google_custom_search_cx = "cx"
|
SiteSetting.ai_google_custom_search_cx = "cx"
|
||||||
|
|
||||||
|
@ -23,15 +20,13 @@ RSpec.describe DiscourseAi::AiBot::Tools::Google do
|
||||||
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
|
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
|
||||||
).to_return(status: 200, body: json_text, headers: {})
|
).to_return(status: 200, body: json_text, headers: {})
|
||||||
|
|
||||||
info = search.invoke(bot_user, llm, &progress_blk).to_json
|
info = search.invoke(&progress_blk).to_json
|
||||||
|
|
||||||
expect(search.results_count).to eq(0)
|
expect(search.results_count).to eq(0)
|
||||||
expect(info).to_not include("oops")
|
expect(info).to_not include("oops")
|
||||||
end
|
end
|
||||||
|
|
||||||
it "can generate correct info" do
|
it "can generate correct info" do
|
||||||
post = Fabricate(:post)
|
|
||||||
|
|
||||||
SiteSetting.ai_google_custom_search_api_key = "abc"
|
SiteSetting.ai_google_custom_search_api_key = "abc"
|
||||||
SiteSetting.ai_google_custom_search_cx = "cx"
|
SiteSetting.ai_google_custom_search_cx = "cx"
|
||||||
|
|
||||||
|
@ -63,7 +58,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Google do
|
||||||
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
|
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
|
||||||
).to_return(status: 200, body: json_text, headers: {})
|
).to_return(status: 200, body: json_text, headers: {})
|
||||||
|
|
||||||
info = search.invoke(bot_user, llm, &progress_blk).to_json
|
info = search.invoke(&progress_blk).to_json
|
||||||
|
|
||||||
expect(search.results_count).to eq(2)
|
expect(search.results_count).to eq(2)
|
||||||
expect(info).to include("title1")
|
expect(info).to include("title1")
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
#frozen_string_literal: true
|
#frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::Image do
|
RSpec.describe DiscourseAi::AiBot::Tools::Image do
|
||||||
subject(:tool) { described_class.new({ prompts: prompts, seeds: [99, 32] }) }
|
|
||||||
|
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
|
|
||||||
|
@ -10,11 +8,21 @@ RSpec.describe DiscourseAi::AiBot::Tools::Image do
|
||||||
|
|
||||||
let(:prompts) { ["a pink cow", "a red cow"] }
|
let(:prompts) { ["a pink cow", "a red cow"] }
|
||||||
|
|
||||||
|
let(:tool) do
|
||||||
|
described_class.new(
|
||||||
|
{ prompts: prompts, seeds: [99, 32] },
|
||||||
|
bot_user: bot_user,
|
||||||
|
llm: llm,
|
||||||
|
context: {
|
||||||
|
},
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
||||||
describe "#process" do
|
describe "#process" do
|
||||||
it "can generate correct info" do
|
it "can generate correct info" do
|
||||||
post = Fabricate(:post)
|
_post = Fabricate(:post)
|
||||||
|
|
||||||
SiteSetting.ai_stability_api_url = "https://api.stability.dev"
|
SiteSetting.ai_stability_api_url = "https://api.stability.dev"
|
||||||
SiteSetting.ai_stability_api_key = "abc"
|
SiteSetting.ai_stability_api_key = "abc"
|
||||||
|
@ -36,7 +44,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Image do
|
||||||
end
|
end
|
||||||
.to_return(status: 200, body: { artifacts: artifacts }.to_json)
|
.to_return(status: 200, body: { artifacts: artifacts }.to_json)
|
||||||
|
|
||||||
info = tool.invoke(bot_user, llm, &progress_blk).to_json
|
info = tool.invoke(&progress_blk).to_json
|
||||||
|
|
||||||
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow", "a red cow"], "seeds" => [99, 99])
|
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow", "a red cow"], "seeds" => [99, 99])
|
||||||
expect(tool.custom_raw).to include("upload://")
|
expect(tool.custom_raw).to include("upload://")
|
||||||
|
|
|
@ -10,7 +10,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ListCategories do
|
||||||
it "list available categories" do
|
it "list available categories" do
|
||||||
Fabricate(:category, name: "america", posts_year: 999)
|
Fabricate(:category, name: "america", posts_year: 999)
|
||||||
|
|
||||||
info = described_class.new({}).invoke(bot_user, llm).to_s
|
info = described_class.new({}, bot_user: bot_user, llm: llm).invoke.to_s
|
||||||
|
|
||||||
expect(info).to include("america")
|
expect(info).to include("america")
|
||||||
expect(info).to include("999")
|
expect(info).to include("999")
|
||||||
|
|
|
@ -14,7 +14,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ListTags do
|
||||||
Fabricate(:tag, name: "america", public_topic_count: 100)
|
Fabricate(:tag, name: "america", public_topic_count: 100)
|
||||||
Fabricate(:tag, name: "not_here", public_topic_count: 0)
|
Fabricate(:tag, name: "not_here", public_topic_count: 0)
|
||||||
|
|
||||||
info = described_class.new({}).invoke(bot_user, llm)
|
info = described_class.new({}, bot_user: bot_user, llm: llm).invoke
|
||||||
|
|
||||||
expect(info.to_s).to include("america")
|
expect(info.to_s).to include("america")
|
||||||
expect(info.to_s).not_to include("not_here")
|
expect(info.to_s).not_to include("not_here")
|
||||||
|
|
|
@ -4,7 +4,7 @@ require "rails_helper"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::RandomPicker do
|
RSpec.describe DiscourseAi::AiBot::Tools::RandomPicker do
|
||||||
describe "#invoke" do
|
describe "#invoke" do
|
||||||
subject { described_class.new({ options: options }).invoke(nil, nil) }
|
subject { described_class.new({ options: options }, bot_user: nil, llm: nil).invoke }
|
||||||
|
|
||||||
context "with options as simple list of strings" do
|
context "with options as simple list of strings" do
|
||||||
let(:options) { %w[apple banana cherry] }
|
let(:options) { %w[apple banana cherry] }
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
#frozen_string_literal: true
|
#frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::Read do
|
RSpec.describe DiscourseAi::AiBot::Tools::Read do
|
||||||
subject(:tool) { described_class.new({ topic_id: topic_with_tags.id }) }
|
|
||||||
|
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
|
let(:tool) { described_class.new({ topic_id: topic_with_tags.id }, bot_user: bot_user, llm: llm) }
|
||||||
|
|
||||||
fab!(:parent_category) { Fabricate(:category, name: "animals") }
|
fab!(:parent_category) { Fabricate(:category, name: "animals") }
|
||||||
fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") }
|
fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") }
|
||||||
|
@ -34,7 +33,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do
|
||||||
Fabricate(:post, topic: topic_with_tags, raw: "hello there")
|
Fabricate(:post, topic: topic_with_tags, raw: "hello there")
|
||||||
Fabricate(:post, topic: topic_with_tags, raw: "mister sam")
|
Fabricate(:post, topic: topic_with_tags, raw: "mister sam")
|
||||||
|
|
||||||
results = tool.invoke(bot_user, llm)
|
results = tool.invoke
|
||||||
|
|
||||||
expect(results[:topic_id]).to eq(topic_id)
|
expect(results[:topic_id]).to eq(topic_id)
|
||||||
expect(results[:content]).to include("hello")
|
expect(results[:content]).to include("hello")
|
||||||
|
|
|
@ -7,18 +7,18 @@ RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
||||||
def search_settings(query)
|
def search_settings(query)
|
||||||
described_class.new({ query: query })
|
described_class.new({ query: query }, bot_user: bot_user, llm: llm)
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "#process" do
|
describe "#process" do
|
||||||
it "can handle no results" do
|
it "can handle no results" do
|
||||||
results = search_settings("this will not exist frogs").invoke(bot_user, llm)
|
results = search_settings("this will not exist frogs").invoke
|
||||||
expect(results[:args]).to eq({ query: "this will not exist frogs" })
|
expect(results[:args]).to eq({ query: "this will not exist frogs" })
|
||||||
expect(results[:rows]).to eq([])
|
expect(results[:rows]).to eq([])
|
||||||
end
|
end
|
||||||
|
|
||||||
it "can return more many settings with no descriptions if there are lots of hits" do
|
it "can return more many settings with no descriptions if there are lots of hits" do
|
||||||
results = search_settings("a").invoke(bot_user, llm)
|
results = search_settings("a").invoke
|
||||||
|
|
||||||
expect(results[:rows].length).to be > 30
|
expect(results[:rows].length).to be > 30
|
||||||
expect(results[:rows][0].length).to eq(1)
|
expect(results[:rows][0].length).to eq(1)
|
||||||
|
@ -26,10 +26,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do
|
||||||
|
|
||||||
it "can return descriptions if there are few matches" do
|
it "can return descriptions if there are few matches" do
|
||||||
results =
|
results =
|
||||||
search_settings("this will not be found!@,default_locale,ai_bot_enabled_chat_bots").invoke(
|
search_settings("this will not be found!@,default_locale,ai_bot_enabled_chat_bots").invoke
|
||||||
bot_user,
|
|
||||||
llm,
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(results[:rows].length).to eq(2)
|
expect(results[:rows].length).to eq(2)
|
||||||
|
|
||||||
|
|
|
@ -39,24 +39,37 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
||||||
|
|
||||||
_bot_post = Fabricate(:post)
|
_bot_post = Fabricate(:post)
|
||||||
|
|
||||||
search = described_class.new({ order: "latest" }, persona_options: persona_options)
|
search =
|
||||||
|
described_class.new(
|
||||||
|
{ order: "latest" },
|
||||||
|
persona_options: persona_options,
|
||||||
|
bot_user: bot_user,
|
||||||
|
llm: llm,
|
||||||
|
context: {
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
results = search.invoke(bot_user, llm, &progress_blk)
|
results = search.invoke(&progress_blk)
|
||||||
expect(results[:rows].length).to eq(1)
|
expect(results[:rows].length).to eq(1)
|
||||||
|
|
||||||
search_post.topic.tags = []
|
search_post.topic.tags = []
|
||||||
search_post.topic.save!
|
search_post.topic.save!
|
||||||
|
|
||||||
# no longer has the tag funny
|
# no longer has the tag funny
|
||||||
results = search.invoke(bot_user, llm, &progress_blk)
|
results = search.invoke(&progress_blk)
|
||||||
expect(results[:rows].length).to eq(0)
|
expect(results[:rows].length).to eq(0)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "can handle no results" do
|
it "can handle no results" do
|
||||||
_post1 = Fabricate(:post, topic: topic_with_tags)
|
_post1 = Fabricate(:post, topic: topic_with_tags)
|
||||||
search = described_class.new({ search_query: "ABDDCDCEDGDG", order: "fake" })
|
search =
|
||||||
|
described_class.new(
|
||||||
|
{ search_query: "ABDDCDCEDGDG", order: "fake" },
|
||||||
|
bot_user: bot_user,
|
||||||
|
llm: llm,
|
||||||
|
)
|
||||||
|
|
||||||
results = search.invoke(bot_user, llm, &progress_blk)
|
results = search.invoke(&progress_blk)
|
||||||
|
|
||||||
expect(results[:args]).to eq({ search_query: "ABDDCDCEDGDG", order: "fake" })
|
expect(results[:args]).to eq({ search_query: "ABDDCDCEDGDG", order: "fake" })
|
||||||
expect(results[:rows]).to eq([])
|
expect(results[:rows]).to eq([])
|
||||||
|
@ -79,7 +92,12 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
||||||
)
|
)
|
||||||
|
|
||||||
post1 = Fabricate(:post, topic: topic_with_tags)
|
post1 = Fabricate(:post, topic: topic_with_tags)
|
||||||
search = described_class.new({ search_query: "hello world, sam", status: "public" })
|
search =
|
||||||
|
described_class.new(
|
||||||
|
{ search_query: "hello world, sam", status: "public" },
|
||||||
|
llm: llm,
|
||||||
|
bot_user: bot_user,
|
||||||
|
)
|
||||||
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn
|
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn
|
||||||
.any_instance
|
.any_instance
|
||||||
|
@ -88,7 +106,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
||||||
|
|
||||||
results =
|
results =
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(["<ai>#{query}</ai>"]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(["<ai>#{query}</ai>"]) do
|
||||||
search.invoke(bot_user, llm, &progress_blk)
|
search.invoke(&progress_blk)
|
||||||
end
|
end
|
||||||
|
|
||||||
expect(results[:args]).to eq({ search_query: "hello world, sam", status: "public" })
|
expect(results[:args]).to eq({ search_query: "hello world, sam", status: "public" })
|
||||||
|
@ -101,9 +119,10 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
||||||
|
|
||||||
post1 = Fabricate(:post, topic: topic_with_tags)
|
post1 = Fabricate(:post, topic: topic_with_tags)
|
||||||
|
|
||||||
search = described_class.new({ limit: 1, user: post1.user.username })
|
search =
|
||||||
|
described_class.new({ limit: 1, user: post1.user.username }, bot_user: bot_user, llm: llm)
|
||||||
|
|
||||||
results = search.invoke(bot_user, llm, &progress_blk)
|
results = search.invoke(&progress_blk)
|
||||||
expect(results[:rows].to_s).to include("/subfolder" + post1.url)
|
expect(results[:rows].to_s).to include("/subfolder" + post1.url)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -120,18 +139,18 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
||||||
.to_h
|
.to_h
|
||||||
.symbolize_keys
|
.symbolize_keys
|
||||||
|
|
||||||
search = described_class.new(params)
|
search = described_class.new(params, bot_user: bot_user, llm: llm)
|
||||||
results = search.invoke(bot_user, llm, &progress_blk)
|
results = search.invoke(&progress_blk)
|
||||||
|
|
||||||
expect(results[:args]).to eq(params)
|
expect(results[:args]).to eq(params)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "returns rich topic information" do
|
it "returns rich topic information" do
|
||||||
post1 = Fabricate(:post, like_count: 1, topic: topic_with_tags)
|
post1 = Fabricate(:post, like_count: 1, topic: topic_with_tags)
|
||||||
search = described_class.new({ user: post1.user.username })
|
search = described_class.new({ user: post1.user.username }, bot_user: bot_user, llm: llm)
|
||||||
post1.topic.update!(views: 100, posts_count: 2, like_count: 10)
|
post1.topic.update!(views: 100, posts_count: 2, like_count: 10)
|
||||||
|
|
||||||
results = search.invoke(bot_user, llm, &progress_blk)
|
results = search.invoke(&progress_blk)
|
||||||
|
|
||||||
row = results[:rows].first
|
row = results[:rows].first
|
||||||
category = row[results[:column_names].index("category")]
|
category = row[results[:column_names].index("category")]
|
||||||
|
|
|
@ -15,12 +15,12 @@ RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
||||||
def setting_context(setting_name)
|
def setting_context(setting_name)
|
||||||
described_class.new({ setting_name: setting_name })
|
described_class.new({ setting_name: setting_name }, bot_user: bot_user, llm: llm)
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "#execute" do
|
describe "#execute" do
|
||||||
it "returns the context for core setting" do
|
it "returns the context for core setting" do
|
||||||
result = setting_context("moderators_view_emails").invoke(bot_user, llm)
|
result = setting_context("moderators_view_emails").invoke
|
||||||
|
|
||||||
expect(result[:setting_name]).to eq("moderators_view_emails")
|
expect(result[:setting_name]).to eq("moderators_view_emails")
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do
|
||||||
end
|
end
|
||||||
|
|
||||||
it "returns the context for plugin setting" do
|
it "returns the context for plugin setting" do
|
||||||
result = setting_context("ai_bot_enabled").invoke(bot_user, llm)
|
result = setting_context("ai_bot_enabled").invoke
|
||||||
|
|
||||||
expect(result[:setting_name]).to eq("ai_bot_enabled")
|
expect(result[:setting_name]).to eq("ai_bot_enabled")
|
||||||
expect(result[:context]).to include("ai_bot_enabled:")
|
expect(result[:context]).to include("ai_bot_enabled:")
|
||||||
|
@ -37,7 +37,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do
|
||||||
|
|
||||||
context "when the setting does not exist" do
|
context "when the setting does not exist" do
|
||||||
it "returns an error message" do
|
it "returns an error message" do
|
||||||
result = setting_context("this_setting_does_not_exist").invoke(bot_user, llm)
|
result = setting_context("this_setting_does_not_exist").invoke
|
||||||
|
|
||||||
expect(result[:context]).to eq("This setting does not exist")
|
expect(result[:context]).to eq("This setting does not exist")
|
||||||
end
|
end
|
||||||
|
|
|
@ -15,8 +15,12 @@ RSpec.describe DiscourseAi::AiBot::Tools::Summarize do
|
||||||
|
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
|
||||||
summarization =
|
summarization =
|
||||||
described_class.new({ topic_id: post.topic_id, guidance: "why did it happen?" })
|
described_class.new(
|
||||||
info = summarization.invoke(bot_user, llm, &progress_blk)
|
{ topic_id: post.topic_id, guidance: "why did it happen?" },
|
||||||
|
bot_user: bot_user,
|
||||||
|
llm: llm,
|
||||||
|
)
|
||||||
|
info = summarization.invoke(&progress_blk)
|
||||||
|
|
||||||
expect(info).to include("Topic summarized")
|
expect(info).to include("Topic summarized")
|
||||||
expect(summarization.custom_raw).to include(summary)
|
expect(summarization.custom_raw).to include(summary)
|
||||||
|
@ -34,8 +38,12 @@ RSpec.describe DiscourseAi::AiBot::Tools::Summarize do
|
||||||
|
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do
|
||||||
summarization =
|
summarization =
|
||||||
described_class.new({ topic_id: post.topic_id, guidance: "why did it happen?" })
|
described_class.new(
|
||||||
info = summarization.invoke(bot_user, llm, &progress_blk)
|
{ topic_id: post.topic_id, guidance: "why did it happen?" },
|
||||||
|
bot_user: bot_user,
|
||||||
|
llm: llm,
|
||||||
|
)
|
||||||
|
info = summarization.invoke(&progress_blk)
|
||||||
|
|
||||||
expect(info).not_to include(post.raw)
|
expect(info).not_to include(post.raw)
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Time do
|
||||||
freeze_time
|
freeze_time
|
||||||
|
|
||||||
args = { timezone: "America/Los_Angeles" }
|
args = { timezone: "America/Los_Angeles" }
|
||||||
info = described_class.new(args).invoke(bot_user, llm)
|
info = described_class.new(args, bot_user: bot_user, llm: llm).invoke
|
||||||
|
|
||||||
expect(info).to eq({ args: args, time: Time.now.in_time_zone("America/Los_Angeles").to_s })
|
expect(info).to eq({ args: args, time: Time.now.in_time_zone("America/Los_Angeles").to_s })
|
||||||
expect(info.to_s).not_to include("not_here")
|
expect(info.to_s).not_to include("not_here")
|
||||||
|
|
|
@ -21,8 +21,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do
|
||||||
"<html><head><title>Test</title></head><body><p>This is a simplified version of the webpage content.</p></body></html>",
|
"<html><head><title>Test</title></head><body><p>This is a simplified version of the webpage content.</p></body></html>",
|
||||||
)
|
)
|
||||||
|
|
||||||
tool = described_class.new({ url: url })
|
tool = described_class.new({ url: url }, bot_user: bot_user, llm: llm)
|
||||||
result = tool.invoke(bot_user, llm)
|
result = tool.invoke
|
||||||
|
|
||||||
expect(result).to have_key(:text)
|
expect(result).to have_key(:text)
|
||||||
expect(result[:text]).to eq(processed_text)
|
expect(result[:text]).to eq(processed_text)
|
||||||
|
@ -35,8 +35,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do
|
||||||
# Simulating a failed request
|
# Simulating a failed request
|
||||||
stub_request(:get, url).to_return(status: [500, "Internal Server Error"])
|
stub_request(:get, url).to_return(status: [500, "Internal Server Error"])
|
||||||
|
|
||||||
tool = described_class.new({ url: url })
|
tool = described_class.new({ url: url }, bot_user: bot_user, llm: llm)
|
||||||
result = tool.invoke(bot_user, llm)
|
result = tool.invoke
|
||||||
|
|
||||||
expect(result).to have_key(:error)
|
expect(result).to have_key(:error)
|
||||||
expect(result[:error]).to include("Failed to retrieve the web page")
|
expect(result[:error]).to include("Failed to retrieve the web page")
|
||||||
|
@ -50,8 +50,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do
|
||||||
simple_html = "<html><body><p>Simple content.</p></body></html>"
|
simple_html = "<html><body><p>Simple content.</p></body></html>"
|
||||||
stub_request(:get, url).to_return(status: 200, body: simple_html)
|
stub_request(:get, url).to_return(status: 200, body: simple_html)
|
||||||
|
|
||||||
tool = described_class.new({ url: url })
|
tool = described_class.new({ url: url }, bot_user: bot_user, llm: llm)
|
||||||
result = tool.invoke(bot_user, llm)
|
result = tool.invoke
|
||||||
|
|
||||||
expect(result[:text]).to eq("Simple content.")
|
expect(result[:text]).to eq("Simple content.")
|
||||||
end
|
end
|
||||||
|
@ -61,8 +61,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do
|
||||||
"<html><head><script>console.log('Ignore me')</script></head><body><style>body { background-color: #000; }</style><p>Only relevant content here.</p></body></html>"
|
"<html><head><script>console.log('Ignore me')</script></head><body><style>body { background-color: #000; }</style><p>Only relevant content here.</p></body></html>"
|
||||||
stub_request(:get, url).to_return(status: 200, body: complex_html)
|
stub_request(:get, url).to_return(status: 200, body: complex_html)
|
||||||
|
|
||||||
tool = described_class.new({ url: url })
|
tool = described_class.new({ url: url }, bot_user: bot_user, llm: llm)
|
||||||
result = tool.invoke(bot_user, llm)
|
result = tool.invoke
|
||||||
|
|
||||||
expect(result[:text]).to eq("Only relevant content here.")
|
expect(result[:text]).to eq("Only relevant content here.")
|
||||||
end
|
end
|
||||||
|
@ -72,8 +72,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do
|
||||||
"<html><body><div><section><p>Nested paragraph 1.</p></section><section><p>Nested paragraph 2.</p></section></div></body></html>"
|
"<html><body><div><section><p>Nested paragraph 1.</p></section><section><p>Nested paragraph 2.</p></section></div></body></html>"
|
||||||
stub_request(:get, url).to_return(status: 200, body: nested_html)
|
stub_request(:get, url).to_return(status: 200, body: nested_html)
|
||||||
|
|
||||||
tool = described_class.new({ url: url })
|
tool = described_class.new({ url: url }, bot_user: bot_user, llm: llm)
|
||||||
result = tool.invoke(bot_user, llm)
|
result = tool.invoke
|
||||||
|
|
||||||
expect(result[:text]).to eq("Nested paragraph 1. Nested paragraph 2.")
|
expect(result[:text]).to eq("Nested paragraph 1. Nested paragraph 2.")
|
||||||
end
|
end
|
||||||
|
@ -88,8 +88,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do
|
||||||
stub_request(:get, initial_url).to_return(status: 302, headers: { "Location" => final_url })
|
stub_request(:get, initial_url).to_return(status: 302, headers: { "Location" => final_url })
|
||||||
stub_request(:get, final_url).to_return(status: 200, body: redirect_html)
|
stub_request(:get, final_url).to_return(status: 200, body: redirect_html)
|
||||||
|
|
||||||
tool = described_class.new({ url: initial_url })
|
tool = described_class.new({ url: initial_url }, bot_user: bot_user, llm: llm)
|
||||||
result = tool.invoke(bot_user, llm)
|
result = tool.invoke
|
||||||
|
|
||||||
expect(result[:url]).to eq(final_url)
|
expect(result[:url]).to eq(final_url)
|
||||||
expect(result[:text]).to eq("Redirected content.")
|
expect(result[:text]).to eq("Redirected content.")
|
||||||
|
|
Loading…
Reference in New Issue