FEATURE: improved tooling (#651)

1. New tool to easily find files (and default branch) in a Github repo
2. Improved read tool with clearer params and larger context

* limit can totally mess up the richness semantic search adds, so include the results unconditionally.
This commit is contained in:
Sam 2024-05-30 06:33:50 +10:00 committed by GitHub
parent d812ecf5da
commit 834fea672f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 271 additions and 21 deletions

View File

@ -218,6 +218,7 @@ en:
description: "Base query to use when searching. Example: '#urgent' will prepend '#urgent' to the search query and only include topics with the urgent category or tag."
command_summary:
web_browser: "Browse Web"
github_search_files: "GitHub search files"
github_search_code: "GitHub code search"
github_file_content: "GitHub file content"
github_pull_request_diff: "GitHub pull request diff"
@ -239,6 +240,7 @@ en:
command_help:
web_browser: "Browse web page using the AI Bot"
github_search_code: "Search for code in a GitHub repository"
github_search_files: "Search for files in a GitHub repository"
github_file_content: "Retrieve content of files from a GitHub repository"
github_pull_request_diff: "Retrieve a GitHub pull request diff"
random_picker: "Pick a random number or a random element of a list"
@ -258,6 +260,7 @@ en:
javascript_evaluator: "Evaluate JavaScript"
command_description:
web_browser: "Reading <a href='%{url}'>%{url}</a>"
github_search_files: "Searched for '%{keywords}' in %{repo}/%{branch}"
github_search_code: "Searched for '%{query}' in %{repo}"
github_pull_request_diff: "<a href='%{url}'>%{repo} %{pull_id}</a>"
github_file_content: "Retrieved content of %{file_paths} from %{repo_name}@%{branch}"

View File

@ -5,7 +5,12 @@ module DiscourseAi
module Personas
class GithubHelper < Persona
def tools
[Tools::GithubFileContent, Tools::GithubPullRequestDiff, Tools::GithubSearchCode]
[
Tools::GithubFileContent,
Tools::GithubPullRequestDiff,
Tools::GithubSearchCode,
Tools::GithubSearchFiles,
]
end
def system_prompt

View File

@ -87,6 +87,7 @@ module DiscourseAi
Tools::DiscourseMetaSearch,
Tools::GithubFileContent,
Tools::GithubPullRequestDiff,
Tools::GithubSearchFiles,
Tools::WebBrowser,
Tools::JavascriptEvaluator,
]

View File

@ -0,0 +1,101 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
module Tools
class GithubSearchFiles < Tool
def self.signature
{
name: name,
description:
"Searches for files in a GitHub repository containing specific keywords in their paths or names",
parameters: [
{
name: "repo",
description: "The repository name in the format 'owner/repo'",
type: "string",
required: true,
},
{
name: "keywords",
description: "An array of keywords to match in file paths or names",
type: "array",
item_type: "string",
required: true,
},
{
name: "branch",
description:
"The branch or commit SHA to search within (default: repository's default branch)",
type: "string",
required: false,
},
],
}
end
def self.name
"github_search_files"
end
def repo
parameters[:repo]
end
def keywords
parameters[:keywords]
end
def branch
parameters[:branch]
end
def description_args
{ repo: repo, keywords: keywords.join(", "), branch: @branch_name }
end
def invoke
# Fetch the default branch if no branch is specified
branch_name = branch || fetch_default_branch(repo)
@branch_name = branch_name
api_url = "https://api.github.com/repos/#{repo}/git/trees/#{branch_name}?recursive=1"
response_code = "unknown error"
tree_data = nil
send_http_request(
api_url,
headers: {
"Accept" => "application/vnd.github.v3+json",
},
authenticate_github: true,
) do |response|
response_code = response.code
if response_code == "200"
begin
tree_data = JSON.parse(read_response_body(response))
rescue JSON::ParserError
response_code = "500 - JSON parse error"
end
end
end
if response_code == "200"
matching_files =
tree_data["tree"]
.select do |item|
item["type"] == "blob" &&
keywords.any? { |keyword| item["path"].include?(keyword) }
end
.map { |item| item["path"] }
{ matching_files: matching_files, branch: branch_name }
else
{ error: "Failed to perform file search. Status code: #{response_code}" }
end
end
end
end
end
end

View File

@ -2,6 +2,8 @@
module DiscourseAi
module AiBot
MAX_POSTS = 100
module Tools
class Read < Tool
def self.signature
@ -16,10 +18,11 @@ module DiscourseAi
required: true,
},
{
name: "post_number",
description: "the post number to read",
type: "integer",
required: true,
name: "post_numbers",
description: "the post numbers to read (optional)",
type: "array",
item_type: "integer",
required: false,
},
],
}
@ -35,8 +38,8 @@ module DiscourseAi
parameters[:topic_id]
end
def post_number
parameters[:post_number]
def post_numbers
parameters[:post_numbers]
end
def invoke
@ -49,10 +52,19 @@ module DiscourseAi
@title = topic.title
posts = Post.secured(Guardian.new).where(topic_id: topic_id).order(:post_number).limit(40)
posts =
Post
.secured(Guardian.new)
.where(topic_id: topic_id)
.order(:post_number)
.limit(MAX_POSTS)
post_number = 1
post_number = post_numbers.first if post_numbers.present?
@url = topic.relative_url(post_number)
posts = posts.where("post_number = ?", post_number) if post_number
posts = posts.where("post_number in (?)", post_numbers) if post_numbers.present?
content = +<<~TEXT.strip
title: #{topic.title}
@ -69,13 +81,15 @@ module DiscourseAi
content << "\ntags: #{tags.map(&:name).join(", ")}\n\n" if tags.length > 0
end
posts.each { |post| content << "\n\n#{post.username} said:\n\n#{post.raw}" }
posts.each do |post|
content << "\n\n#{post.user&.name}(#{post.username}) said:\n\n#{post.raw}"
end
# TODO: 16k or 100k models can handle a lot more tokens
content = llm.tokenizer.truncate(content, 1500).squish
truncated_content =
truncate(content, max_length: 20_000, percent_length: 0.3, llm: llm).squish
result = { topic_id: topic_id, content: content, complete: true }
result[:post_number] = post_number if post_number
result = { topic_id: topic_id, content: truncated_content }
result[:post_numbers] = post_numbers if post_numbers.present?
result
end

View File

@ -123,14 +123,12 @@ module DiscourseAi
results =
::Search.execute(safe_search_string, search_type: :full_page, guardian: guardian)
# let's be frugal with tokens, 50 results is too much and stuff gets cut off
max_results = calculate_max_results(llm)
results_limit = parameters[:limit] || max_results
results_limit = max_results if parameters[:limit].to_i > max_results
should_try_semantic_search =
SiteSetting.ai_embeddings_semantic_search_enabled && results_limit == max_results &&
parameters[:search_query].present?
SiteSetting.ai_embeddings_semantic_search_enabled && parameters[:search_query].present?
max_semantic_results = max_results / 4
results_limit = results_limit - max_semantic_results if should_try_semantic_search

View File

@ -92,6 +92,32 @@ module DiscourseAi
protected
def fetch_default_branch(repo)
api_url = "https://api.github.com/repos/#{repo}"
response_code = "unknown error"
repo_data = nil
send_http_request(
api_url,
headers: {
"Accept" => "application/vnd.github.v3+json",
},
authenticate_github: true,
) do |response|
response_code = response.code
if response_code == "200"
begin
repo_data = JSON.parse(read_response_body(response))
rescue JSON::ParserError
response_code = "500 - JSON parse error"
end
end
end
response_code == "200" ? repo_data["default_branch"] : "main"
end
def send_http_request(url, headers: {}, authenticate_github: false, follow_redirects: false)
raise "Expecting caller to use a block" if !block_given?

View File

@ -0,0 +1,91 @@
# frozen_string_literal: true
require "rails_helper"
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchFiles do
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
let(:tool) do
described_class.new(
{
repo: "discourse/discourse-ai",
keywords: %w[search tool],
branch: nil, # Let it find the default branch
},
bot_user: nil,
llm: llm,
)
end
describe "#invoke" do
let(:default_branch) { "main" }
before do
# Stub request to get the default branch
stub_request(:get, "https://api.github.com/repos/discourse/discourse-ai").to_return(
status: 200,
body: { default_branch: default_branch }.to_json,
)
# Stub request to get the file tree
stub_request(
:get,
"https://api.github.com/repos/discourse/discourse-ai/git/trees/#{default_branch}?recursive=1",
).to_return(
status: 200,
body: {
tree: [
{ path: "lib/modules/ai_bot/tools/github_search_code.rb", type: "blob" },
{ path: "lib/modules/ai_bot/tools/github_file_content.rb", type: "blob" },
],
}.to_json,
)
end
it "retrieves files matching the specified keywords" do
result = tool.invoke
expected = {
branch: "main",
matching_files: %w[
lib/modules/ai_bot/tools/github_search_code.rb
lib/modules/ai_bot/tools/github_file_content.rb
],
}
expect(result).to eq(expected)
end
it "handles missing branches gracefully" do
stub_request(
:get,
"https://api.github.com/repos/discourse/discourse-ai/git/trees/non_existing_branch?recursive=1",
).to_return(status: 404, body: "", headers: { "Content-Type" => "application/json" })
tool_with_invalid_branch =
described_class.new(
{
repo: "discourse/discourse-ai",
keywords: %w[search tool],
branch: "non_existing_branch",
},
bot_user: nil,
llm: llm,
)
result = tool_with_invalid_branch.invoke
expect(result[:matching_files]).to be_nil
expect(result[:error]).to eq("Failed to perform file search. Status code: 404")
end
it "fetches the default branch if none is specified" do
result = tool.invoke
expect(result[:matching_files]).to match_array(
%w[
lib/modules/ai_bot/tools/github_search_code.rb
lib/modules/ai_bot/tools/github_file_content.rb
],
)
expect(result[:error]).to be_nil
end
end
end

View File

@ -24,15 +24,26 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do
Fabricate(:topic, category: category, tags: [tag_funny, tag_sad, tag_hidden])
end
fab!(:post1) { Fabricate(:post, topic: topic_with_tags, raw: "hello there") }
fab!(:post2) { Fabricate(:post, topic: topic_with_tags, raw: "mister sam") }
before { SiteSetting.ai_bot_enabled = true }
describe "#process" do
it "can read specific posts" do
tool =
described_class.new(
{ topic_id: topic_with_tags.id, post_numbers: [post1.post_number] },
bot_user: bot_user,
llm: llm,
)
results = tool.invoke
expect(results[:content]).to include("hello there")
expect(results[:content]).not_to include("mister sam")
end
it "can read a topic" do
topic_id = topic_with_tags.id
Fabricate(:post, topic: topic_with_tags, raw: "hello there")
Fabricate(:post, topic: topic_with_tags, raw: "mister sam")
results = tool.invoke
expect(results[:topic_id]).to eq(topic_id)