diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 9f0660a9..4ecf5f2c 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -96,6 +96,7 @@ en: ai_bot_allowed_groups: "When the GPT Bot has access to the PM, it will reply to members of these groups." ai_bot_enabled_chat_bots: "Available models to act as an AI Bot" ai_bot_add_to_header: "Display a button in the header to start a PM with a AI Bot" + ai_bot_github_access_token: "GitHub access token for use with GitHub AI tools (required for search support)" ai_stability_api_key: "API key for the stability.ai API" ai_stability_engine: "Image generation engine to use for the stability.ai API" @@ -154,6 +155,9 @@ en: personas: cannot_delete_system_persona: "System personas cannot be deleted, please disable it instead" cannot_edit_system_persona: "System personas can only be renamed, you may not edit commands or system prompt, instead disable and make a copy" + github_helper: + name: "GitHub Helper" + description: "AI Bot specialized in assisting with GitHub-related tasks and questions" general: name: Forum Helper description: "General purpose AI Bot capable of performing various tasks" @@ -190,6 +194,9 @@ en: name: "Base Search Query" description: "Base query to use when searching. Example: '#urgent' will prepend '#urgent' to the search query and only include topics with the urgent category or tag." command_summary: + github_search_code: "GitHub code search" + github_file_content: "GitHub file content" + github_pull_request_diff: "GitHub pull request diff" random_picker: "Random Picker" categories: "List categories" search: "Search" @@ -205,6 +212,9 @@ en: dall_e: "Generate image" search_meta_discourse: "Search Meta Discourse" command_help: + github_search_code: "Search for code in a GitHub repository" + github_file_content: "Retrieve content of files from a GitHub repository" + github_pull_request_diff: "Retrieve a GitHub pull request diff" random_picker: "Pick a random number or a random element of a list" categories: "List all publicly visible categories on the forum" search: "Search all public topics on the forum" @@ -220,6 +230,9 @@ en: dall_e: "Generate image using DALL-E 3" search_meta_discourse: "Search Meta Discourse" command_description: + github_search_code: "Searched for '%{query}' in %{repo}" + github_pull_request_diff: "%{repo} %{pull_id}" + github_file_content: "Retrieved content of %{file_paths} from %{repo_name}@%{branch}" random_picker: "Picking from %{options}, picked: %{result}" read: "Reading: %{title}" time: "Time in %{timezone} is %{time}" diff --git a/config/settings.yml b/config/settings.yml index e3ff88b9..90d47d2b 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -318,3 +318,6 @@ discourse_ai: ai_bot_add_to_header: default: true client: true + ai_bot_github_access_token: + default: "" + secret: true diff --git a/lib/ai_bot/entry_point.rb b/lib/ai_bot/entry_point.rb index 29cfe063..7ac9e355 100644 --- a/lib/ai_bot/entry_point.rb +++ b/lib/ai_bot/entry_point.rb @@ -2,6 +2,8 @@ module DiscourseAi module AiBot + USER_AGENT = "Discourse AI Bot 1.0 (https://www.discourse.org)" + class EntryPoint REQUIRE_TITLE_UPDATE = "discourse-ai-title-update" diff --git a/lib/ai_bot/personas/github_helper.rb b/lib/ai_bot/personas/github_helper.rb new file mode 100644 index 00000000..09e71b4b --- /dev/null +++ b/lib/ai_bot/personas/github_helper.rb @@ -0,0 +1,24 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + module Personas + class GithubHelper < Persona + def tools + [Tools::GithubFileContent, Tools::GithubPullRequestDiff, Tools::GithubSearchCode] + end + + def system_prompt + <<~PROMPT + You are a helpful GitHub assistant. + You _understand_ and **generate** Discourse Flavored Markdown. + You live in a Discourse Forum Message. + + Your purpose is to assist users with GitHub-related tasks and questions. + When asked about a specific repository, pull request, or file, try to use the available tools to provide accurate and helpful information. + PROMPT + end + end + end + end +end diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 81e8d7a1..f29cc6b3 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -15,6 +15,7 @@ module DiscourseAi Personas::Creative => -6, Personas::DallE3 => -7, Personas::DiscourseHelper => -8, + Personas::GithubHelper => -9, } end @@ -64,8 +65,12 @@ module DiscourseAi Tools::SettingContext, Tools::RandomPicker, Tools::DiscourseMetaSearch, + Tools::GithubFileContent, + Tools::GithubPullRequestDiff, ] + tools << Tools::GithubSearchCode if SiteSetting.ai_bot_github_access_token.present? + tools << Tools::ListTags if SiteSetting.tagging_enabled tools << Tools::Image if SiteSetting.ai_stability_api_key.present? @@ -162,7 +167,7 @@ module DiscourseAi tool_klass.new( arguments, - tool_call_id: function_id, + tool_call_id: function_id || function_name, persona_options: options[tool_klass].to_h, ) end diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 96ba9b0e..e546f328 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -276,6 +276,14 @@ module DiscourseAi publish_final_update(reply_post) if stream_reply end + def available_bot_usernames + @bot_usernames ||= + AiPersona + .joins(:user) + .pluck(:username) + .concat(DiscourseAi::AiBot::EntryPoint::BOTS.map(&:second)) + end + private def publish_final_update(reply_post) @@ -348,10 +356,6 @@ module DiscourseAi max_backlog_age: 60, ) end - - def available_bot_usernames - @bot_usernames ||= DiscourseAi::AiBot::EntryPoint::BOTS.map(&:second) - end end end end diff --git a/lib/ai_bot/tools/github_file_content.rb b/lib/ai_bot/tools/github_file_content.rb new file mode 100644 index 00000000..e5ed76ff --- /dev/null +++ b/lib/ai_bot/tools/github_file_content.rb @@ -0,0 +1,97 @@ +# frozen_string_literal: true +module DiscourseAi + module AiBot + module Tools + class GithubFileContent < Tool + def self.signature + { + name: name, + description: "Retrieves the content of specified GitHub files", + parameters: [ + { + name: "repo_name", + description: "The name of the GitHub repository (e.g., 'discourse/discourse')", + type: "string", + required: true, + }, + { + name: "file_paths", + description: "The paths of the files to retrieve within the repository", + type: "array", + item_type: "string", + required: true, + }, + { + name: "branch", + description: + "The branch or commit SHA to retrieve the files from (default: 'main')", + type: "string", + required: false, + }, + ], + } + end + + def self.name + "github_file_content" + end + + def repo_name + parameters[:repo_name] + end + + def file_paths + parameters[:file_paths] + end + + def branch + parameters[:branch] || "main" + end + + def description_args + { repo_name: repo_name, file_paths: file_paths.join(", "), branch: branch } + end + + def invoke(_bot_user, llm) + owner, repo = repo_name.split("/") + file_contents = {} + missing_files = [] + + file_paths.each do |file_path| + api_url = + "https://api.github.com/repos/#{owner}/#{repo}/contents/#{file_path}?ref=#{branch}" + + response = + send_http_request( + api_url, + headers: { + "Accept" => "application/vnd.github.v3+json", + }, + authenticate_github: true, + ) + + if response.code == "200" + file_data = JSON.parse(response.body) + content = Base64.decode64(file_data["content"]) + file_contents[file_path] = content + else + missing_files << file_path + end + end + + result = {} + unless file_contents.empty? + blob = + file_contents.map { |path, content| "File Path: #{path}:\n#{content}" }.join("\n") + truncated_blob = truncate(blob, max_length: 20_000, percent_length: 0.3, llm: llm) + result[:file_contents] = truncated_blob + end + + result[:missing_files] = missing_files unless missing_files.empty? + + result.empty? ? { error: "No files found or retrieved." } : result + end + end + end + end +end diff --git a/lib/ai_bot/tools/github_pull_request_diff.rb b/lib/ai_bot/tools/github_pull_request_diff.rb new file mode 100644 index 00000000..6fb57e03 --- /dev/null +++ b/lib/ai_bot/tools/github_pull_request_diff.rb @@ -0,0 +1,72 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + module Tools + class GithubPullRequestDiff < Tool + def self.signature + { + name: name, + description: "Retrieves the diff for a GitHub pull request", + parameters: [ + { + name: "repo", + description: "The repository name in the format 'owner/repo'", + type: "string", + required: true, + }, + { + name: "pull_id", + description: "The ID of the pull request", + type: "integer", + required: true, + }, + ], + } + end + + def self.name + "github_pull_request_diff" + end + + def repo + parameters[:repo] + end + + def pull_id + parameters[:pull_id] + end + + def url + @url + end + + def invoke(_bot_user, llm) + api_url = "https://api.github.com/repos/#{repo}/pulls/#{pull_id}" + @url = "https://github.com/#{repo}/pull/#{pull_id}" + + response = + send_http_request( + api_url, + headers: { + "Accept" => "application/vnd.github.v3.diff", + }, + authenticate_github: true, + ) + + if response.code == "200" + diff = response.body + diff = truncate(diff, max_length: 20_000, percent_length: 0.3, llm: llm) + { diff: diff } + else + { error: "Failed to retrieve the diff. Status code: #{response.code}" } + end + end + + def description_args + { repo: repo, pull_id: pull_id, url: url } + end + end + end + end +end diff --git a/lib/ai_bot/tools/github_search_code.rb b/lib/ai_bot/tools/github_search_code.rb new file mode 100644 index 00000000..6133fa45 --- /dev/null +++ b/lib/ai_bot/tools/github_search_code.rb @@ -0,0 +1,72 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + module Tools + class GithubSearchCode < Tool + def self.signature + { + name: name, + description: "Searches for code in a GitHub repository", + parameters: [ + { + name: "repo", + description: "The repository name in the format 'owner/repo'", + type: "string", + required: true, + }, + { + name: "query", + description: "The search query (e.g., a function name, variable, or code snippet)", + type: "string", + required: true, + }, + ], + } + end + + def self.name + "github_search_code" + end + + def repo + parameters[:repo] + end + + def query + parameters[:query] + end + + def description_args + { repo: repo, query: query } + end + + def invoke(_bot_user, llm) + api_url = "https://api.github.com/search/code?q=#{query}+repo:#{repo}" + + response = + send_http_request( + api_url, + headers: { + "Accept" => "application/vnd.github.v3.text-match+json", + }, + authenticate_github: true, + ) + + if response.code == "200" + search_data = JSON.parse(response.body) + results = + search_data["items"] + .map { |item| "#{item["path"]}:\n#{item["text_matches"][0]["fragment"]}" } + .join("\n---\n") + + results = truncate(results, max_length: 20_000, percent_length: 0.3, llm: llm) + { search_results: results } + else + { error: "Failed to perform code search. Status code: #{response.code}" } + end + end + end + end + end +end diff --git a/lib/ai_bot/tools/tool.rb b/lib/ai_bot/tools/tool.rb index e223c9a2..1f1d42dd 100644 --- a/lib/ai_bot/tools/tool.rb +++ b/lib/ai_bot/tools/tool.rb @@ -77,6 +77,35 @@ module DiscourseAi protected + def send_http_request(url, headers: {}, authenticate_github: false) + uri = URI(url) + request = FinalDestination::HTTP::Get.new(uri) + request["User-Agent"] = DiscourseAi::AiBot::USER_AGENT + headers.each { |k, v| request[k] = v } + if authenticate_github + request["Authorization"] = "Bearer #{SiteSetting.ai_bot_github_access_token}" + end + + FinalDestination::HTTP.start(uri.hostname, uri.port, use_ssl: uri.port != 80) do |http| + http.request(request) + end + end + + def truncate(text, llm:, percent_length: nil, max_length: nil) + if !percent_length && !max_length + raise ArgumentError, "You must provide either percent_length or max_length" + end + + target = llm.max_prompt_tokens + target = (target * percent_length).to_i if percent_length + + if max_length + target = max_length if target > max_length + end + + llm.tokenizer.truncate(text, target) + end + def accepted_options [] end diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index e4ab940f..a9d6a957 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -36,7 +36,8 @@ module DiscourseAi def tool_preamble <<~TEXT In this environment you have access to a set of tools you can use to answer the user's question. - You may call them like this. Only invoke one function at a time and wait for the results before invoking another function: + You may call them like this. + $TOOL_NAME @@ -47,9 +48,12 @@ module DiscourseAi - if a parameter type is an array, return a JSON array of values. For example: + If a parameter type is an array, return a JSON array of values. For example: [1,"two",3.0] + Always wrap calls in tags. + You may call multiple function via in a single block. + Here are the tools available: TEXT end diff --git a/lib/completions/endpoints/anthropic_messages.rb b/lib/completions/endpoints/anthropic_messages.rb index 74e71e1f..7485e0e6 100644 --- a/lib/completions/endpoints/anthropic_messages.rb +++ b/lib/completions/endpoints/anthropic_messages.rb @@ -27,8 +27,11 @@ module DiscourseAi model_params end - def default_options - { model: model + "-20240229", max_tokens: 3_000, stop_sequences: [""] } + def default_options(dialect) + options = { model: model + "-20240229", max_tokens: 3_000 } + + options[:stop_sequences] = [""] if dialect.prompt.has_tools? + options end def provider_id @@ -46,8 +49,8 @@ module DiscourseAi @uri ||= URI("https://api.anthropic.com/v1/messages") end - def prepare_payload(prompt, model_params, _dialect) - payload = default_options.merge(model_params).merge(messages: prompt.messages) + def prepare_payload(prompt, model_params, dialect) + payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) payload[:system] = prompt.system_prompt if prompt.system_prompt.present? payload[:stream] = true if @streaming_mode diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 39b80d01..ec813cab 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -64,6 +64,7 @@ module DiscourseAi end def perform_completion!(dialect, user, model_params = {}) + allow_tools = dialect.prompt.has_tools? model_params = normalize_model_params(model_params) @streaming_mode = block_given? @@ -110,9 +111,11 @@ module DiscourseAi response_data = extract_completion_from(response_raw) partials_raw = response_data.to_s - if has_tool?(response_data) + if allow_tools && has_tool?(response_data) function_buffer = build_buffer # Nokogiri document - function_buffer = add_to_buffer(function_buffer, "", response_data) + function_buffer = add_to_function_buffer(function_buffer, payload: response_data) + + normalize_function_ids!(function_buffer) response_data = +function_buffer.at("function_calls").to_s response_data << "\n" @@ -156,6 +159,7 @@ module DiscourseAi end json_error = false + buffered_partials = [] raw_partials.each do |raw_partial| json_error = false @@ -173,14 +177,30 @@ module DiscourseAi # Stop streaming the response as soon as you find a tool. # We'll buffer and yield it later. - has_tool = true if has_tool?(partials_raw) + has_tool = true if allow_tools && has_tool?(partials_raw) if has_tool - function_buffer = add_to_buffer(function_buffer, partials_raw, partial) + if buffered_partials.present? + joined = buffered_partials.join + joined = joined.gsub(/<.+/, "") + yield joined, cancel if joined.present? + buffered_partials = [] + end + function_buffer = add_to_function_buffer(function_buffer, partial: partial) else - response_data << partial - - yield partial, cancel if partial + if maybe_has_tool?(partials_raw) + buffered_partials << partial + else + if buffered_partials.present? + buffered_partials.each do |buffered_partial| + response_data << buffered_partial + yield buffered_partial, cancel + end + buffered_partials = [] + end + response_data << partial + yield partial, cancel if partial + end end rescue JSON::ParserError leftover = redo_chunk @@ -201,7 +221,11 @@ module DiscourseAi # Once we have the full response, try to return the tool as a XML doc. if has_tool + function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw) + if function_buffer.at("tool_name").text.present? + normalize_function_ids!(function_buffer) + invocation = +function_buffer.at("function_calls").to_s invocation << "\n" @@ -226,6 +250,21 @@ module DiscourseAi end end + def normalize_function_ids!(function_buffer) + function_buffer + .css("invoke") + .each_with_index do |invoke, index| + if invoke.at("tool_id") + invoke.at("tool_id").content = "tool_#{index}" if invoke + .at("tool_id") + .content + .blank? + else + invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id") + end + end + end + def final_log_update(log) # for people that need to override end @@ -291,48 +330,36 @@ module DiscourseAi (<<~TEXT).strip - + TEXT end def has_tool?(response) - response.include?("") end - def add_to_buffer(function_buffer, response_data, partial) - raw_data = (response_data + partial) + def maybe_has_tool?(response) + # 16 is the length of function calls + substring = response[-16..-1] || response + split = substring.split("<") - # recover stop word potentially - raw_data = - raw_data.split("").first + "\n" if raw_data.split( - "", - ).length > 1 - - return function_buffer unless raw_data.include?("") - - read_function = Nokogiri::HTML5.fragment(raw_data) - - if tool_name = read_function.at("tool_name")&.text - function_buffer.at("tool_name").inner_html = tool_name - function_buffer.at("tool_id").inner_html = tool_name + if split.length > 1 + match = "<" + split.last + "".start_with?(match) + else + false end + end - read_function - .at("parameters") - &.elements - .to_a - .each do |elem| - if parameter = function_buffer.at(elem.name)&.text - function_buffer.at(elem.name).inner_html = parameter - else - param_node = read_function.at(elem.name) - function_buffer.at("parameters").add_child(param_node) - function_buffer.at("parameters").add_child("\n") - end - end + def add_to_function_buffer(function_buffer, partial: nil, payload: nil) + if payload&.include?("") + matches = payload.match(%r{.*}m) + function_buffer = + Nokogiri::HTML5.fragment(matches[0] + "\n") if matches + end function_buffer end diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index 8668c658..86e189a9 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -104,12 +104,20 @@ module DiscourseAi @has_function_call end - def add_to_buffer(function_buffer, _response_data, partial) - if partial[:name].present? - function_buffer.at("tool_name").content = partial[:name] - function_buffer.at("tool_id").content = partial[:name] + def maybe_has_tool?(_partial_raw) + # we always get a full partial + false + end + + def add_to_function_buffer(function_buffer, payload: nil, partial: nil) + if @streaming_mode + return function_buffer if !partial + else + partial = payload end + function_buffer.at("tool_name").content = partial[:name] if partial[:name].present? + if partial[:args] argument_fragments = partial[:args].reduce(+"") do |memo, (arg_name, value)| diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index e0027324..029384bb 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -163,7 +163,18 @@ module DiscourseAi @has_function_call end - def add_to_buffer(function_buffer, _response_data, partial) + def maybe_has_tool?(_partial_raw) + # we always get a full partial + false + end + + def add_to_function_buffer(function_buffer, partial: nil, payload: nil) + if @streaming_mode + return function_buffer if !partial + else + partial = payload + end + @args_buffer ||= +"" f_name = partial.dig(:function, :name) diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb index d356a238..959b3d14 100644 --- a/lib/completions/prompt.rb +++ b/lib/completions/prompt.rb @@ -49,6 +49,10 @@ module DiscourseAi messages << new_message end + def has_tools? + tools.present? + end + private def validate_message(message) diff --git a/spec/lib/completions/endpoints/anthropic_messages_spec.rb b/spec/lib/completions/endpoints/anthropic_messages_spec.rb index 8f17a85f..1d52e46c 100644 --- a/spec/lib/completions/endpoints/anthropic_messages_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_messages_spec.rb @@ -10,6 +10,36 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do ) end + let(:echo_tool) do + { + name: "echo", + description: "echo something", + parameters: [{ name: "text", type: "string", description: "text to echo", required: true }], + } + end + + let(:google_tool) do + { + name: "google", + description: "google something", + parameters: [ + { name: "query", type: "string", description: "text to google", required: true }, + ], + } + end + + let(:prompt_with_echo_tool) do + prompt_with_tools = prompt + prompt.tools = [echo_tool] + prompt_with_tools + end + + let(:prompt_with_google_tool) do + prompt_with_tools = prompt + prompt.tools = [echo_tool] + prompt_with_tools + end + before { SiteSetting.ai_anthropic_api_key = "123" } it "does not eat spaces with tool calls" do @@ -165,16 +195,18 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body) result = +"" - llm.generate(prompt, user: Discourse.system_user) { |partial| result << partial } + llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial| + result << partial + end expected = (<<~TEXT).strip google - google top 10 things to do in japan for tourists + tool_0 TEXT @@ -232,7 +264,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do expected_body = { model: "claude-3-opus-20240229", max_tokens: 3000, - stop_sequences: [""], messages: [{ role: "user", content: "user1: hello" }], system: "You are hello bot", stream: true, @@ -245,6 +276,70 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do expect(log.response_tokens).to eq(15) end + it "can return multiple function calls" do + functions = <<~FUNCTIONS + + + echo + + something + + + + echo + + something else + + + FUNCTIONS + + body = <<~STRING + { + "content": [ + { + "text": "Hello!\n\n#{functions}\njunk", + "type": "text" + } + ], + "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", + "model": "claude-3-opus-20240229", + "role": "assistant", + "stop_reason": "end_turn", + "stop_sequence": null, + "type": "message", + "usage": { + "input_tokens": 10, + "output_tokens": 25 + } + } + STRING + + stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body) + + result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user) + + expected = (<<~EXPECTED).strip + + + echo + + something + + tool_0 + + + echo + + something else + + tool_1 + + + EXPECTED + + expect(result.strip).to eq(expected) + end + it "can operate in regular mode" do body = <<~STRING { @@ -287,7 +382,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do expected_body = { model: "claude-3-opus-20240229", max_tokens: 3000, - stop_sequences: [""], messages: [{ role: "user", content: "user1: hello" }], system: "You are hello bot", } diff --git a/spec/lib/completions/endpoints/endpoint_compliance.rb b/spec/lib/completions/endpoints/endpoint_compliance.rb index fe0e44fd..31f5e86e 100644 --- a/spec/lib/completions/endpoints/endpoint_compliance.rb +++ b/spec/lib/completions/endpoints/endpoint_compliance.rb @@ -66,11 +66,11 @@ class EndpointMock get_weather - #{tool_id} Sydney c + tool_0 TEXT diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index 8f9cb367..dbc5fe31 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -132,7 +132,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do fab!(:user) - let(:bedrock_mock) { GeminiMock.new(endpoint) } + let(:gemini_mock) { GeminiMock.new(endpoint) } let(:compliance) do EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user) @@ -142,13 +142,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do context "when using regular mode" do context "with simple prompts" do it "completes a trivial prompt and logs the response" do - compliance.regular_mode_simple_prompt(bedrock_mock) + compliance.regular_mode_simple_prompt(gemini_mock) end end context "with tools" do it "returns a function invocation" do - compliance.regular_mode_tools(bedrock_mock) + compliance.regular_mode_tools(gemini_mock) end end end @@ -156,13 +156,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do describe "when using streaming mode" do context "with simple prompts" do it "completes a trivial prompt and logs the response" do - compliance.streaming_mode_simple_prompt(bedrock_mock) + compliance.streaming_mode_simple_prompt(gemini_mock) end end context "with tools" do it "returns a function invocation" do - compliance.streaming_mode_tools(bedrock_mock) + compliance.streaming_mode_tools(gemini_mock) end end end diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index a83ed0a6..fe662ae2 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -102,7 +102,7 @@ class OpenAiMock < EndpointMock end def tool_id - "eujbuebfe" + "tool_0" end def tool_payload @@ -149,6 +149,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do fab!(:user) + let(:echo_tool) do + { + name: "echo", + description: "echo something", + parameters: [{ name: "text", type: "string", description: "text to echo", required: true }], + } + end + + let(:tools) { [echo_tool] } + let(:open_ai_mock) { OpenAiMock.new(endpoint) } let(:compliance) do @@ -260,23 +270,25 @@ TEXT open_ai_mock.stub_raw(raw_data) content = +"" - endpoint.perform_completion!(compliance.dialect, user) { |partial| content << partial } + dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools)) + + endpoint.perform_completion!(dialect, user) { |partial| content << partial } expected = <<~TEXT search - call_3Gyr3HylFJwfrtKrL6NaIit1 Discourse AI bot + call_3Gyr3HylFJwfrtKrL6NaIit1 search - call_H7YkbgYurHpyJqzwUN4bghwN Discourse AI bot + call_H7YkbgYurHpyJqzwUN4bghwN TEXT @@ -321,7 +333,8 @@ TEXT open_ai_mock.stub_raw(chunks) partials = [] - endpoint.perform_completion!(compliance.dialect, user) { |partial| partials << partial } + dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools)) + endpoint.perform_completion!(dialect, user) { |partial| partials << partial } expect(partials.length).to eq(1) @@ -329,10 +342,10 @@ TEXT google - func_id Adabas 9.1 + func_id TXT diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index 0e2cf504..a60a6b16 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -153,6 +153,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do DiscourseAi::AiBot::Personas::Artist, DiscourseAi::AiBot::Personas::Creative, DiscourseAi::AiBot::Personas::DiscourseHelper, + DiscourseAi::AiBot::Personas::GithubHelper, DiscourseAi::AiBot::Personas::Researcher, DiscourseAi::AiBot::Personas::SettingsExplorer, DiscourseAi::AiBot::Personas::SqlHelper, @@ -169,6 +170,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do DiscourseAi::AiBot::Personas::SettingsExplorer, DiscourseAi::AiBot::Personas::Creative, DiscourseAi::AiBot::Personas::DiscourseHelper, + DiscourseAi::AiBot::Personas::GithubHelper, ) AiPersona.find( @@ -182,6 +184,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do DiscourseAi::AiBot::Personas::SettingsExplorer, DiscourseAi::AiBot::Personas::Creative, DiscourseAi::AiBot::Personas::DiscourseHelper, + DiscourseAi::AiBot::Personas::GithubHelper, ) end end diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index f9d444de..d6524108 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -337,6 +337,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do end end + describe "#available_bot_usernames" do + it "includes persona users" do + persona = Fabricate(:ai_persona) + persona.create_user! + + expect(playground.available_bot_usernames).to include(persona.user.username) + end + end + describe "#conversation_context" do context "with limited context" do before do diff --git a/spec/lib/modules/ai_bot/tools/github_file_content_spec.rb b/spec/lib/modules/ai_bot/tools/github_file_content_spec.rb new file mode 100644 index 00000000..aa2263bd --- /dev/null +++ b/spec/lib/modules/ai_bot/tools/github_file_content_spec.rb @@ -0,0 +1,78 @@ +# frozen_string_literal: true + +require "rails_helper" + +RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do + let(:tool) do + described_class.new( + { + repo_name: "discourse/discourse-ai", + file_paths: %w[lib/database/connection.rb lib/ai_bot/tools/github_pull_request_diff.rb], + branch: "8b382d6098fde879d28bbee68d3cbe0a193e4ffc", + }, + ) + end + + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") } + + describe "#invoke" do + before do + stub_request( + :get, + "https://api.github.com/repos/discourse/discourse-ai/contents/lib/database/connection.rb?ref=8b382d6098fde879d28bbee68d3cbe0a193e4ffc", + ).to_return( + status: 200, + body: { content: Base64.encode64("content of connection.rb") }.to_json, + ) + + stub_request( + :get, + "https://api.github.com/repos/discourse/discourse-ai/contents/lib/ai_bot/tools/github_pull_request_diff.rb?ref=8b382d6098fde879d28bbee68d3cbe0a193e4ffc", + ).to_return( + status: 200, + body: { content: Base64.encode64("content of github_pull_request_diff.rb") }.to_json, + ) + end + + it "retrieves the content of the specified GitHub files" do + result = tool.invoke(nil, llm) + expected = { + file_contents: + "File Path: lib/database/connection.rb:\ncontent of connection.rb\nFile Path: lib/ai_bot/tools/github_pull_request_diff.rb:\ncontent of github_pull_request_diff.rb", + } + + expect(result).to eq(expected) + end + end + + describe ".signature" do + it "returns the tool signature" do + signature = described_class.signature + expect(signature[:name]).to eq("github_file_content") + expect(signature[:description]).to eq("Retrieves the content of specified GitHub files") + expect(signature[:parameters]).to eq( + [ + { + name: "repo_name", + description: "The name of the GitHub repository (e.g., 'discourse/discourse')", + type: "string", + required: true, + }, + { + name: "file_paths", + description: "The paths of the files to retrieve within the repository", + type: "array", + item_type: "string", + required: true, + }, + { + name: "branch", + description: "The branch or commit SHA to retrieve the files from (default: 'main')", + type: "string", + required: false, + }, + ], + ) + end + end +end diff --git a/spec/lib/modules/ai_bot/tools/github_pull_request_diff_spec.rb b/spec/lib/modules/ai_bot/tools/github_pull_request_diff_spec.rb new file mode 100644 index 00000000..4260643c --- /dev/null +++ b/spec/lib/modules/ai_bot/tools/github_pull_request_diff_spec.rb @@ -0,0 +1,61 @@ +# frozen_string_literal: true + +require "rails_helper" + +RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do + let(:tool) { described_class.new({ repo: repo, pull_id: pull_id }) } + let(:bot_user) { Fabricate(:user) } + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") } + + context "with a valid pull request" do + let(:repo) { "discourse/discourse-automation" } + let(:pull_id) { 253 } + + it "retrieves the diff for the pull request" do + stub_request(:get, "https://api.github.com/repos/#{repo}/pulls/#{pull_id}").with( + headers: { + "Accept" => "application/vnd.github.v3.diff", + "User-Agent" => DiscourseAi::AiBot::USER_AGENT, + }, + ).to_return(status: 200, body: "sample diff") + + result = tool.invoke(bot_user, llm) + expect(result[:diff]).to eq("sample diff") + expect(result[:error]).to be_nil + end + + it "uses the github access token if present" do + SiteSetting.ai_bot_github_access_token = "ABC" + + stub_request(:get, "https://api.github.com/repos/#{repo}/pulls/#{pull_id}").with( + headers: { + "Accept" => "application/vnd.github.v3.diff", + "User-Agent" => DiscourseAi::AiBot::USER_AGENT, + "Authorization" => "Bearer ABC", + }, + ).to_return(status: 200, body: "sample diff") + + result = tool.invoke(bot_user, llm) + expect(result[:diff]).to eq("sample diff") + expect(result[:error]).to be_nil + end + end + + context "with an invalid pull request" do + let(:repo) { "invalid/repo" } + let(:pull_id) { 999 } + + it "returns an error message" do + stub_request(:get, "https://api.github.com/repos/#{repo}/pulls/#{pull_id}").with( + headers: { + "Accept" => "application/vnd.github.v3.diff", + "User-Agent" => DiscourseAi::AiBot::USER_AGENT, + }, + ).to_return(status: 404) + + result = tool.invoke(bot_user, nil) + expect(result[:diff]).to be_nil + expect(result[:error]).to include("Failed to retrieve the diff") + end + end +end diff --git a/spec/lib/modules/ai_bot/tools/github_search_code_spec.rb b/spec/lib/modules/ai_bot/tools/github_search_code_spec.rb new file mode 100644 index 00000000..789ad8c2 --- /dev/null +++ b/spec/lib/modules/ai_bot/tools/github_search_code_spec.rb @@ -0,0 +1,93 @@ +# frozen_string_literal: true + +require "rails_helper" + +RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do + let(:tool) { described_class.new({ repo: repo, query: query }) } + let(:bot_user) { Fabricate(:user) } + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") } + + context "with valid search results" do + let(:repo) { "discourse/discourse" } + let(:query) { "def hello" } + + it "searches for code in the repository" do + stub_request( + :get, + "https://api.github.com/search/code?q=def%20hello+repo:discourse/discourse", + ).with( + headers: { + "Accept" => "application/vnd.github.v3.text-match+json", + "User-Agent" => DiscourseAi::AiBot::USER_AGENT, + }, + ).to_return( + status: 200, + body: { + total_count: 1, + items: [ + { + path: "test/hello.rb", + name: "hello.rb", + text_matches: [{ fragment: "def hello\n puts 'hello'\nend" }], + }, + ], + }.to_json, + ) + + result = tool.invoke(bot_user, llm) + expect(result[:search_results]).to include("def hello\n puts 'hello'\nend") + expect(result[:search_results]).to include("test/hello.rb") + expect(result[:error]).to be_nil + end + end + + context "with an empty search result" do + let(:repo) { "discourse/discourse" } + let(:query) { "nonexistent_method" } + + describe "#description_args" do + it "returns the repo and query" do + expect(tool.description_args).to eq(repo: repo, query: query) + end + end + + it "returns an empty result" do + SiteSetting.ai_bot_github_access_token = "ABC" + stub_request( + :get, + "https://api.github.com/search/code?q=nonexistent_method+repo:discourse/discourse", + ).with( + headers: { + "Accept" => "application/vnd.github.v3.text-match+json", + "User-Agent" => DiscourseAi::AiBot::USER_AGENT, + "Authorization" => "Bearer ABC", + }, + ).to_return(status: 200, body: { total_count: 0, items: [] }.to_json) + + result = tool.invoke(bot_user, llm) + expect(result[:search_results]).to be_empty + expect(result[:error]).to be_nil + end + end + + context "with an error response" do + let(:repo) { "discourse/discourse" } + let(:query) { "def hello" } + + it "returns an error message" do + stub_request( + :get, + "https://api.github.com/search/code?q=def%20hello+repo:discourse/discourse", + ).with( + headers: { + "Accept" => "application/vnd.github.v3.text-match+json", + "User-Agent" => DiscourseAi::AiBot::USER_AGENT, + }, + ).to_return(status: 403) + + result = tool.invoke(bot_user, llm) + expect(result[:search_results]).to be_nil + expect(result[:error]).to include("Failed to perform code search") + end + end +end