diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 322e142b..d935b336 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -68,13 +68,13 @@ module DiscourseAi result = llm.generate(prompt, **llm_kwargs) do |partial, cancel| - tools = persona.find_tools(partial) + tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context) if (tools.present?) tool_found = true tools[0..MAX_TOOLS].each do |tool| - ongoing_chain &&= tool.chain_next_response? process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context) + ongoing_chain &&= tool.chain_next_response? end else update_blk.call(partial, cancel, nil) @@ -137,7 +137,7 @@ module DiscourseAi update_blk.call("", cancel, build_placeholder(tool.summary, "")) result = - tool.invoke(bot_user, llm) do |progress| + tool.invoke do |progress| placeholder = build_placeholder(tool.summary, progress) update_blk.call("", cancel, placeholder) end diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index c0b83c41..24010190 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -174,16 +174,21 @@ module DiscourseAi prompt end - def find_tools(partial) + def find_tools(partial, bot_user:, llm:, context:) return [] if !partial.include?("") parsed_function = Nokogiri::HTML5.fragment(partial) - parsed_function.css("invoke").map { |fragment| find_tool(fragment) }.compact + parsed_function + .css("invoke") + .map do |fragment| + tool_instance(fragment, bot_user: bot_user, llm: llm, context: context) + end + .compact end protected - def find_tool(parsed_function) + def tool_instance(parsed_function, bot_user:, llm:, context:) function_id = parsed_function.at("tool_id")&.text function_name = parsed_function.at("tool_name")&.text return nil if function_name.nil? @@ -212,6 +217,9 @@ module DiscourseAi arguments, tool_call_id: function_id || function_name, persona_options: options[tool_klass].to_h, + bot_user: bot_user, + llm: llm, + context: context, ) end diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 260331f3..834a231a 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -350,6 +350,7 @@ module DiscourseAi ) context[:post_id] = post.id context[:topic_id] = post.topic_id + context[:private_message] = post.topic.private_message? reply_user = bot.bot_user if bot.persona.class.respond_to?(:user_id) diff --git a/lib/ai_bot/tools/dall_e.rb b/lib/ai_bot/tools/dall_e.rb index 96548f6b..3686ba97 100644 --- a/lib/ai_bot/tools/dall_e.rb +++ b/lib/ai_bot/tools/dall_e.rb @@ -33,7 +33,7 @@ module DiscourseAi false end - def invoke(bot_user, _llm) + def invoke # max 4 prompts max_prompts = prompts.take(4) progress = prompts.first @@ -88,7 +88,12 @@ module DiscourseAi file.rewind uploads << { prompt: image[:revised_prompt], - upload: UploadCreator.new(file, "image.png").create_for(bot_user.id), + upload: + UploadCreator.new( + file, + "image.png", + for_private_message: context[:private_message], + ).create_for(bot_user.id), } end end @@ -99,9 +104,7 @@ module DiscourseAi [grid] #{ uploads - .map do |item| - "![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})" - end + .map { |item| "![#{item[:prompt].gsub(/\|\'\"/, "")}](#{item[:upload].short_url})" } .join(" ") } [/grid] diff --git a/lib/ai_bot/tools/db_schema.rb b/lib/ai_bot/tools/db_schema.rb index 101ac4c3..1c11fb68 100644 --- a/lib/ai_bot/tools/db_schema.rb +++ b/lib/ai_bot/tools/db_schema.rb @@ -28,7 +28,7 @@ module DiscourseAi parameters[:tables] end - def invoke(_bot_user, _llm) + def invoke tables_arr = tables.split(",").map(&:strip) table_info = {} diff --git a/lib/ai_bot/tools/discourse_meta_search.rb b/lib/ai_bot/tools/discourse_meta_search.rb index bd4e9ebf..437cd936 100644 --- a/lib/ai_bot/tools/discourse_meta_search.rb +++ b/lib/ai_bot/tools/discourse_meta_search.rb @@ -78,7 +78,7 @@ module DiscourseAi parameters.slice(:category, :user, :order, :max_posts, :tags, :before, :after, :status) end - def invoke(bot_user, llm) + def invoke search_string = search_args.reduce(+parameters[:search_query].to_s) do |memo, (key, value)| return memo if value.blank? diff --git a/lib/ai_bot/tools/github_file_content.rb b/lib/ai_bot/tools/github_file_content.rb index 91fe9ab5..c3cde356 100644 --- a/lib/ai_bot/tools/github_file_content.rb +++ b/lib/ai_bot/tools/github_file_content.rb @@ -52,7 +52,7 @@ module DiscourseAi { repo_name: repo_name, file_paths: file_paths.join(", "), branch: branch } end - def invoke(_bot_user, llm) + def invoke owner, repo = repo_name.split("/") file_contents = {} missing_files = [] diff --git a/lib/ai_bot/tools/github_pull_request_diff.rb b/lib/ai_bot/tools/github_pull_request_diff.rb index 5fa8ca4b..90d43f82 100644 --- a/lib/ai_bot/tools/github_pull_request_diff.rb +++ b/lib/ai_bot/tools/github_pull_request_diff.rb @@ -43,7 +43,7 @@ module DiscourseAi @url end - def invoke(_bot_user, llm) + def invoke api_url = "https://api.github.com/repos/#{repo}/pulls/#{pull_id}" @url = "https://github.com/#{repo}/pull/#{pull_id}" diff --git a/lib/ai_bot/tools/github_search_code.rb b/lib/ai_bot/tools/github_search_code.rb index 77d1f005..a0e1acbd 100644 --- a/lib/ai_bot/tools/github_search_code.rb +++ b/lib/ai_bot/tools/github_search_code.rb @@ -41,7 +41,7 @@ module DiscourseAi { repo: repo, query: query } end - def invoke(_bot_user, llm) + def invoke api_url = "https://api.github.com/search/code?q=#{query}+repo:#{repo}" response_code = "unknown error" diff --git a/lib/ai_bot/tools/google.rb b/lib/ai_bot/tools/google.rb index a41a9a85..ba3ecb5d 100644 --- a/lib/ai_bot/tools/google.rb +++ b/lib/ai_bot/tools/google.rb @@ -27,7 +27,7 @@ module DiscourseAi parameters[:query].to_s.strip end - def invoke(bot_user, llm) + def invoke yield(query) api_key = SiteSetting.ai_google_custom_search_api_key diff --git a/lib/ai_bot/tools/image.rb b/lib/ai_bot/tools/image.rb index d25810c4..72e575c3 100644 --- a/lib/ai_bot/tools/image.rb +++ b/lib/ai_bot/tools/image.rb @@ -40,6 +40,11 @@ module DiscourseAi "image" end + def initialize(*args, **kwargs) + super + @chain_next_response = false + end + def prompts parameters[:prompts] end @@ -53,10 +58,10 @@ module DiscourseAi end def chain_next_response? - false + @chain_next_response end - def invoke(bot_user, _llm) + def invoke # max 4 prompts selected_prompts = prompts.take(4) seeds = seeds.take(4) if seeds @@ -101,7 +106,15 @@ module DiscourseAi results = threads.map(&:value).compact if !results.present? - return { prompts: prompts, error: "Something went wrong, could not generate image" } + @chain_next_response = true + return( + { + prompts: prompts, + error: + "Something went wrong inform user you could not generate image, check Discourse logs, give up don't try anymore", + give_up: true, + } + ) end uploads = [] @@ -114,7 +127,12 @@ module DiscourseAi file.rewind uploads << { prompt: prompts[index], - upload: UploadCreator.new(file, "image.png").create_for(bot_user.id), + upload: + UploadCreator.new( + file, + "image.png", + for_private_message: context[:private_message], + ).create_for(bot_user.id), seed: image[:seed], } end @@ -126,9 +144,7 @@ module DiscourseAi [grid] #{ uploads - .map do |item| - "![#{item[:prompt].gsub(/\|\'\"/, "")}|50%](#{item[:upload].short_url})" - end + .map { |item| "![#{item[:prompt].gsub(/\|\'\"/, "")}](#{item[:upload].short_url})" } .join(" ") } [/grid] diff --git a/lib/ai_bot/tools/list_categories.rb b/lib/ai_bot/tools/list_categories.rb index 52bf0cb6..4a32ce13 100644 --- a/lib/ai_bot/tools/list_categories.rb +++ b/lib/ai_bot/tools/list_categories.rb @@ -16,7 +16,7 @@ module DiscourseAi "categories" end - def invoke(_bot_user, _llm) + def invoke columns = { name: "Name", slug: "Slug", diff --git a/lib/ai_bot/tools/list_tags.rb b/lib/ai_bot/tools/list_tags.rb index e12c2491..6ff702bf 100644 --- a/lib/ai_bot/tools/list_tags.rb +++ b/lib/ai_bot/tools/list_tags.rb @@ -15,7 +15,7 @@ module DiscourseAi "tags" end - def invoke(_bot_user, _llm) + def invoke column_names = { name: "Name", public_topic_count: "Topic Count" } tags = diff --git a/lib/ai_bot/tools/random_picker.rb b/lib/ai_bot/tools/random_picker.rb index 60000bb6..c1ed16f9 100644 --- a/lib/ai_bot/tools/random_picker.rb +++ b/lib/ai_bot/tools/random_picker.rb @@ -30,7 +30,7 @@ module DiscourseAi parameters[:options] end - def invoke(_bot_user, _llm) + def invoke result = nil # can be a naive list of strings if options.none? { |option| option.match?(/\A\d+-\d+\z/) || option.include?(",") } diff --git a/lib/ai_bot/tools/read.rb b/lib/ai_bot/tools/read.rb index 12a522e6..1f8de7a4 100644 --- a/lib/ai_bot/tools/read.rb +++ b/lib/ai_bot/tools/read.rb @@ -39,7 +39,7 @@ module DiscourseAi parameters[:post_number] end - def invoke(_bot_user, llm) + def invoke not_found = { topic_id: topic_id, description: "Topic not found" } @title = "" diff --git a/lib/ai_bot/tools/search.rb b/lib/ai_bot/tools/search.rb index c2139407..d29f6dd7 100644 --- a/lib/ai_bot/tools/search.rb +++ b/lib/ai_bot/tools/search.rb @@ -91,7 +91,7 @@ module DiscourseAi parameters.slice(:category, :user, :order, :max_posts, :tags, :before, :after, :status) end - def invoke(bot_user, llm) + def invoke search_string = search_args.reduce(+parameters[:search_query].to_s) do |memo, (key, value)| return memo if value.blank? diff --git a/lib/ai_bot/tools/search_settings.rb b/lib/ai_bot/tools/search_settings.rb index 504b7f0d..c89acc0f 100644 --- a/lib/ai_bot/tools/search_settings.rb +++ b/lib/ai_bot/tools/search_settings.rb @@ -31,7 +31,7 @@ module DiscourseAi parameters[:query].to_s end - def invoke(_bot_user, _llm) + def invoke @last_num_results = 0 terms = query.split(",").map(&:strip).map(&:downcase).reject(&:blank?) diff --git a/lib/ai_bot/tools/setting_context.rb b/lib/ai_bot/tools/setting_context.rb index ef08d323..17996af0 100644 --- a/lib/ai_bot/tools/setting_context.rb +++ b/lib/ai_bot/tools/setting_context.rb @@ -47,7 +47,7 @@ module DiscourseAi parameters[:setting_name] end - def invoke(_bot_user, llm) + def invoke if !self.class.rg_installed? return( { diff --git a/lib/ai_bot/tools/summarize.rb b/lib/ai_bot/tools/summarize.rb index 113bd215..3a916301 100644 --- a/lib/ai_bot/tools/summarize.rb +++ b/lib/ai_bot/tools/summarize.rb @@ -48,7 +48,7 @@ module DiscourseAi @last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found") end - def invoke(bot_user, llm, &progress_blk) + def invoke(&progress_blk) topic = nil if topic_id > 0 topic = Topic.find_by(id: topic_id) diff --git a/lib/ai_bot/tools/time.rb b/lib/ai_bot/tools/time.rb index 563f2ceb..bd88eeed 100644 --- a/lib/ai_bot/tools/time.rb +++ b/lib/ai_bot/tools/time.rb @@ -27,7 +27,7 @@ module DiscourseAi parameters[:timezone].to_s end - def invoke(_bot_user, _llm) + def invoke time = begin ::Time.now.in_time_zone(timezone) diff --git a/lib/ai_bot/tools/tool.rb b/lib/ai_bot/tools/tool.rb index e9c3bd49..9ccd2a1e 100644 --- a/lib/ai_bot/tools/tool.rb +++ b/lib/ai_bot/tools/tool.rb @@ -31,15 +31,24 @@ module DiscourseAi end attr_accessor :custom_raw + attr_reader :tool_call_id, :persona_options, :bot_user, :llm, :context, :parameters - def initialize(parameters, tool_call_id: "", persona_options: {}) + def initialize( + parameters, + tool_call_id: "", + persona_options: {}, + bot_user:, + llm:, + context: {} + ) @parameters = parameters @tool_call_id = tool_call_id @persona_options = persona_options + @bot_user = bot_user + @llm = llm + @context = context end - attr_reader :parameters, :tool_call_id - def name self.class.name end diff --git a/lib/ai_bot/tools/web_browser.rb b/lib/ai_bot/tools/web_browser.rb index 86f9d9aa..eb4cf020 100644 --- a/lib/ai_bot/tools/web_browser.rb +++ b/lib/ai_bot/tools/web_browser.rb @@ -32,7 +32,7 @@ module DiscourseAi @url end - def invoke(_bot_user, llm) + def invoke send_http_request(url, follow_redirects: true) do |response| if response.code == "200" html = read_response_body(response) diff --git a/plugin.rb b/plugin.rb index b90801d7..ad255b70 100644 --- a/plugin.rb +++ b/plugin.rb @@ -66,6 +66,12 @@ after_initialize do reloadable_patch { |plugin| Guardian.prepend DiscourseAi::GuardianExtensions } register_modifier(:post_should_secure_uploads?) do |_, _, topic| - false if topic.private_message? && SharedAiConversation.exists?(target: topic) + if topic.private_message? && SharedAiConversation.exists?(target: topic) + false + else + # revert to default behavior + # even though this can be shortened this is the clearest way to express it + nil + end end end diff --git a/spec/lib/modules/ai_bot/bot_spec.rb b/spec/lib/modules/ai_bot/bot_spec.rb index 7211b04d..d973317b 100644 --- a/spec/lib/modules/ai_bot/bot_spec.rb +++ b/spec/lib/modules/ai_bot/bot_spec.rb @@ -63,7 +63,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do context "when using function chaining" do it "yields a loading placeholder while proceeds to invoke the command" do - tool = DiscourseAi::AiBot::Tools::ListCategories.new({}) + tool = DiscourseAi::AiBot::Tools::ListCategories.new({}, bot_user: nil, llm: nil) partial_placeholder = +(<<~HTML)
#{tool.summary} diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index a4e71339..b803297e 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -99,7 +99,14 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do XML - dall_e1, dall_e2 = tools = DiscourseAi::AiBot::Personas::DallE3.new.find_tools(xml) + dall_e1, dall_e2 = + tools = + DiscourseAi::AiBot::Personas::DallE3.new.find_tools( + xml, + bot_user: nil, + llm: nil, + context: nil, + ) expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"]) expect(dall_e2.parameters[:prompts]).to eq(["pic3"]) expect(tools.length).to eq(2) diff --git a/spec/lib/modules/ai_bot/tools/dall_e_spec.rb b/spec/lib/modules/ai_bot/tools/dall_e_spec.rb index 7bbde56b..3d5f5a13 100644 --- a/spec/lib/modules/ai_bot/tools/dall_e_spec.rb +++ b/spec/lib/modules/ai_bot/tools/dall_e_spec.rb @@ -1,14 +1,16 @@ #frozen_string_literal: true RSpec.describe DiscourseAi::AiBot::Tools::DallE do - subject(:dall_e) { described_class.new({ prompts: prompts }) } - let(:prompts) { ["a pink cow", "a red cow"] } let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:progress_blk) { Proc.new {} } + let(:dall_e) do + described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user, context: {}) + end + before { SiteSetting.ai_bot_enabled = true } describe "#process" do @@ -34,12 +36,12 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do end .to_return(status: 200, body: { data: data }.to_json) - info = dall_e.invoke(bot_user, llm, &progress_blk).to_json + info = dall_e.invoke(&progress_blk).to_json expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"]) - expect(subject.custom_raw).to include("upload://") - expect(subject.custom_raw).to include("[grid]") - expect(subject.custom_raw).to include("a pink cow 1") + expect(dall_e.custom_raw).to include("upload://") + expect(dall_e.custom_raw).to include("[grid]") + expect(dall_e.custom_raw).to include("a pink cow 1") end it "can generate correct info" do @@ -59,12 +61,12 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do end .to_return(status: 200, body: { data: data }.to_json) - info = dall_e.invoke(bot_user, llm, &progress_blk).to_json + info = dall_e.invoke(&progress_blk).to_json expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"]) - expect(subject.custom_raw).to include("upload://") - expect(subject.custom_raw).to include("[grid]") - expect(subject.custom_raw).to include("a pink cow 1") + expect(dall_e.custom_raw).to include("upload://") + expect(dall_e.custom_raw).to include("[grid]") + expect(dall_e.custom_raw).to include("a pink cow 1") end end end diff --git a/spec/lib/modules/ai_bot/tools/db_schema_spec.rb b/spec/lib/modules/ai_bot/tools/db_schema_spec.rb index f545477c..93647931 100644 --- a/spec/lib/modules/ai_bot/tools/db_schema_spec.rb +++ b/spec/lib/modules/ai_bot/tools/db_schema_spec.rb @@ -7,7 +7,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DbSchema do before { SiteSetting.ai_bot_enabled = true } describe "#process" do it "returns rich schema for tables" do - result = described_class.new({ tables: "posts,topics" }).invoke(bot_user, llm) + result = described_class.new({ tables: "posts,topics" }, bot_user: bot_user, llm: llm).invoke expect(result[:schema_info]).to include("raw text") expect(result[:schema_info]).to include("views integer") diff --git a/spec/lib/modules/ai_bot/tools/discourse_meta_search_spec.rb b/spec/lib/modules/ai_bot/tools/discourse_meta_search_spec.rb index d2c47c19..f29d15ff 100644 --- a/spec/lib/modules/ai_bot/tools/discourse_meta_search_spec.rb +++ b/spec/lib/modules/ai_bot/tools/discourse_meta_search_spec.rb @@ -45,8 +45,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::DiscourseMetaSearch do }, ) - search = described_class.new({ search_query: "test" }) - results = search.invoke(bot_user, llm, &progress_blk) + search = described_class.new({ search_query: "test" }, bot_user: bot_user, llm: llm) + results = search.invoke(&progress_blk) expect(results[:rows].length).to eq(20) expect(results[:rows].first[results[:column_names].index("category")]).to eq( @@ -71,8 +71,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::DiscourseMetaSearch do .to_h .symbolize_keys - search = described_class.new(params) - results = search.invoke(bot_user, llm, &progress_blk) + search = described_class.new(params, bot_user: bot_user, llm: llm) + results = search.invoke(&progress_blk) expect(results[:args]).to eq(params) end diff --git a/spec/lib/modules/ai_bot/tools/github_file_content_spec.rb b/spec/lib/modules/ai_bot/tools/github_file_content_spec.rb index aa2263bd..b63aa1ec 100644 --- a/spec/lib/modules/ai_bot/tools/github_file_content_spec.rb +++ b/spec/lib/modules/ai_bot/tools/github_file_content_spec.rb @@ -3,6 +3,8 @@ require "rails_helper" RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") } + let(:tool) do described_class.new( { @@ -10,11 +12,11 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do file_paths: %w[lib/database/connection.rb lib/ai_bot/tools/github_pull_request_diff.rb], branch: "8b382d6098fde879d28bbee68d3cbe0a193e4ffc", }, + bot_user: nil, + llm: llm, ) end - let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") } - describe "#invoke" do before do stub_request( @@ -35,7 +37,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do end it "retrieves the content of the specified GitHub files" do - result = tool.invoke(nil, llm) + result = tool.invoke expected = { file_contents: "File Path: lib/database/connection.rb:\ncontent of connection.rb\nFile Path: lib/ai_bot/tools/github_pull_request_diff.rb:\ncontent of github_pull_request_diff.rb", diff --git a/spec/lib/modules/ai_bot/tools/github_pull_request_diff_spec.rb b/spec/lib/modules/ai_bot/tools/github_pull_request_diff_spec.rb index 0399108b..30489064 100644 --- a/spec/lib/modules/ai_bot/tools/github_pull_request_diff_spec.rb +++ b/spec/lib/modules/ai_bot/tools/github_pull_request_diff_spec.rb @@ -3,9 +3,9 @@ require "rails_helper" RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do - let(:tool) { described_class.new({ repo: repo, pull_id: pull_id }) } let(:bot_user) { Fabricate(:user) } let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") } + let(:tool) { described_class.new({ repo: repo, pull_id: pull_id }, bot_user: bot_user, llm: llm) } context "with #sort_and_shorten_diff" do it "sorts and shortens the diff without dropping data" do @@ -64,7 +64,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do }, ).to_return(status: 200, body: diff) - result = tool.invoke(bot_user, llm) + result = tool.invoke expect(result[:diff]).to eq(diff) expect(result[:error]).to be_nil end @@ -80,7 +80,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do }, ).to_return(status: 200, body: diff) - result = tool.invoke(bot_user, llm) + result = tool.invoke expect(result[:diff]).to eq(diff) expect(result[:error]).to be_nil end @@ -98,7 +98,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do }, ).to_return(status: 404) - result = tool.invoke(bot_user, nil) + result = tool.invoke expect(result[:diff]).to be_nil expect(result[:error]).to include("Failed to retrieve the diff") end diff --git a/spec/lib/modules/ai_bot/tools/github_search_code_spec.rb b/spec/lib/modules/ai_bot/tools/github_search_code_spec.rb index 789ad8c2..e43ae91b 100644 --- a/spec/lib/modules/ai_bot/tools/github_search_code_spec.rb +++ b/spec/lib/modules/ai_bot/tools/github_search_code_spec.rb @@ -3,9 +3,9 @@ require "rails_helper" RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do - let(:tool) { described_class.new({ repo: repo, query: query }) } let(:bot_user) { Fabricate(:user) } let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") } + let(:tool) { described_class.new({ repo: repo, query: query }, bot_user: bot_user, llm: llm) } context "with valid search results" do let(:repo) { "discourse/discourse" } @@ -34,7 +34,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do }.to_json, ) - result = tool.invoke(bot_user, llm) + result = tool.invoke expect(result[:search_results]).to include("def hello\n puts 'hello'\nend") expect(result[:search_results]).to include("test/hello.rb") expect(result[:error]).to be_nil @@ -64,7 +64,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do }, ).to_return(status: 200, body: { total_count: 0, items: [] }.to_json) - result = tool.invoke(bot_user, llm) + result = tool.invoke expect(result[:search_results]).to be_empty expect(result[:error]).to be_nil end @@ -85,7 +85,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do }, ).to_return(status: 403) - result = tool.invoke(bot_user, llm) + result = tool.invoke expect(result[:search_results]).to be_nil expect(result[:error]).to include("Failed to perform code search") end diff --git a/spec/lib/modules/ai_bot/tools/google_spec.rb b/spec/lib/modules/ai_bot/tools/google_spec.rb index 4cb2d6d3..3e2dc96d 100644 --- a/spec/lib/modules/ai_bot/tools/google_spec.rb +++ b/spec/lib/modules/ai_bot/tools/google_spec.rb @@ -1,18 +1,15 @@ #frozen_string_literal: true RSpec.describe DiscourseAi::AiBot::Tools::Google do - subject(:search) { described_class.new({ query: "some search term" }) } - let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:progress_blk) { Proc.new {} } + let(:search) { described_class.new({ query: "some search term" }, bot_user: bot_user, llm: llm) } before { SiteSetting.ai_bot_enabled = true } describe "#process" do it "will not explode if there are no results" do - post = Fabricate(:post) - SiteSetting.ai_google_custom_search_api_key = "abc" SiteSetting.ai_google_custom_search_cx = "cx" @@ -23,15 +20,13 @@ RSpec.describe DiscourseAi::AiBot::Tools::Google do "https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term", ).to_return(status: 200, body: json_text, headers: {}) - info = search.invoke(bot_user, llm, &progress_blk).to_json + info = search.invoke(&progress_blk).to_json expect(search.results_count).to eq(0) expect(info).to_not include("oops") end it "can generate correct info" do - post = Fabricate(:post) - SiteSetting.ai_google_custom_search_api_key = "abc" SiteSetting.ai_google_custom_search_cx = "cx" @@ -63,7 +58,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Google do "https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term", ).to_return(status: 200, body: json_text, headers: {}) - info = search.invoke(bot_user, llm, &progress_blk).to_json + info = search.invoke(&progress_blk).to_json expect(search.results_count).to eq(2) expect(info).to include("title1") diff --git a/spec/lib/modules/ai_bot/tools/image_spec.rb b/spec/lib/modules/ai_bot/tools/image_spec.rb index 040c5230..7ddf6073 100644 --- a/spec/lib/modules/ai_bot/tools/image_spec.rb +++ b/spec/lib/modules/ai_bot/tools/image_spec.rb @@ -1,8 +1,6 @@ #frozen_string_literal: true RSpec.describe DiscourseAi::AiBot::Tools::Image do - subject(:tool) { described_class.new({ prompts: prompts, seeds: [99, 32] }) } - let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:progress_blk) { Proc.new {} } @@ -10,11 +8,21 @@ RSpec.describe DiscourseAi::AiBot::Tools::Image do let(:prompts) { ["a pink cow", "a red cow"] } + let(:tool) do + described_class.new( + { prompts: prompts, seeds: [99, 32] }, + bot_user: bot_user, + llm: llm, + context: { + }, + ) + end + before { SiteSetting.ai_bot_enabled = true } describe "#process" do it "can generate correct info" do - post = Fabricate(:post) + _post = Fabricate(:post) SiteSetting.ai_stability_api_url = "https://api.stability.dev" SiteSetting.ai_stability_api_key = "abc" @@ -36,7 +44,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Image do end .to_return(status: 200, body: { artifacts: artifacts }.to_json) - info = tool.invoke(bot_user, llm, &progress_blk).to_json + info = tool.invoke(&progress_blk).to_json expect(JSON.parse(info)).to eq("prompts" => ["a pink cow", "a red cow"], "seeds" => [99, 99]) expect(tool.custom_raw).to include("upload://") diff --git a/spec/lib/modules/ai_bot/tools/list_categories_spec.rb b/spec/lib/modules/ai_bot/tools/list_categories_spec.rb index 64844a44..1dd6276f 100644 --- a/spec/lib/modules/ai_bot/tools/list_categories_spec.rb +++ b/spec/lib/modules/ai_bot/tools/list_categories_spec.rb @@ -10,7 +10,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ListCategories do it "list available categories" do Fabricate(:category, name: "america", posts_year: 999) - info = described_class.new({}).invoke(bot_user, llm).to_s + info = described_class.new({}, bot_user: bot_user, llm: llm).invoke.to_s expect(info).to include("america") expect(info).to include("999") diff --git a/spec/lib/modules/ai_bot/tools/list_tags_spec.rb b/spec/lib/modules/ai_bot/tools/list_tags_spec.rb index 9cddf419..f813a9ac 100644 --- a/spec/lib/modules/ai_bot/tools/list_tags_spec.rb +++ b/spec/lib/modules/ai_bot/tools/list_tags_spec.rb @@ -14,7 +14,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ListTags do Fabricate(:tag, name: "america", public_topic_count: 100) Fabricate(:tag, name: "not_here", public_topic_count: 0) - info = described_class.new({}).invoke(bot_user, llm) + info = described_class.new({}, bot_user: bot_user, llm: llm).invoke expect(info.to_s).to include("america") expect(info.to_s).not_to include("not_here") diff --git a/spec/lib/modules/ai_bot/tools/random_picker_spec.rb b/spec/lib/modules/ai_bot/tools/random_picker_spec.rb index 80c24548..6e5930bf 100644 --- a/spec/lib/modules/ai_bot/tools/random_picker_spec.rb +++ b/spec/lib/modules/ai_bot/tools/random_picker_spec.rb @@ -4,7 +4,7 @@ require "rails_helper" RSpec.describe DiscourseAi::AiBot::Tools::RandomPicker do describe "#invoke" do - subject { described_class.new({ options: options }).invoke(nil, nil) } + subject { described_class.new({ options: options }, bot_user: nil, llm: nil).invoke } context "with options as simple list of strings" do let(:options) { %w[apple banana cherry] } diff --git a/spec/lib/modules/ai_bot/tools/read_spec.rb b/spec/lib/modules/ai_bot/tools/read_spec.rb index 309d3286..ffe45cac 100644 --- a/spec/lib/modules/ai_bot/tools/read_spec.rb +++ b/spec/lib/modules/ai_bot/tools/read_spec.rb @@ -1,10 +1,9 @@ #frozen_string_literal: true RSpec.describe DiscourseAi::AiBot::Tools::Read do - subject(:tool) { described_class.new({ topic_id: topic_with_tags.id }) } - let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } + let(:tool) { described_class.new({ topic_id: topic_with_tags.id }, bot_user: bot_user, llm: llm) } fab!(:parent_category) { Fabricate(:category, name: "animals") } fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") } @@ -34,7 +33,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do Fabricate(:post, topic: topic_with_tags, raw: "hello there") Fabricate(:post, topic: topic_with_tags, raw: "mister sam") - results = tool.invoke(bot_user, llm) + results = tool.invoke expect(results[:topic_id]).to eq(topic_id) expect(results[:content]).to include("hello") diff --git a/spec/lib/modules/ai_bot/tools/search_settings_spec.rb b/spec/lib/modules/ai_bot/tools/search_settings_spec.rb index 94763ae9..9c878303 100644 --- a/spec/lib/modules/ai_bot/tools/search_settings_spec.rb +++ b/spec/lib/modules/ai_bot/tools/search_settings_spec.rb @@ -7,18 +7,18 @@ RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do before { SiteSetting.ai_bot_enabled = true } def search_settings(query) - described_class.new({ query: query }) + described_class.new({ query: query }, bot_user: bot_user, llm: llm) end describe "#process" do it "can handle no results" do - results = search_settings("this will not exist frogs").invoke(bot_user, llm) + results = search_settings("this will not exist frogs").invoke expect(results[:args]).to eq({ query: "this will not exist frogs" }) expect(results[:rows]).to eq([]) end it "can return more many settings with no descriptions if there are lots of hits" do - results = search_settings("a").invoke(bot_user, llm) + results = search_settings("a").invoke expect(results[:rows].length).to be > 30 expect(results[:rows][0].length).to eq(1) @@ -26,10 +26,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do it "can return descriptions if there are few matches" do results = - search_settings("this will not be found!@,default_locale,ai_bot_enabled_chat_bots").invoke( - bot_user, - llm, - ) + search_settings("this will not be found!@,default_locale,ai_bot_enabled_chat_bots").invoke expect(results[:rows].length).to eq(2) diff --git a/spec/lib/modules/ai_bot/tools/search_spec.rb b/spec/lib/modules/ai_bot/tools/search_spec.rb index 060103e9..a33f2a14 100644 --- a/spec/lib/modules/ai_bot/tools/search_spec.rb +++ b/spec/lib/modules/ai_bot/tools/search_spec.rb @@ -39,24 +39,37 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do _bot_post = Fabricate(:post) - search = described_class.new({ order: "latest" }, persona_options: persona_options) + search = + described_class.new( + { order: "latest" }, + persona_options: persona_options, + bot_user: bot_user, + llm: llm, + context: { + }, + ) - results = search.invoke(bot_user, llm, &progress_blk) + results = search.invoke(&progress_blk) expect(results[:rows].length).to eq(1) search_post.topic.tags = [] search_post.topic.save! # no longer has the tag funny - results = search.invoke(bot_user, llm, &progress_blk) + results = search.invoke(&progress_blk) expect(results[:rows].length).to eq(0) end it "can handle no results" do _post1 = Fabricate(:post, topic: topic_with_tags) - search = described_class.new({ search_query: "ABDDCDCEDGDG", order: "fake" }) + search = + described_class.new( + { search_query: "ABDDCDCEDGDG", order: "fake" }, + bot_user: bot_user, + llm: llm, + ) - results = search.invoke(bot_user, llm, &progress_blk) + results = search.invoke(&progress_blk) expect(results[:args]).to eq({ search_query: "ABDDCDCEDGDG", order: "fake" }) expect(results[:rows]).to eq([]) @@ -79,7 +92,12 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do ) post1 = Fabricate(:post, topic: topic_with_tags) - search = described_class.new({ search_query: "hello world, sam", status: "public" }) + search = + described_class.new( + { search_query: "hello world, sam", status: "public" }, + llm: llm, + bot_user: bot_user, + ) DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn .any_instance @@ -88,7 +106,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do results = DiscourseAi::Completions::Llm.with_prepared_responses(["#{query}"]) do - search.invoke(bot_user, llm, &progress_blk) + search.invoke(&progress_blk) end expect(results[:args]).to eq({ search_query: "hello world, sam", status: "public" }) @@ -101,9 +119,10 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do post1 = Fabricate(:post, topic: topic_with_tags) - search = described_class.new({ limit: 1, user: post1.user.username }) + search = + described_class.new({ limit: 1, user: post1.user.username }, bot_user: bot_user, llm: llm) - results = search.invoke(bot_user, llm, &progress_blk) + results = search.invoke(&progress_blk) expect(results[:rows].to_s).to include("/subfolder" + post1.url) end @@ -120,18 +139,18 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do .to_h .symbolize_keys - search = described_class.new(params) - results = search.invoke(bot_user, llm, &progress_blk) + search = described_class.new(params, bot_user: bot_user, llm: llm) + results = search.invoke(&progress_blk) expect(results[:args]).to eq(params) end it "returns rich topic information" do post1 = Fabricate(:post, like_count: 1, topic: topic_with_tags) - search = described_class.new({ user: post1.user.username }) + search = described_class.new({ user: post1.user.username }, bot_user: bot_user, llm: llm) post1.topic.update!(views: 100, posts_count: 2, like_count: 10) - results = search.invoke(bot_user, llm, &progress_blk) + results = search.invoke(&progress_blk) row = results[:rows].first category = row[results[:column_names].index("category")] diff --git a/spec/lib/modules/ai_bot/tools/setting_context_spec.rb b/spec/lib/modules/ai_bot/tools/setting_context_spec.rb index 1953947d..d665988f 100644 --- a/spec/lib/modules/ai_bot/tools/setting_context_spec.rb +++ b/spec/lib/modules/ai_bot/tools/setting_context_spec.rb @@ -15,12 +15,12 @@ RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do before { SiteSetting.ai_bot_enabled = true } def setting_context(setting_name) - described_class.new({ setting_name: setting_name }) + described_class.new({ setting_name: setting_name }, bot_user: bot_user, llm: llm) end describe "#execute" do it "returns the context for core setting" do - result = setting_context("moderators_view_emails").invoke(bot_user, llm) + result = setting_context("moderators_view_emails").invoke expect(result[:setting_name]).to eq("moderators_view_emails") @@ -29,7 +29,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do end it "returns the context for plugin setting" do - result = setting_context("ai_bot_enabled").invoke(bot_user, llm) + result = setting_context("ai_bot_enabled").invoke expect(result[:setting_name]).to eq("ai_bot_enabled") expect(result[:context]).to include("ai_bot_enabled:") @@ -37,7 +37,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do context "when the setting does not exist" do it "returns an error message" do - result = setting_context("this_setting_does_not_exist").invoke(bot_user, llm) + result = setting_context("this_setting_does_not_exist").invoke expect(result[:context]).to eq("This setting does not exist") end diff --git a/spec/lib/modules/ai_bot/tools/summarize_spec.rb b/spec/lib/modules/ai_bot/tools/summarize_spec.rb index 1af327ee..6c5f6d84 100644 --- a/spec/lib/modules/ai_bot/tools/summarize_spec.rb +++ b/spec/lib/modules/ai_bot/tools/summarize_spec.rb @@ -15,8 +15,12 @@ RSpec.describe DiscourseAi::AiBot::Tools::Summarize do DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do summarization = - described_class.new({ topic_id: post.topic_id, guidance: "why did it happen?" }) - info = summarization.invoke(bot_user, llm, &progress_blk) + described_class.new( + { topic_id: post.topic_id, guidance: "why did it happen?" }, + bot_user: bot_user, + llm: llm, + ) + info = summarization.invoke(&progress_blk) expect(info).to include("Topic summarized") expect(summarization.custom_raw).to include(summary) @@ -34,8 +38,12 @@ RSpec.describe DiscourseAi::AiBot::Tools::Summarize do DiscourseAi::Completions::Llm.with_prepared_responses([summary]) do summarization = - described_class.new({ topic_id: post.topic_id, guidance: "why did it happen?" }) - info = summarization.invoke(bot_user, llm, &progress_blk) + described_class.new( + { topic_id: post.topic_id, guidance: "why did it happen?" }, + bot_user: bot_user, + llm: llm, + ) + info = summarization.invoke(&progress_blk) expect(info).not_to include(post.raw) diff --git a/spec/lib/modules/ai_bot/tools/time_spec.rb b/spec/lib/modules/ai_bot/tools/time_spec.rb index 6e0ca6fd..a900c968 100644 --- a/spec/lib/modules/ai_bot/tools/time_spec.rb +++ b/spec/lib/modules/ai_bot/tools/time_spec.rb @@ -11,7 +11,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Time do freeze_time args = { timezone: "America/Los_Angeles" } - info = described_class.new(args).invoke(bot_user, llm) + info = described_class.new(args, bot_user: bot_user, llm: llm).invoke expect(info).to eq({ args: args, time: Time.now.in_time_zone("America/Los_Angeles").to_s }) expect(info.to_s).not_to include("not_here") diff --git a/spec/lib/modules/ai_bot/tools/web_browser_spec.rb b/spec/lib/modules/ai_bot/tools/web_browser_spec.rb index 85d3af9e..bb16afc3 100644 --- a/spec/lib/modules/ai_bot/tools/web_browser_spec.rb +++ b/spec/lib/modules/ai_bot/tools/web_browser_spec.rb @@ -21,8 +21,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do "Test

This is a simplified version of the webpage content.

", ) - tool = described_class.new({ url: url }) - result = tool.invoke(bot_user, llm) + tool = described_class.new({ url: url }, bot_user: bot_user, llm: llm) + result = tool.invoke expect(result).to have_key(:text) expect(result[:text]).to eq(processed_text) @@ -35,8 +35,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do # Simulating a failed request stub_request(:get, url).to_return(status: [500, "Internal Server Error"]) - tool = described_class.new({ url: url }) - result = tool.invoke(bot_user, llm) + tool = described_class.new({ url: url }, bot_user: bot_user, llm: llm) + result = tool.invoke expect(result).to have_key(:error) expect(result[:error]).to include("Failed to retrieve the web page") @@ -50,8 +50,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do simple_html = "

Simple content.

" stub_request(:get, url).to_return(status: 200, body: simple_html) - tool = described_class.new({ url: url }) - result = tool.invoke(bot_user, llm) + tool = described_class.new({ url: url }, bot_user: bot_user, llm: llm) + result = tool.invoke expect(result[:text]).to eq("Simple content.") end @@ -61,8 +61,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do "

Only relevant content here.

" stub_request(:get, url).to_return(status: 200, body: complex_html) - tool = described_class.new({ url: url }) - result = tool.invoke(bot_user, llm) + tool = described_class.new({ url: url }, bot_user: bot_user, llm: llm) + result = tool.invoke expect(result[:text]).to eq("Only relevant content here.") end @@ -72,8 +72,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do "

Nested paragraph 1.

Nested paragraph 2.

" stub_request(:get, url).to_return(status: 200, body: nested_html) - tool = described_class.new({ url: url }) - result = tool.invoke(bot_user, llm) + tool = described_class.new({ url: url }, bot_user: bot_user, llm: llm) + result = tool.invoke expect(result[:text]).to eq("Nested paragraph 1. Nested paragraph 2.") end @@ -88,8 +88,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do stub_request(:get, initial_url).to_return(status: 302, headers: { "Location" => final_url }) stub_request(:get, final_url).to_return(status: 200, body: redirect_html) - tool = described_class.new({ url: initial_url }) - result = tool.invoke(bot_user, llm) + tool = described_class.new({ url: initial_url }, bot_user: bot_user, llm: llm) + result = tool.invoke expect(result[:url]).to eq(final_url) expect(result[:text]).to eq("Redirected content.")