FEATURE: improved tooling (#651)
1. New tool to easily find files (and default branch) in a Github repo 2. Improved read tool with clearer params and larger context * limit can totally mess up the richness semantic search adds, so include the results unconditionally.
This commit is contained in:
parent
d812ecf5da
commit
834fea672f
|
@ -218,6 +218,7 @@ en:
|
|||
description: "Base query to use when searching. Example: '#urgent' will prepend '#urgent' to the search query and only include topics with the urgent category or tag."
|
||||
command_summary:
|
||||
web_browser: "Browse Web"
|
||||
github_search_files: "GitHub search files"
|
||||
github_search_code: "GitHub code search"
|
||||
github_file_content: "GitHub file content"
|
||||
github_pull_request_diff: "GitHub pull request diff"
|
||||
|
@ -239,6 +240,7 @@ en:
|
|||
command_help:
|
||||
web_browser: "Browse web page using the AI Bot"
|
||||
github_search_code: "Search for code in a GitHub repository"
|
||||
github_search_files: "Search for files in a GitHub repository"
|
||||
github_file_content: "Retrieve content of files from a GitHub repository"
|
||||
github_pull_request_diff: "Retrieve a GitHub pull request diff"
|
||||
random_picker: "Pick a random number or a random element of a list"
|
||||
|
@ -258,6 +260,7 @@ en:
|
|||
javascript_evaluator: "Evaluate JavaScript"
|
||||
command_description:
|
||||
web_browser: "Reading <a href='%{url}'>%{url}</a>"
|
||||
github_search_files: "Searched for '%{keywords}' in %{repo}/%{branch}"
|
||||
github_search_code: "Searched for '%{query}' in %{repo}"
|
||||
github_pull_request_diff: "<a href='%{url}'>%{repo} %{pull_id}</a>"
|
||||
github_file_content: "Retrieved content of %{file_paths} from %{repo_name}@%{branch}"
|
||||
|
|
|
@ -5,7 +5,12 @@ module DiscourseAi
|
|||
module Personas
|
||||
class GithubHelper < Persona
|
||||
def tools
|
||||
[Tools::GithubFileContent, Tools::GithubPullRequestDiff, Tools::GithubSearchCode]
|
||||
[
|
||||
Tools::GithubFileContent,
|
||||
Tools::GithubPullRequestDiff,
|
||||
Tools::GithubSearchCode,
|
||||
Tools::GithubSearchFiles,
|
||||
]
|
||||
end
|
||||
|
||||
def system_prompt
|
||||
|
|
|
@ -87,6 +87,7 @@ module DiscourseAi
|
|||
Tools::DiscourseMetaSearch,
|
||||
Tools::GithubFileContent,
|
||||
Tools::GithubPullRequestDiff,
|
||||
Tools::GithubSearchFiles,
|
||||
Tools::WebBrowser,
|
||||
Tools::JavascriptEvaluator,
|
||||
]
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Tools
|
||||
class GithubSearchFiles < Tool
|
||||
def self.signature
|
||||
{
|
||||
name: name,
|
||||
description:
|
||||
"Searches for files in a GitHub repository containing specific keywords in their paths or names",
|
||||
parameters: [
|
||||
{
|
||||
name: "repo",
|
||||
description: "The repository name in the format 'owner/repo'",
|
||||
type: "string",
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
name: "keywords",
|
||||
description: "An array of keywords to match in file paths or names",
|
||||
type: "array",
|
||||
item_type: "string",
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
name: "branch",
|
||||
description:
|
||||
"The branch or commit SHA to search within (default: repository's default branch)",
|
||||
type: "string",
|
||||
required: false,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def self.name
|
||||
"github_search_files"
|
||||
end
|
||||
|
||||
def repo
|
||||
parameters[:repo]
|
||||
end
|
||||
|
||||
def keywords
|
||||
parameters[:keywords]
|
||||
end
|
||||
|
||||
def branch
|
||||
parameters[:branch]
|
||||
end
|
||||
|
||||
def description_args
|
||||
{ repo: repo, keywords: keywords.join(", "), branch: @branch_name }
|
||||
end
|
||||
|
||||
def invoke
|
||||
# Fetch the default branch if no branch is specified
|
||||
branch_name = branch || fetch_default_branch(repo)
|
||||
@branch_name = branch_name
|
||||
|
||||
api_url = "https://api.github.com/repos/#{repo}/git/trees/#{branch_name}?recursive=1"
|
||||
|
||||
response_code = "unknown error"
|
||||
tree_data = nil
|
||||
|
||||
send_http_request(
|
||||
api_url,
|
||||
headers: {
|
||||
"Accept" => "application/vnd.github.v3+json",
|
||||
},
|
||||
authenticate_github: true,
|
||||
) do |response|
|
||||
response_code = response.code
|
||||
if response_code == "200"
|
||||
begin
|
||||
tree_data = JSON.parse(read_response_body(response))
|
||||
rescue JSON::ParserError
|
||||
response_code = "500 - JSON parse error"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if response_code == "200"
|
||||
matching_files =
|
||||
tree_data["tree"]
|
||||
.select do |item|
|
||||
item["type"] == "blob" &&
|
||||
keywords.any? { |keyword| item["path"].include?(keyword) }
|
||||
end
|
||||
.map { |item| item["path"] }
|
||||
|
||||
{ matching_files: matching_files, branch: branch_name }
|
||||
else
|
||||
{ error: "Failed to perform file search. Status code: #{response_code}" }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
MAX_POSTS = 100
|
||||
|
||||
module Tools
|
||||
class Read < Tool
|
||||
def self.signature
|
||||
|
@ -16,10 +18,11 @@ module DiscourseAi
|
|||
required: true,
|
||||
},
|
||||
{
|
||||
name: "post_number",
|
||||
description: "the post number to read",
|
||||
type: "integer",
|
||||
required: true,
|
||||
name: "post_numbers",
|
||||
description: "the post numbers to read (optional)",
|
||||
type: "array",
|
||||
item_type: "integer",
|
||||
required: false,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
@ -35,8 +38,8 @@ module DiscourseAi
|
|||
parameters[:topic_id]
|
||||
end
|
||||
|
||||
def post_number
|
||||
parameters[:post_number]
|
||||
def post_numbers
|
||||
parameters[:post_numbers]
|
||||
end
|
||||
|
||||
def invoke
|
||||
|
@ -49,10 +52,19 @@ module DiscourseAi
|
|||
|
||||
@title = topic.title
|
||||
|
||||
posts = Post.secured(Guardian.new).where(topic_id: topic_id).order(:post_number).limit(40)
|
||||
posts =
|
||||
Post
|
||||
.secured(Guardian.new)
|
||||
.where(topic_id: topic_id)
|
||||
.order(:post_number)
|
||||
.limit(MAX_POSTS)
|
||||
|
||||
post_number = 1
|
||||
post_number = post_numbers.first if post_numbers.present?
|
||||
|
||||
@url = topic.relative_url(post_number)
|
||||
|
||||
posts = posts.where("post_number = ?", post_number) if post_number
|
||||
posts = posts.where("post_number in (?)", post_numbers) if post_numbers.present?
|
||||
|
||||
content = +<<~TEXT.strip
|
||||
title: #{topic.title}
|
||||
|
@ -69,13 +81,15 @@ module DiscourseAi
|
|||
content << "\ntags: #{tags.map(&:name).join(", ")}\n\n" if tags.length > 0
|
||||
end
|
||||
|
||||
posts.each { |post| content << "\n\n#{post.username} said:\n\n#{post.raw}" }
|
||||
posts.each do |post|
|
||||
content << "\n\n#{post.user&.name}(#{post.username}) said:\n\n#{post.raw}"
|
||||
end
|
||||
|
||||
# TODO: 16k or 100k models can handle a lot more tokens
|
||||
content = llm.tokenizer.truncate(content, 1500).squish
|
||||
truncated_content =
|
||||
truncate(content, max_length: 20_000, percent_length: 0.3, llm: llm).squish
|
||||
|
||||
result = { topic_id: topic_id, content: content, complete: true }
|
||||
result[:post_number] = post_number if post_number
|
||||
result = { topic_id: topic_id, content: truncated_content }
|
||||
result[:post_numbers] = post_numbers if post_numbers.present?
|
||||
result
|
||||
end
|
||||
|
||||
|
|
|
@ -123,14 +123,12 @@ module DiscourseAi
|
|||
results =
|
||||
::Search.execute(safe_search_string, search_type: :full_page, guardian: guardian)
|
||||
|
||||
# let's be frugal with tokens, 50 results is too much and stuff gets cut off
|
||||
max_results = calculate_max_results(llm)
|
||||
results_limit = parameters[:limit] || max_results
|
||||
results_limit = max_results if parameters[:limit].to_i > max_results
|
||||
|
||||
should_try_semantic_search =
|
||||
SiteSetting.ai_embeddings_semantic_search_enabled && results_limit == max_results &&
|
||||
parameters[:search_query].present?
|
||||
SiteSetting.ai_embeddings_semantic_search_enabled && parameters[:search_query].present?
|
||||
|
||||
max_semantic_results = max_results / 4
|
||||
results_limit = results_limit - max_semantic_results if should_try_semantic_search
|
||||
|
|
|
@ -92,6 +92,32 @@ module DiscourseAi
|
|||
|
||||
protected
|
||||
|
||||
def fetch_default_branch(repo)
|
||||
api_url = "https://api.github.com/repos/#{repo}"
|
||||
|
||||
response_code = "unknown error"
|
||||
repo_data = nil
|
||||
|
||||
send_http_request(
|
||||
api_url,
|
||||
headers: {
|
||||
"Accept" => "application/vnd.github.v3+json",
|
||||
},
|
||||
authenticate_github: true,
|
||||
) do |response|
|
||||
response_code = response.code
|
||||
if response_code == "200"
|
||||
begin
|
||||
repo_data = JSON.parse(read_response_body(response))
|
||||
rescue JSON::ParserError
|
||||
response_code = "500 - JSON parse error"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
response_code == "200" ? repo_data["default_branch"] : "main"
|
||||
end
|
||||
|
||||
def send_http_request(url, headers: {}, authenticate_github: false, follow_redirects: false)
|
||||
raise "Expecting caller to use a block" if !block_given?
|
||||
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchFiles do
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
||||
|
||||
let(:tool) do
|
||||
described_class.new(
|
||||
{
|
||||
repo: "discourse/discourse-ai",
|
||||
keywords: %w[search tool],
|
||||
branch: nil, # Let it find the default branch
|
||||
},
|
||||
bot_user: nil,
|
||||
llm: llm,
|
||||
)
|
||||
end
|
||||
|
||||
describe "#invoke" do
|
||||
let(:default_branch) { "main" }
|
||||
|
||||
before do
|
||||
# Stub request to get the default branch
|
||||
stub_request(:get, "https://api.github.com/repos/discourse/discourse-ai").to_return(
|
||||
status: 200,
|
||||
body: { default_branch: default_branch }.to_json,
|
||||
)
|
||||
|
||||
# Stub request to get the file tree
|
||||
stub_request(
|
||||
:get,
|
||||
"https://api.github.com/repos/discourse/discourse-ai/git/trees/#{default_branch}?recursive=1",
|
||||
).to_return(
|
||||
status: 200,
|
||||
body: {
|
||||
tree: [
|
||||
{ path: "lib/modules/ai_bot/tools/github_search_code.rb", type: "blob" },
|
||||
{ path: "lib/modules/ai_bot/tools/github_file_content.rb", type: "blob" },
|
||||
],
|
||||
}.to_json,
|
||||
)
|
||||
end
|
||||
|
||||
it "retrieves files matching the specified keywords" do
|
||||
result = tool.invoke
|
||||
expected = {
|
||||
branch: "main",
|
||||
matching_files: %w[
|
||||
lib/modules/ai_bot/tools/github_search_code.rb
|
||||
lib/modules/ai_bot/tools/github_file_content.rb
|
||||
],
|
||||
}
|
||||
|
||||
expect(result).to eq(expected)
|
||||
end
|
||||
|
||||
it "handles missing branches gracefully" do
|
||||
stub_request(
|
||||
:get,
|
||||
"https://api.github.com/repos/discourse/discourse-ai/git/trees/non_existing_branch?recursive=1",
|
||||
).to_return(status: 404, body: "", headers: { "Content-Type" => "application/json" })
|
||||
|
||||
tool_with_invalid_branch =
|
||||
described_class.new(
|
||||
{
|
||||
repo: "discourse/discourse-ai",
|
||||
keywords: %w[search tool],
|
||||
branch: "non_existing_branch",
|
||||
},
|
||||
bot_user: nil,
|
||||
llm: llm,
|
||||
)
|
||||
|
||||
result = tool_with_invalid_branch.invoke
|
||||
expect(result[:matching_files]).to be_nil
|
||||
expect(result[:error]).to eq("Failed to perform file search. Status code: 404")
|
||||
end
|
||||
|
||||
it "fetches the default branch if none is specified" do
|
||||
result = tool.invoke
|
||||
expect(result[:matching_files]).to match_array(
|
||||
%w[
|
||||
lib/modules/ai_bot/tools/github_search_code.rb
|
||||
lib/modules/ai_bot/tools/github_file_content.rb
|
||||
],
|
||||
)
|
||||
expect(result[:error]).to be_nil
|
||||
end
|
||||
end
|
||||
end
|
|
@ -24,15 +24,26 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do
|
|||
Fabricate(:topic, category: category, tags: [tag_funny, tag_sad, tag_hidden])
|
||||
end
|
||||
|
||||
fab!(:post1) { Fabricate(:post, topic: topic_with_tags, raw: "hello there") }
|
||||
fab!(:post2) { Fabricate(:post, topic: topic_with_tags, raw: "mister sam") }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
describe "#process" do
|
||||
it "can read specific posts" do
|
||||
tool =
|
||||
described_class.new(
|
||||
{ topic_id: topic_with_tags.id, post_numbers: [post1.post_number] },
|
||||
bot_user: bot_user,
|
||||
llm: llm,
|
||||
)
|
||||
results = tool.invoke
|
||||
|
||||
expect(results[:content]).to include("hello there")
|
||||
expect(results[:content]).not_to include("mister sam")
|
||||
end
|
||||
it "can read a topic" do
|
||||
topic_id = topic_with_tags.id
|
||||
|
||||
Fabricate(:post, topic: topic_with_tags, raw: "hello there")
|
||||
Fabricate(:post, topic: topic_with_tags, raw: "mister sam")
|
||||
|
||||
results = tool.invoke
|
||||
|
||||
expect(results[:topic_id]).to eq(topic_id)
|
||||
|
|
Loading…
Reference in New Issue