From 834fea672ff4930eae8a5a53d106158accbb074d Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 30 May 2024 06:33:50 +1000 Subject: [PATCH] FEATURE: improved tooling (#651) 1. New tool to easily find files (and default branch) in a Github repo 2. Improved read tool with clearer params and larger context * limit can totally mess up the richness semantic search adds, so include the results unconditionally. --- config/locales/server.en.yml | 3 + lib/ai_bot/personas/github_helper.rb | 7 +- lib/ai_bot/personas/persona.rb | 1 + lib/ai_bot/tools/github_search_files.rb | 101 ++++++++++++++++++ lib/ai_bot/tools/read.rb | 40 ++++--- lib/ai_bot/tools/search.rb | 4 +- lib/ai_bot/tools/tool.rb | 26 +++++ .../ai_bot/tools/github_search_files_spec.rb | 91 ++++++++++++++++ spec/lib/modules/ai_bot/tools/read_spec.rb | 19 +++- 9 files changed, 271 insertions(+), 21 deletions(-) create mode 100644 lib/ai_bot/tools/github_search_files.rb create mode 100644 spec/lib/modules/ai_bot/tools/github_search_files_spec.rb diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index fa102cb7..c656283a 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -218,6 +218,7 @@ en: description: "Base query to use when searching. Example: '#urgent' will prepend '#urgent' to the search query and only include topics with the urgent category or tag." command_summary: web_browser: "Browse Web" + github_search_files: "GitHub search files" github_search_code: "GitHub code search" github_file_content: "GitHub file content" github_pull_request_diff: "GitHub pull request diff" @@ -239,6 +240,7 @@ en: command_help: web_browser: "Browse web page using the AI Bot" github_search_code: "Search for code in a GitHub repository" + github_search_files: "Search for files in a GitHub repository" github_file_content: "Retrieve content of files from a GitHub repository" github_pull_request_diff: "Retrieve a GitHub pull request diff" random_picker: "Pick a random number or a random element of a list" @@ -258,6 +260,7 @@ en: javascript_evaluator: "Evaluate JavaScript" command_description: web_browser: "Reading %{url}" + github_search_files: "Searched for '%{keywords}' in %{repo}/%{branch}" github_search_code: "Searched for '%{query}' in %{repo}" github_pull_request_diff: "%{repo} %{pull_id}" github_file_content: "Retrieved content of %{file_paths} from %{repo_name}@%{branch}" diff --git a/lib/ai_bot/personas/github_helper.rb b/lib/ai_bot/personas/github_helper.rb index 09e71b4b..a0a2e499 100644 --- a/lib/ai_bot/personas/github_helper.rb +++ b/lib/ai_bot/personas/github_helper.rb @@ -5,7 +5,12 @@ module DiscourseAi module Personas class GithubHelper < Persona def tools - [Tools::GithubFileContent, Tools::GithubPullRequestDiff, Tools::GithubSearchCode] + [ + Tools::GithubFileContent, + Tools::GithubPullRequestDiff, + Tools::GithubSearchCode, + Tools::GithubSearchFiles, + ] end def system_prompt diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 6720da7f..abde3dae 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -87,6 +87,7 @@ module DiscourseAi Tools::DiscourseMetaSearch, Tools::GithubFileContent, Tools::GithubPullRequestDiff, + Tools::GithubSearchFiles, Tools::WebBrowser, Tools::JavascriptEvaluator, ] diff --git a/lib/ai_bot/tools/github_search_files.rb b/lib/ai_bot/tools/github_search_files.rb new file mode 100644 index 00000000..97af3bdf --- /dev/null +++ b/lib/ai_bot/tools/github_search_files.rb @@ -0,0 +1,101 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + module Tools + class GithubSearchFiles < Tool + def self.signature + { + name: name, + description: + "Searches for files in a GitHub repository containing specific keywords in their paths or names", + parameters: [ + { + name: "repo", + description: "The repository name in the format 'owner/repo'", + type: "string", + required: true, + }, + { + name: "keywords", + description: "An array of keywords to match in file paths or names", + type: "array", + item_type: "string", + required: true, + }, + { + name: "branch", + description: + "The branch or commit SHA to search within (default: repository's default branch)", + type: "string", + required: false, + }, + ], + } + end + + def self.name + "github_search_files" + end + + def repo + parameters[:repo] + end + + def keywords + parameters[:keywords] + end + + def branch + parameters[:branch] + end + + def description_args + { repo: repo, keywords: keywords.join(", "), branch: @branch_name } + end + + def invoke + # Fetch the default branch if no branch is specified + branch_name = branch || fetch_default_branch(repo) + @branch_name = branch_name + + api_url = "https://api.github.com/repos/#{repo}/git/trees/#{branch_name}?recursive=1" + + response_code = "unknown error" + tree_data = nil + + send_http_request( + api_url, + headers: { + "Accept" => "application/vnd.github.v3+json", + }, + authenticate_github: true, + ) do |response| + response_code = response.code + if response_code == "200" + begin + tree_data = JSON.parse(read_response_body(response)) + rescue JSON::ParserError + response_code = "500 - JSON parse error" + end + end + end + + if response_code == "200" + matching_files = + tree_data["tree"] + .select do |item| + item["type"] == "blob" && + keywords.any? { |keyword| item["path"].include?(keyword) } + end + .map { |item| item["path"] } + + { matching_files: matching_files, branch: branch_name } + else + { error: "Failed to perform file search. Status code: #{response_code}" } + end + end + end + end + end +end diff --git a/lib/ai_bot/tools/read.rb b/lib/ai_bot/tools/read.rb index 1f8de7a4..3cc6e096 100644 --- a/lib/ai_bot/tools/read.rb +++ b/lib/ai_bot/tools/read.rb @@ -2,6 +2,8 @@ module DiscourseAi module AiBot + MAX_POSTS = 100 + module Tools class Read < Tool def self.signature @@ -16,10 +18,11 @@ module DiscourseAi required: true, }, { - name: "post_number", - description: "the post number to read", - type: "integer", - required: true, + name: "post_numbers", + description: "the post numbers to read (optional)", + type: "array", + item_type: "integer", + required: false, }, ], } @@ -35,8 +38,8 @@ module DiscourseAi parameters[:topic_id] end - def post_number - parameters[:post_number] + def post_numbers + parameters[:post_numbers] end def invoke @@ -49,10 +52,19 @@ module DiscourseAi @title = topic.title - posts = Post.secured(Guardian.new).where(topic_id: topic_id).order(:post_number).limit(40) + posts = + Post + .secured(Guardian.new) + .where(topic_id: topic_id) + .order(:post_number) + .limit(MAX_POSTS) + + post_number = 1 + post_number = post_numbers.first if post_numbers.present? + @url = topic.relative_url(post_number) - posts = posts.where("post_number = ?", post_number) if post_number + posts = posts.where("post_number in (?)", post_numbers) if post_numbers.present? content = +<<~TEXT.strip title: #{topic.title} @@ -69,13 +81,15 @@ module DiscourseAi content << "\ntags: #{tags.map(&:name).join(", ")}\n\n" if tags.length > 0 end - posts.each { |post| content << "\n\n#{post.username} said:\n\n#{post.raw}" } + posts.each do |post| + content << "\n\n#{post.user&.name}(#{post.username}) said:\n\n#{post.raw}" + end - # TODO: 16k or 100k models can handle a lot more tokens - content = llm.tokenizer.truncate(content, 1500).squish + truncated_content = + truncate(content, max_length: 20_000, percent_length: 0.3, llm: llm).squish - result = { topic_id: topic_id, content: content, complete: true } - result[:post_number] = post_number if post_number + result = { topic_id: topic_id, content: truncated_content } + result[:post_numbers] = post_numbers if post_numbers.present? result end diff --git a/lib/ai_bot/tools/search.rb b/lib/ai_bot/tools/search.rb index 3bc91bee..634693be 100644 --- a/lib/ai_bot/tools/search.rb +++ b/lib/ai_bot/tools/search.rb @@ -123,14 +123,12 @@ module DiscourseAi results = ::Search.execute(safe_search_string, search_type: :full_page, guardian: guardian) - # let's be frugal with tokens, 50 results is too much and stuff gets cut off max_results = calculate_max_results(llm) results_limit = parameters[:limit] || max_results results_limit = max_results if parameters[:limit].to_i > max_results should_try_semantic_search = - SiteSetting.ai_embeddings_semantic_search_enabled && results_limit == max_results && - parameters[:search_query].present? + SiteSetting.ai_embeddings_semantic_search_enabled && parameters[:search_query].present? max_semantic_results = max_results / 4 results_limit = results_limit - max_semantic_results if should_try_semantic_search diff --git a/lib/ai_bot/tools/tool.rb b/lib/ai_bot/tools/tool.rb index 71c92eb8..d4411e0b 100644 --- a/lib/ai_bot/tools/tool.rb +++ b/lib/ai_bot/tools/tool.rb @@ -92,6 +92,32 @@ module DiscourseAi protected + def fetch_default_branch(repo) + api_url = "https://api.github.com/repos/#{repo}" + + response_code = "unknown error" + repo_data = nil + + send_http_request( + api_url, + headers: { + "Accept" => "application/vnd.github.v3+json", + }, + authenticate_github: true, + ) do |response| + response_code = response.code + if response_code == "200" + begin + repo_data = JSON.parse(read_response_body(response)) + rescue JSON::ParserError + response_code = "500 - JSON parse error" + end + end + end + + response_code == "200" ? repo_data["default_branch"] : "main" + end + def send_http_request(url, headers: {}, authenticate_github: false, follow_redirects: false) raise "Expecting caller to use a block" if !block_given? diff --git a/spec/lib/modules/ai_bot/tools/github_search_files_spec.rb b/spec/lib/modules/ai_bot/tools/github_search_files_spec.rb new file mode 100644 index 00000000..b4f61b0d --- /dev/null +++ b/spec/lib/modules/ai_bot/tools/github_search_files_spec.rb @@ -0,0 +1,91 @@ +# frozen_string_literal: true + +require "rails_helper" + +RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchFiles do + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") } + + let(:tool) do + described_class.new( + { + repo: "discourse/discourse-ai", + keywords: %w[search tool], + branch: nil, # Let it find the default branch + }, + bot_user: nil, + llm: llm, + ) + end + + describe "#invoke" do + let(:default_branch) { "main" } + + before do + # Stub request to get the default branch + stub_request(:get, "https://api.github.com/repos/discourse/discourse-ai").to_return( + status: 200, + body: { default_branch: default_branch }.to_json, + ) + + # Stub request to get the file tree + stub_request( + :get, + "https://api.github.com/repos/discourse/discourse-ai/git/trees/#{default_branch}?recursive=1", + ).to_return( + status: 200, + body: { + tree: [ + { path: "lib/modules/ai_bot/tools/github_search_code.rb", type: "blob" }, + { path: "lib/modules/ai_bot/tools/github_file_content.rb", type: "blob" }, + ], + }.to_json, + ) + end + + it "retrieves files matching the specified keywords" do + result = tool.invoke + expected = { + branch: "main", + matching_files: %w[ + lib/modules/ai_bot/tools/github_search_code.rb + lib/modules/ai_bot/tools/github_file_content.rb + ], + } + + expect(result).to eq(expected) + end + + it "handles missing branches gracefully" do + stub_request( + :get, + "https://api.github.com/repos/discourse/discourse-ai/git/trees/non_existing_branch?recursive=1", + ).to_return(status: 404, body: "", headers: { "Content-Type" => "application/json" }) + + tool_with_invalid_branch = + described_class.new( + { + repo: "discourse/discourse-ai", + keywords: %w[search tool], + branch: "non_existing_branch", + }, + bot_user: nil, + llm: llm, + ) + + result = tool_with_invalid_branch.invoke + expect(result[:matching_files]).to be_nil + expect(result[:error]).to eq("Failed to perform file search. Status code: 404") + end + + it "fetches the default branch if none is specified" do + result = tool.invoke + expect(result[:matching_files]).to match_array( + %w[ + lib/modules/ai_bot/tools/github_search_code.rb + lib/modules/ai_bot/tools/github_file_content.rb + ], + ) + expect(result[:error]).to be_nil + end + end +end diff --git a/spec/lib/modules/ai_bot/tools/read_spec.rb b/spec/lib/modules/ai_bot/tools/read_spec.rb index ffe45cac..c801e224 100644 --- a/spec/lib/modules/ai_bot/tools/read_spec.rb +++ b/spec/lib/modules/ai_bot/tools/read_spec.rb @@ -24,15 +24,26 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do Fabricate(:topic, category: category, tags: [tag_funny, tag_sad, tag_hidden]) end + fab!(:post1) { Fabricate(:post, topic: topic_with_tags, raw: "hello there") } + fab!(:post2) { Fabricate(:post, topic: topic_with_tags, raw: "mister sam") } + before { SiteSetting.ai_bot_enabled = true } describe "#process" do + it "can read specific posts" do + tool = + described_class.new( + { topic_id: topic_with_tags.id, post_numbers: [post1.post_number] }, + bot_user: bot_user, + llm: llm, + ) + results = tool.invoke + + expect(results[:content]).to include("hello there") + expect(results[:content]).not_to include("mister sam") + end it "can read a topic" do topic_id = topic_with_tags.id - - Fabricate(:post, topic: topic_with_tags, raw: "hello there") - Fabricate(:post, topic: topic_with_tags, raw: "mister sam") - results = tool.invoke expect(results[:topic_id]).to eq(topic_id)