diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml
index 9f0660a9..4ecf5f2c 100644
--- a/config/locales/server.en.yml
+++ b/config/locales/server.en.yml
@@ -96,6 +96,7 @@ en:
ai_bot_allowed_groups: "When the GPT Bot has access to the PM, it will reply to members of these groups."
ai_bot_enabled_chat_bots: "Available models to act as an AI Bot"
ai_bot_add_to_header: "Display a button in the header to start a PM with a AI Bot"
+ ai_bot_github_access_token: "GitHub access token for use with GitHub AI tools (required for search support)"
ai_stability_api_key: "API key for the stability.ai API"
ai_stability_engine: "Image generation engine to use for the stability.ai API"
@@ -154,6 +155,9 @@ en:
personas:
cannot_delete_system_persona: "System personas cannot be deleted, please disable it instead"
cannot_edit_system_persona: "System personas can only be renamed, you may not edit commands or system prompt, instead disable and make a copy"
+ github_helper:
+ name: "GitHub Helper"
+ description: "AI Bot specialized in assisting with GitHub-related tasks and questions"
general:
name: Forum Helper
description: "General purpose AI Bot capable of performing various tasks"
@@ -190,6 +194,9 @@ en:
name: "Base Search Query"
description: "Base query to use when searching. Example: '#urgent' will prepend '#urgent' to the search query and only include topics with the urgent category or tag."
command_summary:
+ github_search_code: "GitHub code search"
+ github_file_content: "GitHub file content"
+ github_pull_request_diff: "GitHub pull request diff"
random_picker: "Random Picker"
categories: "List categories"
search: "Search"
@@ -205,6 +212,9 @@ en:
dall_e: "Generate image"
search_meta_discourse: "Search Meta Discourse"
command_help:
+ github_search_code: "Search for code in a GitHub repository"
+ github_file_content: "Retrieve content of files from a GitHub repository"
+ github_pull_request_diff: "Retrieve a GitHub pull request diff"
random_picker: "Pick a random number or a random element of a list"
categories: "List all publicly visible categories on the forum"
search: "Search all public topics on the forum"
@@ -220,6 +230,9 @@ en:
dall_e: "Generate image using DALL-E 3"
search_meta_discourse: "Search Meta Discourse"
command_description:
+ github_search_code: "Searched for '%{query}' in %{repo}"
+ github_pull_request_diff: "%{repo} %{pull_id}"
+ github_file_content: "Retrieved content of %{file_paths} from %{repo_name}@%{branch}"
random_picker: "Picking from %{options}, picked: %{result}"
read: "Reading: %{title}"
time: "Time in %{timezone} is %{time}"
diff --git a/config/settings.yml b/config/settings.yml
index e3ff88b9..90d47d2b 100644
--- a/config/settings.yml
+++ b/config/settings.yml
@@ -318,3 +318,6 @@ discourse_ai:
ai_bot_add_to_header:
default: true
client: true
+ ai_bot_github_access_token:
+ default: ""
+ secret: true
diff --git a/lib/ai_bot/entry_point.rb b/lib/ai_bot/entry_point.rb
index 29cfe063..7ac9e355 100644
--- a/lib/ai_bot/entry_point.rb
+++ b/lib/ai_bot/entry_point.rb
@@ -2,6 +2,8 @@
module DiscourseAi
module AiBot
+ USER_AGENT = "Discourse AI Bot 1.0 (https://www.discourse.org)"
+
class EntryPoint
REQUIRE_TITLE_UPDATE = "discourse-ai-title-update"
diff --git a/lib/ai_bot/personas/github_helper.rb b/lib/ai_bot/personas/github_helper.rb
new file mode 100644
index 00000000..09e71b4b
--- /dev/null
+++ b/lib/ai_bot/personas/github_helper.rb
@@ -0,0 +1,24 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module AiBot
+ module Personas
+ class GithubHelper < Persona
+ def tools
+ [Tools::GithubFileContent, Tools::GithubPullRequestDiff, Tools::GithubSearchCode]
+ end
+
+ def system_prompt
+ <<~PROMPT
+ You are a helpful GitHub assistant.
+ You _understand_ and **generate** Discourse Flavored Markdown.
+ You live in a Discourse Forum Message.
+
+ Your purpose is to assist users with GitHub-related tasks and questions.
+ When asked about a specific repository, pull request, or file, try to use the available tools to provide accurate and helpful information.
+ PROMPT
+ end
+ end
+ end
+ end
+end
diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb
index 81e8d7a1..f29cc6b3 100644
--- a/lib/ai_bot/personas/persona.rb
+++ b/lib/ai_bot/personas/persona.rb
@@ -15,6 +15,7 @@ module DiscourseAi
Personas::Creative => -6,
Personas::DallE3 => -7,
Personas::DiscourseHelper => -8,
+ Personas::GithubHelper => -9,
}
end
@@ -64,8 +65,12 @@ module DiscourseAi
Tools::SettingContext,
Tools::RandomPicker,
Tools::DiscourseMetaSearch,
+ Tools::GithubFileContent,
+ Tools::GithubPullRequestDiff,
]
+ tools << Tools::GithubSearchCode if SiteSetting.ai_bot_github_access_token.present?
+
tools << Tools::ListTags if SiteSetting.tagging_enabled
tools << Tools::Image if SiteSetting.ai_stability_api_key.present?
@@ -162,7 +167,7 @@ module DiscourseAi
tool_klass.new(
arguments,
- tool_call_id: function_id,
+ tool_call_id: function_id || function_name,
persona_options: options[tool_klass].to_h,
)
end
diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb
index 96ba9b0e..e546f328 100644
--- a/lib/ai_bot/playground.rb
+++ b/lib/ai_bot/playground.rb
@@ -276,6 +276,14 @@ module DiscourseAi
publish_final_update(reply_post) if stream_reply
end
+ def available_bot_usernames
+ @bot_usernames ||=
+ AiPersona
+ .joins(:user)
+ .pluck(:username)
+ .concat(DiscourseAi::AiBot::EntryPoint::BOTS.map(&:second))
+ end
+
private
def publish_final_update(reply_post)
@@ -348,10 +356,6 @@ module DiscourseAi
max_backlog_age: 60,
)
end
-
- def available_bot_usernames
- @bot_usernames ||= DiscourseAi::AiBot::EntryPoint::BOTS.map(&:second)
- end
end
end
end
diff --git a/lib/ai_bot/tools/github_file_content.rb b/lib/ai_bot/tools/github_file_content.rb
new file mode 100644
index 00000000..e5ed76ff
--- /dev/null
+++ b/lib/ai_bot/tools/github_file_content.rb
@@ -0,0 +1,97 @@
+# frozen_string_literal: true
+module DiscourseAi
+ module AiBot
+ module Tools
+ class GithubFileContent < Tool
+ def self.signature
+ {
+ name: name,
+ description: "Retrieves the content of specified GitHub files",
+ parameters: [
+ {
+ name: "repo_name",
+ description: "The name of the GitHub repository (e.g., 'discourse/discourse')",
+ type: "string",
+ required: true,
+ },
+ {
+ name: "file_paths",
+ description: "The paths of the files to retrieve within the repository",
+ type: "array",
+ item_type: "string",
+ required: true,
+ },
+ {
+ name: "branch",
+ description:
+ "The branch or commit SHA to retrieve the files from (default: 'main')",
+ type: "string",
+ required: false,
+ },
+ ],
+ }
+ end
+
+ def self.name
+ "github_file_content"
+ end
+
+ def repo_name
+ parameters[:repo_name]
+ end
+
+ def file_paths
+ parameters[:file_paths]
+ end
+
+ def branch
+ parameters[:branch] || "main"
+ end
+
+ def description_args
+ { repo_name: repo_name, file_paths: file_paths.join(", "), branch: branch }
+ end
+
+ def invoke(_bot_user, llm)
+ owner, repo = repo_name.split("/")
+ file_contents = {}
+ missing_files = []
+
+ file_paths.each do |file_path|
+ api_url =
+ "https://api.github.com/repos/#{owner}/#{repo}/contents/#{file_path}?ref=#{branch}"
+
+ response =
+ send_http_request(
+ api_url,
+ headers: {
+ "Accept" => "application/vnd.github.v3+json",
+ },
+ authenticate_github: true,
+ )
+
+ if response.code == "200"
+ file_data = JSON.parse(response.body)
+ content = Base64.decode64(file_data["content"])
+ file_contents[file_path] = content
+ else
+ missing_files << file_path
+ end
+ end
+
+ result = {}
+ unless file_contents.empty?
+ blob =
+ file_contents.map { |path, content| "File Path: #{path}:\n#{content}" }.join("\n")
+ truncated_blob = truncate(blob, max_length: 20_000, percent_length: 0.3, llm: llm)
+ result[:file_contents] = truncated_blob
+ end
+
+ result[:missing_files] = missing_files unless missing_files.empty?
+
+ result.empty? ? { error: "No files found or retrieved." } : result
+ end
+ end
+ end
+ end
+end
diff --git a/lib/ai_bot/tools/github_pull_request_diff.rb b/lib/ai_bot/tools/github_pull_request_diff.rb
new file mode 100644
index 00000000..6fb57e03
--- /dev/null
+++ b/lib/ai_bot/tools/github_pull_request_diff.rb
@@ -0,0 +1,72 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module AiBot
+ module Tools
+ class GithubPullRequestDiff < Tool
+ def self.signature
+ {
+ name: name,
+ description: "Retrieves the diff for a GitHub pull request",
+ parameters: [
+ {
+ name: "repo",
+ description: "The repository name in the format 'owner/repo'",
+ type: "string",
+ required: true,
+ },
+ {
+ name: "pull_id",
+ description: "The ID of the pull request",
+ type: "integer",
+ required: true,
+ },
+ ],
+ }
+ end
+
+ def self.name
+ "github_pull_request_diff"
+ end
+
+ def repo
+ parameters[:repo]
+ end
+
+ def pull_id
+ parameters[:pull_id]
+ end
+
+ def url
+ @url
+ end
+
+ def invoke(_bot_user, llm)
+ api_url = "https://api.github.com/repos/#{repo}/pulls/#{pull_id}"
+ @url = "https://github.com/#{repo}/pull/#{pull_id}"
+
+ response =
+ send_http_request(
+ api_url,
+ headers: {
+ "Accept" => "application/vnd.github.v3.diff",
+ },
+ authenticate_github: true,
+ )
+
+ if response.code == "200"
+ diff = response.body
+ diff = truncate(diff, max_length: 20_000, percent_length: 0.3, llm: llm)
+ { diff: diff }
+ else
+ { error: "Failed to retrieve the diff. Status code: #{response.code}" }
+ end
+ end
+
+ def description_args
+ { repo: repo, pull_id: pull_id, url: url }
+ end
+ end
+ end
+ end
+end
diff --git a/lib/ai_bot/tools/github_search_code.rb b/lib/ai_bot/tools/github_search_code.rb
new file mode 100644
index 00000000..6133fa45
--- /dev/null
+++ b/lib/ai_bot/tools/github_search_code.rb
@@ -0,0 +1,72 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module AiBot
+ module Tools
+ class GithubSearchCode < Tool
+ def self.signature
+ {
+ name: name,
+ description: "Searches for code in a GitHub repository",
+ parameters: [
+ {
+ name: "repo",
+ description: "The repository name in the format 'owner/repo'",
+ type: "string",
+ required: true,
+ },
+ {
+ name: "query",
+ description: "The search query (e.g., a function name, variable, or code snippet)",
+ type: "string",
+ required: true,
+ },
+ ],
+ }
+ end
+
+ def self.name
+ "github_search_code"
+ end
+
+ def repo
+ parameters[:repo]
+ end
+
+ def query
+ parameters[:query]
+ end
+
+ def description_args
+ { repo: repo, query: query }
+ end
+
+ def invoke(_bot_user, llm)
+ api_url = "https://api.github.com/search/code?q=#{query}+repo:#{repo}"
+
+ response =
+ send_http_request(
+ api_url,
+ headers: {
+ "Accept" => "application/vnd.github.v3.text-match+json",
+ },
+ authenticate_github: true,
+ )
+
+ if response.code == "200"
+ search_data = JSON.parse(response.body)
+ results =
+ search_data["items"]
+ .map { |item| "#{item["path"]}:\n#{item["text_matches"][0]["fragment"]}" }
+ .join("\n---\n")
+
+ results = truncate(results, max_length: 20_000, percent_length: 0.3, llm: llm)
+ { search_results: results }
+ else
+ { error: "Failed to perform code search. Status code: #{response.code}" }
+ end
+ end
+ end
+ end
+ end
+end
diff --git a/lib/ai_bot/tools/tool.rb b/lib/ai_bot/tools/tool.rb
index e223c9a2..1f1d42dd 100644
--- a/lib/ai_bot/tools/tool.rb
+++ b/lib/ai_bot/tools/tool.rb
@@ -77,6 +77,35 @@ module DiscourseAi
protected
+ def send_http_request(url, headers: {}, authenticate_github: false)
+ uri = URI(url)
+ request = FinalDestination::HTTP::Get.new(uri)
+ request["User-Agent"] = DiscourseAi::AiBot::USER_AGENT
+ headers.each { |k, v| request[k] = v }
+ if authenticate_github
+ request["Authorization"] = "Bearer #{SiteSetting.ai_bot_github_access_token}"
+ end
+
+ FinalDestination::HTTP.start(uri.hostname, uri.port, use_ssl: uri.port != 80) do |http|
+ http.request(request)
+ end
+ end
+
+ def truncate(text, llm:, percent_length: nil, max_length: nil)
+ if !percent_length && !max_length
+ raise ArgumentError, "You must provide either percent_length or max_length"
+ end
+
+ target = llm.max_prompt_tokens
+ target = (target * percent_length).to_i if percent_length
+
+ if max_length
+ target = max_length if target > max_length
+ end
+
+ llm.tokenizer.truncate(text, target)
+ end
+
def accepted_options
[]
end
diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb
index e4ab940f..a9d6a957 100644
--- a/lib/completions/dialects/dialect.rb
+++ b/lib/completions/dialects/dialect.rb
@@ -36,7 +36,8 @@ module DiscourseAi
def tool_preamble
<<~TEXT
In this environment you have access to a set of tools you can use to answer the user's question.
- You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
+ You may call them like this.
+
$TOOL_NAME
@@ -47,9 +48,12 @@ module DiscourseAi
- if a parameter type is an array, return a JSON array of values. For example:
+ If a parameter type is an array, return a JSON array of values. For example:
[1,"two",3.0]
+ Always wrap calls in tags.
+ You may call multiple function via in a single block.
+
Here are the tools available:
TEXT
end
diff --git a/lib/completions/endpoints/anthropic_messages.rb b/lib/completions/endpoints/anthropic_messages.rb
index 74e71e1f..7485e0e6 100644
--- a/lib/completions/endpoints/anthropic_messages.rb
+++ b/lib/completions/endpoints/anthropic_messages.rb
@@ -27,8 +27,11 @@ module DiscourseAi
model_params
end
- def default_options
- { model: model + "-20240229", max_tokens: 3_000, stop_sequences: [""] }
+ def default_options(dialect)
+ options = { model: model + "-20240229", max_tokens: 3_000 }
+
+ options[:stop_sequences] = [""] if dialect.prompt.has_tools?
+ options
end
def provider_id
@@ -46,8 +49,8 @@ module DiscourseAi
@uri ||= URI("https://api.anthropic.com/v1/messages")
end
- def prepare_payload(prompt, model_params, _dialect)
- payload = default_options.merge(model_params).merge(messages: prompt.messages)
+ def prepare_payload(prompt, model_params, dialect)
+ payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
payload[:stream] = true if @streaming_mode
diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb
index 39b80d01..ec813cab 100644
--- a/lib/completions/endpoints/base.rb
+++ b/lib/completions/endpoints/base.rb
@@ -64,6 +64,7 @@ module DiscourseAi
end
def perform_completion!(dialect, user, model_params = {})
+ allow_tools = dialect.prompt.has_tools?
model_params = normalize_model_params(model_params)
@streaming_mode = block_given?
@@ -110,9 +111,11 @@ module DiscourseAi
response_data = extract_completion_from(response_raw)
partials_raw = response_data.to_s
- if has_tool?(response_data)
+ if allow_tools && has_tool?(response_data)
function_buffer = build_buffer # Nokogiri document
- function_buffer = add_to_buffer(function_buffer, "", response_data)
+ function_buffer = add_to_function_buffer(function_buffer, payload: response_data)
+
+ normalize_function_ids!(function_buffer)
response_data = +function_buffer.at("function_calls").to_s
response_data << "\n"
@@ -156,6 +159,7 @@ module DiscourseAi
end
json_error = false
+ buffered_partials = []
raw_partials.each do |raw_partial|
json_error = false
@@ -173,14 +177,30 @@ module DiscourseAi
# Stop streaming the response as soon as you find a tool.
# We'll buffer and yield it later.
- has_tool = true if has_tool?(partials_raw)
+ has_tool = true if allow_tools && has_tool?(partials_raw)
if has_tool
- function_buffer = add_to_buffer(function_buffer, partials_raw, partial)
+ if buffered_partials.present?
+ joined = buffered_partials.join
+ joined = joined.gsub(/<.+/, "")
+ yield joined, cancel if joined.present?
+ buffered_partials = []
+ end
+ function_buffer = add_to_function_buffer(function_buffer, partial: partial)
else
- response_data << partial
-
- yield partial, cancel if partial
+ if maybe_has_tool?(partials_raw)
+ buffered_partials << partial
+ else
+ if buffered_partials.present?
+ buffered_partials.each do |buffered_partial|
+ response_data << buffered_partial
+ yield buffered_partial, cancel
+ end
+ buffered_partials = []
+ end
+ response_data << partial
+ yield partial, cancel if partial
+ end
end
rescue JSON::ParserError
leftover = redo_chunk
@@ -201,7 +221,11 @@ module DiscourseAi
# Once we have the full response, try to return the tool as a XML doc.
if has_tool
+ function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw)
+
if function_buffer.at("tool_name").text.present?
+ normalize_function_ids!(function_buffer)
+
invocation = +function_buffer.at("function_calls").to_s
invocation << "\n"
@@ -226,6 +250,21 @@ module DiscourseAi
end
end
+ def normalize_function_ids!(function_buffer)
+ function_buffer
+ .css("invoke")
+ .each_with_index do |invoke, index|
+ if invoke.at("tool_id")
+ invoke.at("tool_id").content = "tool_#{index}" if invoke
+ .at("tool_id")
+ .content
+ .blank?
+ else
+ invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id")
+ end
+ end
+ end
+
def final_log_update(log)
# for people that need to override
end
@@ -291,48 +330,36 @@ module DiscourseAi
(<<~TEXT).strip
-
+
TEXT
end
def has_tool?(response)
- response.include?("")
end
- def add_to_buffer(function_buffer, response_data, partial)
- raw_data = (response_data + partial)
+ def maybe_has_tool?(response)
+ # 16 is the length of function calls
+ substring = response[-16..-1] || response
+ split = substring.split("<")
- # recover stop word potentially
- raw_data =
- raw_data.split("").first + "\n" if raw_data.split(
- "",
- ).length > 1
-
- return function_buffer unless raw_data.include?("")
-
- read_function = Nokogiri::HTML5.fragment(raw_data)
-
- if tool_name = read_function.at("tool_name")&.text
- function_buffer.at("tool_name").inner_html = tool_name
- function_buffer.at("tool_id").inner_html = tool_name
+ if split.length > 1
+ match = "<" + split.last
+ "".start_with?(match)
+ else
+ false
end
+ end
- read_function
- .at("parameters")
- &.elements
- .to_a
- .each do |elem|
- if parameter = function_buffer.at(elem.name)&.text
- function_buffer.at(elem.name).inner_html = parameter
- else
- param_node = read_function.at(elem.name)
- function_buffer.at("parameters").add_child(param_node)
- function_buffer.at("parameters").add_child("\n")
- end
- end
+ def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
+ if payload&.include?("")
+ matches = payload.match(%r{.*}m)
+ function_buffer =
+ Nokogiri::HTML5.fragment(matches[0] + "\n") if matches
+ end
function_buffer
end
diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb
index 8668c658..86e189a9 100644
--- a/lib/completions/endpoints/gemini.rb
+++ b/lib/completions/endpoints/gemini.rb
@@ -104,12 +104,20 @@ module DiscourseAi
@has_function_call
end
- def add_to_buffer(function_buffer, _response_data, partial)
- if partial[:name].present?
- function_buffer.at("tool_name").content = partial[:name]
- function_buffer.at("tool_id").content = partial[:name]
+ def maybe_has_tool?(_partial_raw)
+ # we always get a full partial
+ false
+ end
+
+ def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
+ if @streaming_mode
+ return function_buffer if !partial
+ else
+ partial = payload
end
+ function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
+
if partial[:args]
argument_fragments =
partial[:args].reduce(+"") do |memo, (arg_name, value)|
diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb
index e0027324..029384bb 100644
--- a/lib/completions/endpoints/open_ai.rb
+++ b/lib/completions/endpoints/open_ai.rb
@@ -163,7 +163,18 @@ module DiscourseAi
@has_function_call
end
- def add_to_buffer(function_buffer, _response_data, partial)
+ def maybe_has_tool?(_partial_raw)
+ # we always get a full partial
+ false
+ end
+
+ def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
+ if @streaming_mode
+ return function_buffer if !partial
+ else
+ partial = payload
+ end
+
@args_buffer ||= +""
f_name = partial.dig(:function, :name)
diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb
index d356a238..959b3d14 100644
--- a/lib/completions/prompt.rb
+++ b/lib/completions/prompt.rb
@@ -49,6 +49,10 @@ module DiscourseAi
messages << new_message
end
+ def has_tools?
+ tools.present?
+ end
+
private
def validate_message(message)
diff --git a/spec/lib/completions/endpoints/anthropic_messages_spec.rb b/spec/lib/completions/endpoints/anthropic_messages_spec.rb
index 8f17a85f..1d52e46c 100644
--- a/spec/lib/completions/endpoints/anthropic_messages_spec.rb
+++ b/spec/lib/completions/endpoints/anthropic_messages_spec.rb
@@ -10,6 +10,36 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do
)
end
+ let(:echo_tool) do
+ {
+ name: "echo",
+ description: "echo something",
+ parameters: [{ name: "text", type: "string", description: "text to echo", required: true }],
+ }
+ end
+
+ let(:google_tool) do
+ {
+ name: "google",
+ description: "google something",
+ parameters: [
+ { name: "query", type: "string", description: "text to google", required: true },
+ ],
+ }
+ end
+
+ let(:prompt_with_echo_tool) do
+ prompt_with_tools = prompt
+ prompt.tools = [echo_tool]
+ prompt_with_tools
+ end
+
+ let(:prompt_with_google_tool) do
+ prompt_with_tools = prompt
+ prompt.tools = [echo_tool]
+ prompt_with_tools
+ end
+
before { SiteSetting.ai_anthropic_api_key = "123" }
it "does not eat spaces with tool calls" do
@@ -165,16 +195,18 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
result = +""
- llm.generate(prompt, user: Discourse.system_user) { |partial| result << partial }
+ llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial|
+ result << partial
+ end
expected = (<<~TEXT).strip
google
- google
top 10 things to do in japan for tourists
+ tool_0
TEXT
@@ -232,7 +264,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do
expected_body = {
model: "claude-3-opus-20240229",
max_tokens: 3000,
- stop_sequences: [""],
messages: [{ role: "user", content: "user1: hello" }],
system: "You are hello bot",
stream: true,
@@ -245,6 +276,70 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do
expect(log.response_tokens).to eq(15)
end
+ it "can return multiple function calls" do
+ functions = <<~FUNCTIONS
+
+
+ echo
+
+ something
+
+
+
+ echo
+
+ something else
+
+
+ FUNCTIONS
+
+ body = <<~STRING
+ {
+ "content": [
+ {
+ "text": "Hello!\n\n#{functions}\njunk",
+ "type": "text"
+ }
+ ],
+ "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
+ "model": "claude-3-opus-20240229",
+ "role": "assistant",
+ "stop_reason": "end_turn",
+ "stop_sequence": null,
+ "type": "message",
+ "usage": {
+ "input_tokens": 10,
+ "output_tokens": 25
+ }
+ }
+ STRING
+
+ stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
+
+ result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user)
+
+ expected = (<<~EXPECTED).strip
+
+
+ echo
+
+ something
+
+ tool_0
+
+
+ echo
+
+ something else
+
+ tool_1
+
+
+ EXPECTED
+
+ expect(result.strip).to eq(expected)
+ end
+
it "can operate in regular mode" do
body = <<~STRING
{
@@ -287,7 +382,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do
expected_body = {
model: "claude-3-opus-20240229",
max_tokens: 3000,
- stop_sequences: [""],
messages: [{ role: "user", content: "user1: hello" }],
system: "You are hello bot",
}
diff --git a/spec/lib/completions/endpoints/endpoint_compliance.rb b/spec/lib/completions/endpoints/endpoint_compliance.rb
index fe0e44fd..31f5e86e 100644
--- a/spec/lib/completions/endpoints/endpoint_compliance.rb
+++ b/spec/lib/completions/endpoints/endpoint_compliance.rb
@@ -66,11 +66,11 @@ class EndpointMock
get_weather
- #{tool_id}
Sydney
c
+ tool_0
TEXT
diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb
index 8f9cb367..dbc5fe31 100644
--- a/spec/lib/completions/endpoints/gemini_spec.rb
+++ b/spec/lib/completions/endpoints/gemini_spec.rb
@@ -132,7 +132,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
fab!(:user)
- let(:bedrock_mock) { GeminiMock.new(endpoint) }
+ let(:gemini_mock) { GeminiMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user)
@@ -142,13 +142,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
- compliance.regular_mode_simple_prompt(bedrock_mock)
+ compliance.regular_mode_simple_prompt(gemini_mock)
end
end
context "with tools" do
it "returns a function invocation" do
- compliance.regular_mode_tools(bedrock_mock)
+ compliance.regular_mode_tools(gemini_mock)
end
end
end
@@ -156,13 +156,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
- compliance.streaming_mode_simple_prompt(bedrock_mock)
+ compliance.streaming_mode_simple_prompt(gemini_mock)
end
end
context "with tools" do
it "returns a function invocation" do
- compliance.streaming_mode_tools(bedrock_mock)
+ compliance.streaming_mode_tools(gemini_mock)
end
end
end
diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb
index a83ed0a6..fe662ae2 100644
--- a/spec/lib/completions/endpoints/open_ai_spec.rb
+++ b/spec/lib/completions/endpoints/open_ai_spec.rb
@@ -102,7 +102,7 @@ class OpenAiMock < EndpointMock
end
def tool_id
- "eujbuebfe"
+ "tool_0"
end
def tool_payload
@@ -149,6 +149,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
fab!(:user)
+ let(:echo_tool) do
+ {
+ name: "echo",
+ description: "echo something",
+ parameters: [{ name: "text", type: "string", description: "text to echo", required: true }],
+ }
+ end
+
+ let(:tools) { [echo_tool] }
+
let(:open_ai_mock) { OpenAiMock.new(endpoint) }
let(:compliance) do
@@ -260,23 +270,25 @@ TEXT
open_ai_mock.stub_raw(raw_data)
content = +""
- endpoint.perform_completion!(compliance.dialect, user) { |partial| content << partial }
+ dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
+
+ endpoint.perform_completion!(dialect, user) { |partial| content << partial }
expected = <<~TEXT
search
- call_3Gyr3HylFJwfrtKrL6NaIit1
Discourse AI bot
+ call_3Gyr3HylFJwfrtKrL6NaIit1
search
- call_H7YkbgYurHpyJqzwUN4bghwN
Discourse AI bot
+ call_H7YkbgYurHpyJqzwUN4bghwN
TEXT
@@ -321,7 +333,8 @@ TEXT
open_ai_mock.stub_raw(chunks)
partials = []
- endpoint.perform_completion!(compliance.dialect, user) { |partial| partials << partial }
+ dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
+ endpoint.perform_completion!(dialect, user) { |partial| partials << partial }
expect(partials.length).to eq(1)
@@ -329,10 +342,10 @@ TEXT
google
- func_id
Adabas 9.1
+ func_id
TXT
diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb
index 0e2cf504..a60a6b16 100644
--- a/spec/lib/modules/ai_bot/personas/persona_spec.rb
+++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb
@@ -153,6 +153,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
DiscourseAi::AiBot::Personas::Artist,
DiscourseAi::AiBot::Personas::Creative,
DiscourseAi::AiBot::Personas::DiscourseHelper,
+ DiscourseAi::AiBot::Personas::GithubHelper,
DiscourseAi::AiBot::Personas::Researcher,
DiscourseAi::AiBot::Personas::SettingsExplorer,
DiscourseAi::AiBot::Personas::SqlHelper,
@@ -169,6 +170,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
DiscourseAi::AiBot::Personas::SettingsExplorer,
DiscourseAi::AiBot::Personas::Creative,
DiscourseAi::AiBot::Personas::DiscourseHelper,
+ DiscourseAi::AiBot::Personas::GithubHelper,
)
AiPersona.find(
@@ -182,6 +184,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
DiscourseAi::AiBot::Personas::SettingsExplorer,
DiscourseAi::AiBot::Personas::Creative,
DiscourseAi::AiBot::Personas::DiscourseHelper,
+ DiscourseAi::AiBot::Personas::GithubHelper,
)
end
end
diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb
index f9d444de..d6524108 100644
--- a/spec/lib/modules/ai_bot/playground_spec.rb
+++ b/spec/lib/modules/ai_bot/playground_spec.rb
@@ -337,6 +337,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end
end
+ describe "#available_bot_usernames" do
+ it "includes persona users" do
+ persona = Fabricate(:ai_persona)
+ persona.create_user!
+
+ expect(playground.available_bot_usernames).to include(persona.user.username)
+ end
+ end
+
describe "#conversation_context" do
context "with limited context" do
before do
diff --git a/spec/lib/modules/ai_bot/tools/github_file_content_spec.rb b/spec/lib/modules/ai_bot/tools/github_file_content_spec.rb
new file mode 100644
index 00000000..aa2263bd
--- /dev/null
+++ b/spec/lib/modules/ai_bot/tools/github_file_content_spec.rb
@@ -0,0 +1,78 @@
+# frozen_string_literal: true
+
+require "rails_helper"
+
+RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do
+ let(:tool) do
+ described_class.new(
+ {
+ repo_name: "discourse/discourse-ai",
+ file_paths: %w[lib/database/connection.rb lib/ai_bot/tools/github_pull_request_diff.rb],
+ branch: "8b382d6098fde879d28bbee68d3cbe0a193e4ffc",
+ },
+ )
+ end
+
+ let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
+
+ describe "#invoke" do
+ before do
+ stub_request(
+ :get,
+ "https://api.github.com/repos/discourse/discourse-ai/contents/lib/database/connection.rb?ref=8b382d6098fde879d28bbee68d3cbe0a193e4ffc",
+ ).to_return(
+ status: 200,
+ body: { content: Base64.encode64("content of connection.rb") }.to_json,
+ )
+
+ stub_request(
+ :get,
+ "https://api.github.com/repos/discourse/discourse-ai/contents/lib/ai_bot/tools/github_pull_request_diff.rb?ref=8b382d6098fde879d28bbee68d3cbe0a193e4ffc",
+ ).to_return(
+ status: 200,
+ body: { content: Base64.encode64("content of github_pull_request_diff.rb") }.to_json,
+ )
+ end
+
+ it "retrieves the content of the specified GitHub files" do
+ result = tool.invoke(nil, llm)
+ expected = {
+ file_contents:
+ "File Path: lib/database/connection.rb:\ncontent of connection.rb\nFile Path: lib/ai_bot/tools/github_pull_request_diff.rb:\ncontent of github_pull_request_diff.rb",
+ }
+
+ expect(result).to eq(expected)
+ end
+ end
+
+ describe ".signature" do
+ it "returns the tool signature" do
+ signature = described_class.signature
+ expect(signature[:name]).to eq("github_file_content")
+ expect(signature[:description]).to eq("Retrieves the content of specified GitHub files")
+ expect(signature[:parameters]).to eq(
+ [
+ {
+ name: "repo_name",
+ description: "The name of the GitHub repository (e.g., 'discourse/discourse')",
+ type: "string",
+ required: true,
+ },
+ {
+ name: "file_paths",
+ description: "The paths of the files to retrieve within the repository",
+ type: "array",
+ item_type: "string",
+ required: true,
+ },
+ {
+ name: "branch",
+ description: "The branch or commit SHA to retrieve the files from (default: 'main')",
+ type: "string",
+ required: false,
+ },
+ ],
+ )
+ end
+ end
+end
diff --git a/spec/lib/modules/ai_bot/tools/github_pull_request_diff_spec.rb b/spec/lib/modules/ai_bot/tools/github_pull_request_diff_spec.rb
new file mode 100644
index 00000000..4260643c
--- /dev/null
+++ b/spec/lib/modules/ai_bot/tools/github_pull_request_diff_spec.rb
@@ -0,0 +1,61 @@
+# frozen_string_literal: true
+
+require "rails_helper"
+
+RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do
+ let(:tool) { described_class.new({ repo: repo, pull_id: pull_id }) }
+ let(:bot_user) { Fabricate(:user) }
+ let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
+
+ context "with a valid pull request" do
+ let(:repo) { "discourse/discourse-automation" }
+ let(:pull_id) { 253 }
+
+ it "retrieves the diff for the pull request" do
+ stub_request(:get, "https://api.github.com/repos/#{repo}/pulls/#{pull_id}").with(
+ headers: {
+ "Accept" => "application/vnd.github.v3.diff",
+ "User-Agent" => DiscourseAi::AiBot::USER_AGENT,
+ },
+ ).to_return(status: 200, body: "sample diff")
+
+ result = tool.invoke(bot_user, llm)
+ expect(result[:diff]).to eq("sample diff")
+ expect(result[:error]).to be_nil
+ end
+
+ it "uses the github access token if present" do
+ SiteSetting.ai_bot_github_access_token = "ABC"
+
+ stub_request(:get, "https://api.github.com/repos/#{repo}/pulls/#{pull_id}").with(
+ headers: {
+ "Accept" => "application/vnd.github.v3.diff",
+ "User-Agent" => DiscourseAi::AiBot::USER_AGENT,
+ "Authorization" => "Bearer ABC",
+ },
+ ).to_return(status: 200, body: "sample diff")
+
+ result = tool.invoke(bot_user, llm)
+ expect(result[:diff]).to eq("sample diff")
+ expect(result[:error]).to be_nil
+ end
+ end
+
+ context "with an invalid pull request" do
+ let(:repo) { "invalid/repo" }
+ let(:pull_id) { 999 }
+
+ it "returns an error message" do
+ stub_request(:get, "https://api.github.com/repos/#{repo}/pulls/#{pull_id}").with(
+ headers: {
+ "Accept" => "application/vnd.github.v3.diff",
+ "User-Agent" => DiscourseAi::AiBot::USER_AGENT,
+ },
+ ).to_return(status: 404)
+
+ result = tool.invoke(bot_user, nil)
+ expect(result[:diff]).to be_nil
+ expect(result[:error]).to include("Failed to retrieve the diff")
+ end
+ end
+end
diff --git a/spec/lib/modules/ai_bot/tools/github_search_code_spec.rb b/spec/lib/modules/ai_bot/tools/github_search_code_spec.rb
new file mode 100644
index 00000000..789ad8c2
--- /dev/null
+++ b/spec/lib/modules/ai_bot/tools/github_search_code_spec.rb
@@ -0,0 +1,93 @@
+# frozen_string_literal: true
+
+require "rails_helper"
+
+RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do
+ let(:tool) { described_class.new({ repo: repo, query: query }) }
+ let(:bot_user) { Fabricate(:user) }
+ let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
+
+ context "with valid search results" do
+ let(:repo) { "discourse/discourse" }
+ let(:query) { "def hello" }
+
+ it "searches for code in the repository" do
+ stub_request(
+ :get,
+ "https://api.github.com/search/code?q=def%20hello+repo:discourse/discourse",
+ ).with(
+ headers: {
+ "Accept" => "application/vnd.github.v3.text-match+json",
+ "User-Agent" => DiscourseAi::AiBot::USER_AGENT,
+ },
+ ).to_return(
+ status: 200,
+ body: {
+ total_count: 1,
+ items: [
+ {
+ path: "test/hello.rb",
+ name: "hello.rb",
+ text_matches: [{ fragment: "def hello\n puts 'hello'\nend" }],
+ },
+ ],
+ }.to_json,
+ )
+
+ result = tool.invoke(bot_user, llm)
+ expect(result[:search_results]).to include("def hello\n puts 'hello'\nend")
+ expect(result[:search_results]).to include("test/hello.rb")
+ expect(result[:error]).to be_nil
+ end
+ end
+
+ context "with an empty search result" do
+ let(:repo) { "discourse/discourse" }
+ let(:query) { "nonexistent_method" }
+
+ describe "#description_args" do
+ it "returns the repo and query" do
+ expect(tool.description_args).to eq(repo: repo, query: query)
+ end
+ end
+
+ it "returns an empty result" do
+ SiteSetting.ai_bot_github_access_token = "ABC"
+ stub_request(
+ :get,
+ "https://api.github.com/search/code?q=nonexistent_method+repo:discourse/discourse",
+ ).with(
+ headers: {
+ "Accept" => "application/vnd.github.v3.text-match+json",
+ "User-Agent" => DiscourseAi::AiBot::USER_AGENT,
+ "Authorization" => "Bearer ABC",
+ },
+ ).to_return(status: 200, body: { total_count: 0, items: [] }.to_json)
+
+ result = tool.invoke(bot_user, llm)
+ expect(result[:search_results]).to be_empty
+ expect(result[:error]).to be_nil
+ end
+ end
+
+ context "with an error response" do
+ let(:repo) { "discourse/discourse" }
+ let(:query) { "def hello" }
+
+ it "returns an error message" do
+ stub_request(
+ :get,
+ "https://api.github.com/search/code?q=def%20hello+repo:discourse/discourse",
+ ).with(
+ headers: {
+ "Accept" => "application/vnd.github.v3.text-match+json",
+ "User-Agent" => DiscourseAi::AiBot::USER_AGENT,
+ },
+ ).to_return(status: 403)
+
+ result = tool.invoke(bot_user, llm)
+ expect(result[:search_results]).to be_nil
+ expect(result[:error]).to include("Failed to perform code search")
+ end
+ end
+end