FEATURE: improved tooling (#651)
1. New tool to easily find files (and default branch) in a Github repo 2. Improved read tool with clearer params and larger context * limit can totally mess up the richness semantic search adds, so include the results unconditionally.
This commit is contained in:
parent
d812ecf5da
commit
834fea672f
|
@ -218,6 +218,7 @@ en:
|
||||||
description: "Base query to use when searching. Example: '#urgent' will prepend '#urgent' to the search query and only include topics with the urgent category or tag."
|
description: "Base query to use when searching. Example: '#urgent' will prepend '#urgent' to the search query and only include topics with the urgent category or tag."
|
||||||
command_summary:
|
command_summary:
|
||||||
web_browser: "Browse Web"
|
web_browser: "Browse Web"
|
||||||
|
github_search_files: "GitHub search files"
|
||||||
github_search_code: "GitHub code search"
|
github_search_code: "GitHub code search"
|
||||||
github_file_content: "GitHub file content"
|
github_file_content: "GitHub file content"
|
||||||
github_pull_request_diff: "GitHub pull request diff"
|
github_pull_request_diff: "GitHub pull request diff"
|
||||||
|
@ -239,6 +240,7 @@ en:
|
||||||
command_help:
|
command_help:
|
||||||
web_browser: "Browse web page using the AI Bot"
|
web_browser: "Browse web page using the AI Bot"
|
||||||
github_search_code: "Search for code in a GitHub repository"
|
github_search_code: "Search for code in a GitHub repository"
|
||||||
|
github_search_files: "Search for files in a GitHub repository"
|
||||||
github_file_content: "Retrieve content of files from a GitHub repository"
|
github_file_content: "Retrieve content of files from a GitHub repository"
|
||||||
github_pull_request_diff: "Retrieve a GitHub pull request diff"
|
github_pull_request_diff: "Retrieve a GitHub pull request diff"
|
||||||
random_picker: "Pick a random number or a random element of a list"
|
random_picker: "Pick a random number or a random element of a list"
|
||||||
|
@ -258,6 +260,7 @@ en:
|
||||||
javascript_evaluator: "Evaluate JavaScript"
|
javascript_evaluator: "Evaluate JavaScript"
|
||||||
command_description:
|
command_description:
|
||||||
web_browser: "Reading <a href='%{url}'>%{url}</a>"
|
web_browser: "Reading <a href='%{url}'>%{url}</a>"
|
||||||
|
github_search_files: "Searched for '%{keywords}' in %{repo}/%{branch}"
|
||||||
github_search_code: "Searched for '%{query}' in %{repo}"
|
github_search_code: "Searched for '%{query}' in %{repo}"
|
||||||
github_pull_request_diff: "<a href='%{url}'>%{repo} %{pull_id}</a>"
|
github_pull_request_diff: "<a href='%{url}'>%{repo} %{pull_id}</a>"
|
||||||
github_file_content: "Retrieved content of %{file_paths} from %{repo_name}@%{branch}"
|
github_file_content: "Retrieved content of %{file_paths} from %{repo_name}@%{branch}"
|
||||||
|
|
|
@ -5,7 +5,12 @@ module DiscourseAi
|
||||||
module Personas
|
module Personas
|
||||||
class GithubHelper < Persona
|
class GithubHelper < Persona
|
||||||
def tools
|
def tools
|
||||||
[Tools::GithubFileContent, Tools::GithubPullRequestDiff, Tools::GithubSearchCode]
|
[
|
||||||
|
Tools::GithubFileContent,
|
||||||
|
Tools::GithubPullRequestDiff,
|
||||||
|
Tools::GithubSearchCode,
|
||||||
|
Tools::GithubSearchFiles,
|
||||||
|
]
|
||||||
end
|
end
|
||||||
|
|
||||||
def system_prompt
|
def system_prompt
|
||||||
|
|
|
@ -87,6 +87,7 @@ module DiscourseAi
|
||||||
Tools::DiscourseMetaSearch,
|
Tools::DiscourseMetaSearch,
|
||||||
Tools::GithubFileContent,
|
Tools::GithubFileContent,
|
||||||
Tools::GithubPullRequestDiff,
|
Tools::GithubPullRequestDiff,
|
||||||
|
Tools::GithubSearchFiles,
|
||||||
Tools::WebBrowser,
|
Tools::WebBrowser,
|
||||||
Tools::JavascriptEvaluator,
|
Tools::JavascriptEvaluator,
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
module Tools
|
||||||
|
class GithubSearchFiles < Tool
|
||||||
|
def self.signature
|
||||||
|
{
|
||||||
|
name: name,
|
||||||
|
description:
|
||||||
|
"Searches for files in a GitHub repository containing specific keywords in their paths or names",
|
||||||
|
parameters: [
|
||||||
|
{
|
||||||
|
name: "repo",
|
||||||
|
description: "The repository name in the format 'owner/repo'",
|
||||||
|
type: "string",
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "keywords",
|
||||||
|
description: "An array of keywords to match in file paths or names",
|
||||||
|
type: "array",
|
||||||
|
item_type: "string",
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "branch",
|
||||||
|
description:
|
||||||
|
"The branch or commit SHA to search within (default: repository's default branch)",
|
||||||
|
type: "string",
|
||||||
|
required: false,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.name
|
||||||
|
"github_search_files"
|
||||||
|
end
|
||||||
|
|
||||||
|
def repo
|
||||||
|
parameters[:repo]
|
||||||
|
end
|
||||||
|
|
||||||
|
def keywords
|
||||||
|
parameters[:keywords]
|
||||||
|
end
|
||||||
|
|
||||||
|
def branch
|
||||||
|
parameters[:branch]
|
||||||
|
end
|
||||||
|
|
||||||
|
def description_args
|
||||||
|
{ repo: repo, keywords: keywords.join(", "), branch: @branch_name }
|
||||||
|
end
|
||||||
|
|
||||||
|
def invoke
|
||||||
|
# Fetch the default branch if no branch is specified
|
||||||
|
branch_name = branch || fetch_default_branch(repo)
|
||||||
|
@branch_name = branch_name
|
||||||
|
|
||||||
|
api_url = "https://api.github.com/repos/#{repo}/git/trees/#{branch_name}?recursive=1"
|
||||||
|
|
||||||
|
response_code = "unknown error"
|
||||||
|
tree_data = nil
|
||||||
|
|
||||||
|
send_http_request(
|
||||||
|
api_url,
|
||||||
|
headers: {
|
||||||
|
"Accept" => "application/vnd.github.v3+json",
|
||||||
|
},
|
||||||
|
authenticate_github: true,
|
||||||
|
) do |response|
|
||||||
|
response_code = response.code
|
||||||
|
if response_code == "200"
|
||||||
|
begin
|
||||||
|
tree_data = JSON.parse(read_response_body(response))
|
||||||
|
rescue JSON::ParserError
|
||||||
|
response_code = "500 - JSON parse error"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if response_code == "200"
|
||||||
|
matching_files =
|
||||||
|
tree_data["tree"]
|
||||||
|
.select do |item|
|
||||||
|
item["type"] == "blob" &&
|
||||||
|
keywords.any? { |keyword| item["path"].include?(keyword) }
|
||||||
|
end
|
||||||
|
.map { |item| item["path"] }
|
||||||
|
|
||||||
|
{ matching_files: matching_files, branch: branch_name }
|
||||||
|
else
|
||||||
|
{ error: "Failed to perform file search. Status code: #{response_code}" }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -2,6 +2,8 @@
|
||||||
|
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module AiBot
|
module AiBot
|
||||||
|
MAX_POSTS = 100
|
||||||
|
|
||||||
module Tools
|
module Tools
|
||||||
class Read < Tool
|
class Read < Tool
|
||||||
def self.signature
|
def self.signature
|
||||||
|
@ -16,10 +18,11 @@ module DiscourseAi
|
||||||
required: true,
|
required: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "post_number",
|
name: "post_numbers",
|
||||||
description: "the post number to read",
|
description: "the post numbers to read (optional)",
|
||||||
type: "integer",
|
type: "array",
|
||||||
required: true,
|
item_type: "integer",
|
||||||
|
required: false,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
@ -35,8 +38,8 @@ module DiscourseAi
|
||||||
parameters[:topic_id]
|
parameters[:topic_id]
|
||||||
end
|
end
|
||||||
|
|
||||||
def post_number
|
def post_numbers
|
||||||
parameters[:post_number]
|
parameters[:post_numbers]
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke
|
def invoke
|
||||||
|
@ -49,10 +52,19 @@ module DiscourseAi
|
||||||
|
|
||||||
@title = topic.title
|
@title = topic.title
|
||||||
|
|
||||||
posts = Post.secured(Guardian.new).where(topic_id: topic_id).order(:post_number).limit(40)
|
posts =
|
||||||
|
Post
|
||||||
|
.secured(Guardian.new)
|
||||||
|
.where(topic_id: topic_id)
|
||||||
|
.order(:post_number)
|
||||||
|
.limit(MAX_POSTS)
|
||||||
|
|
||||||
|
post_number = 1
|
||||||
|
post_number = post_numbers.first if post_numbers.present?
|
||||||
|
|
||||||
@url = topic.relative_url(post_number)
|
@url = topic.relative_url(post_number)
|
||||||
|
|
||||||
posts = posts.where("post_number = ?", post_number) if post_number
|
posts = posts.where("post_number in (?)", post_numbers) if post_numbers.present?
|
||||||
|
|
||||||
content = +<<~TEXT.strip
|
content = +<<~TEXT.strip
|
||||||
title: #{topic.title}
|
title: #{topic.title}
|
||||||
|
@ -69,13 +81,15 @@ module DiscourseAi
|
||||||
content << "\ntags: #{tags.map(&:name).join(", ")}\n\n" if tags.length > 0
|
content << "\ntags: #{tags.map(&:name).join(", ")}\n\n" if tags.length > 0
|
||||||
end
|
end
|
||||||
|
|
||||||
posts.each { |post| content << "\n\n#{post.username} said:\n\n#{post.raw}" }
|
posts.each do |post|
|
||||||
|
content << "\n\n#{post.user&.name}(#{post.username}) said:\n\n#{post.raw}"
|
||||||
|
end
|
||||||
|
|
||||||
# TODO: 16k or 100k models can handle a lot more tokens
|
truncated_content =
|
||||||
content = llm.tokenizer.truncate(content, 1500).squish
|
truncate(content, max_length: 20_000, percent_length: 0.3, llm: llm).squish
|
||||||
|
|
||||||
result = { topic_id: topic_id, content: content, complete: true }
|
result = { topic_id: topic_id, content: truncated_content }
|
||||||
result[:post_number] = post_number if post_number
|
result[:post_numbers] = post_numbers if post_numbers.present?
|
||||||
result
|
result
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -123,14 +123,12 @@ module DiscourseAi
|
||||||
results =
|
results =
|
||||||
::Search.execute(safe_search_string, search_type: :full_page, guardian: guardian)
|
::Search.execute(safe_search_string, search_type: :full_page, guardian: guardian)
|
||||||
|
|
||||||
# let's be frugal with tokens, 50 results is too much and stuff gets cut off
|
|
||||||
max_results = calculate_max_results(llm)
|
max_results = calculate_max_results(llm)
|
||||||
results_limit = parameters[:limit] || max_results
|
results_limit = parameters[:limit] || max_results
|
||||||
results_limit = max_results if parameters[:limit].to_i > max_results
|
results_limit = max_results if parameters[:limit].to_i > max_results
|
||||||
|
|
||||||
should_try_semantic_search =
|
should_try_semantic_search =
|
||||||
SiteSetting.ai_embeddings_semantic_search_enabled && results_limit == max_results &&
|
SiteSetting.ai_embeddings_semantic_search_enabled && parameters[:search_query].present?
|
||||||
parameters[:search_query].present?
|
|
||||||
|
|
||||||
max_semantic_results = max_results / 4
|
max_semantic_results = max_results / 4
|
||||||
results_limit = results_limit - max_semantic_results if should_try_semantic_search
|
results_limit = results_limit - max_semantic_results if should_try_semantic_search
|
||||||
|
|
|
@ -92,6 +92,32 @@ module DiscourseAi
|
||||||
|
|
||||||
protected
|
protected
|
||||||
|
|
||||||
|
def fetch_default_branch(repo)
|
||||||
|
api_url = "https://api.github.com/repos/#{repo}"
|
||||||
|
|
||||||
|
response_code = "unknown error"
|
||||||
|
repo_data = nil
|
||||||
|
|
||||||
|
send_http_request(
|
||||||
|
api_url,
|
||||||
|
headers: {
|
||||||
|
"Accept" => "application/vnd.github.v3+json",
|
||||||
|
},
|
||||||
|
authenticate_github: true,
|
||||||
|
) do |response|
|
||||||
|
response_code = response.code
|
||||||
|
if response_code == "200"
|
||||||
|
begin
|
||||||
|
repo_data = JSON.parse(read_response_body(response))
|
||||||
|
rescue JSON::ParserError
|
||||||
|
response_code = "500 - JSON parse error"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
response_code == "200" ? repo_data["default_branch"] : "main"
|
||||||
|
end
|
||||||
|
|
||||||
def send_http_request(url, headers: {}, authenticate_github: false, follow_redirects: false)
|
def send_http_request(url, headers: {}, authenticate_github: false, follow_redirects: false)
|
||||||
raise "Expecting caller to use a block" if !block_given?
|
raise "Expecting caller to use a block" if !block_given?
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,91 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require "rails_helper"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchFiles do
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
||||||
|
|
||||||
|
let(:tool) do
|
||||||
|
described_class.new(
|
||||||
|
{
|
||||||
|
repo: "discourse/discourse-ai",
|
||||||
|
keywords: %w[search tool],
|
||||||
|
branch: nil, # Let it find the default branch
|
||||||
|
},
|
||||||
|
bot_user: nil,
|
||||||
|
llm: llm,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#invoke" do
|
||||||
|
let(:default_branch) { "main" }
|
||||||
|
|
||||||
|
before do
|
||||||
|
# Stub request to get the default branch
|
||||||
|
stub_request(:get, "https://api.github.com/repos/discourse/discourse-ai").to_return(
|
||||||
|
status: 200,
|
||||||
|
body: { default_branch: default_branch }.to_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stub request to get the file tree
|
||||||
|
stub_request(
|
||||||
|
:get,
|
||||||
|
"https://api.github.com/repos/discourse/discourse-ai/git/trees/#{default_branch}?recursive=1",
|
||||||
|
).to_return(
|
||||||
|
status: 200,
|
||||||
|
body: {
|
||||||
|
tree: [
|
||||||
|
{ path: "lib/modules/ai_bot/tools/github_search_code.rb", type: "blob" },
|
||||||
|
{ path: "lib/modules/ai_bot/tools/github_file_content.rb", type: "blob" },
|
||||||
|
],
|
||||||
|
}.to_json,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "retrieves files matching the specified keywords" do
|
||||||
|
result = tool.invoke
|
||||||
|
expected = {
|
||||||
|
branch: "main",
|
||||||
|
matching_files: %w[
|
||||||
|
lib/modules/ai_bot/tools/github_search_code.rb
|
||||||
|
lib/modules/ai_bot/tools/github_file_content.rb
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(result).to eq(expected)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "handles missing branches gracefully" do
|
||||||
|
stub_request(
|
||||||
|
:get,
|
||||||
|
"https://api.github.com/repos/discourse/discourse-ai/git/trees/non_existing_branch?recursive=1",
|
||||||
|
).to_return(status: 404, body: "", headers: { "Content-Type" => "application/json" })
|
||||||
|
|
||||||
|
tool_with_invalid_branch =
|
||||||
|
described_class.new(
|
||||||
|
{
|
||||||
|
repo: "discourse/discourse-ai",
|
||||||
|
keywords: %w[search tool],
|
||||||
|
branch: "non_existing_branch",
|
||||||
|
},
|
||||||
|
bot_user: nil,
|
||||||
|
llm: llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = tool_with_invalid_branch.invoke
|
||||||
|
expect(result[:matching_files]).to be_nil
|
||||||
|
expect(result[:error]).to eq("Failed to perform file search. Status code: 404")
|
||||||
|
end
|
||||||
|
|
||||||
|
it "fetches the default branch if none is specified" do
|
||||||
|
result = tool.invoke
|
||||||
|
expect(result[:matching_files]).to match_array(
|
||||||
|
%w[
|
||||||
|
lib/modules/ai_bot/tools/github_search_code.rb
|
||||||
|
lib/modules/ai_bot/tools/github_file_content.rb
|
||||||
|
],
|
||||||
|
)
|
||||||
|
expect(result[:error]).to be_nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -24,15 +24,26 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do
|
||||||
Fabricate(:topic, category: category, tags: [tag_funny, tag_sad, tag_hidden])
|
Fabricate(:topic, category: category, tags: [tag_funny, tag_sad, tag_hidden])
|
||||||
end
|
end
|
||||||
|
|
||||||
|
fab!(:post1) { Fabricate(:post, topic: topic_with_tags, raw: "hello there") }
|
||||||
|
fab!(:post2) { Fabricate(:post, topic: topic_with_tags, raw: "mister sam") }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
||||||
describe "#process" do
|
describe "#process" do
|
||||||
|
it "can read specific posts" do
|
||||||
|
tool =
|
||||||
|
described_class.new(
|
||||||
|
{ topic_id: topic_with_tags.id, post_numbers: [post1.post_number] },
|
||||||
|
bot_user: bot_user,
|
||||||
|
llm: llm,
|
||||||
|
)
|
||||||
|
results = tool.invoke
|
||||||
|
|
||||||
|
expect(results[:content]).to include("hello there")
|
||||||
|
expect(results[:content]).not_to include("mister sam")
|
||||||
|
end
|
||||||
it "can read a topic" do
|
it "can read a topic" do
|
||||||
topic_id = topic_with_tags.id
|
topic_id = topic_with_tags.id
|
||||||
|
|
||||||
Fabricate(:post, topic: topic_with_tags, raw: "hello there")
|
|
||||||
Fabricate(:post, topic: topic_with_tags, raw: "mister sam")
|
|
||||||
|
|
||||||
results = tool.invoke
|
results = tool.invoke
|
||||||
|
|
||||||
expect(results[:topic_id]).to eq(topic_id)
|
expect(results[:topic_id]).to eq(topic_id)
|
||||||
|
|
Loading…
Reference in New Issue