FEATURE: Add GitHub Helper AI Bot persona and tools (#513)
Introduces a new AI Bot persona called 'GitHub Helper' which is specialized in assisting with GitHub-related tasks and questions. It includes the following key changes: - Implements the GitHub Helper persona class with its system prompt and available tools - Adds three new AI Bot tools for GitHub interactions: - github_file_content: Retrieves content of files from a GitHub repository - github_pull_request_diff: Retrieves the diff for a GitHub pull request - github_search_code: Searches for code in a GitHub repository - Updates the AI Bot dialects to support the new GitHub tools - Implements multiple function calls for standard tool dialect
This commit is contained in:
parent
176a4458f2
commit
2ad743d246
|
@ -96,6 +96,7 @@ en:
|
||||||
ai_bot_allowed_groups: "When the GPT Bot has access to the PM, it will reply to members of these groups."
|
ai_bot_allowed_groups: "When the GPT Bot has access to the PM, it will reply to members of these groups."
|
||||||
ai_bot_enabled_chat_bots: "Available models to act as an AI Bot"
|
ai_bot_enabled_chat_bots: "Available models to act as an AI Bot"
|
||||||
ai_bot_add_to_header: "Display a button in the header to start a PM with a AI Bot"
|
ai_bot_add_to_header: "Display a button in the header to start a PM with a AI Bot"
|
||||||
|
ai_bot_github_access_token: "GitHub access token for use with GitHub AI tools (required for search support)"
|
||||||
|
|
||||||
ai_stability_api_key: "API key for the stability.ai API"
|
ai_stability_api_key: "API key for the stability.ai API"
|
||||||
ai_stability_engine: "Image generation engine to use for the stability.ai API"
|
ai_stability_engine: "Image generation engine to use for the stability.ai API"
|
||||||
|
@ -154,6 +155,9 @@ en:
|
||||||
personas:
|
personas:
|
||||||
cannot_delete_system_persona: "System personas cannot be deleted, please disable it instead"
|
cannot_delete_system_persona: "System personas cannot be deleted, please disable it instead"
|
||||||
cannot_edit_system_persona: "System personas can only be renamed, you may not edit commands or system prompt, instead disable and make a copy"
|
cannot_edit_system_persona: "System personas can only be renamed, you may not edit commands or system prompt, instead disable and make a copy"
|
||||||
|
github_helper:
|
||||||
|
name: "GitHub Helper"
|
||||||
|
description: "AI Bot specialized in assisting with GitHub-related tasks and questions"
|
||||||
general:
|
general:
|
||||||
name: Forum Helper
|
name: Forum Helper
|
||||||
description: "General purpose AI Bot capable of performing various tasks"
|
description: "General purpose AI Bot capable of performing various tasks"
|
||||||
|
@ -190,6 +194,9 @@ en:
|
||||||
name: "Base Search Query"
|
name: "Base Search Query"
|
||||||
description: "Base query to use when searching. Example: '#urgent' will prepend '#urgent' to the search query and only include topics with the urgent category or tag."
|
description: "Base query to use when searching. Example: '#urgent' will prepend '#urgent' to the search query and only include topics with the urgent category or tag."
|
||||||
command_summary:
|
command_summary:
|
||||||
|
github_search_code: "GitHub code search"
|
||||||
|
github_file_content: "GitHub file content"
|
||||||
|
github_pull_request_diff: "GitHub pull request diff"
|
||||||
random_picker: "Random Picker"
|
random_picker: "Random Picker"
|
||||||
categories: "List categories"
|
categories: "List categories"
|
||||||
search: "Search"
|
search: "Search"
|
||||||
|
@ -205,6 +212,9 @@ en:
|
||||||
dall_e: "Generate image"
|
dall_e: "Generate image"
|
||||||
search_meta_discourse: "Search Meta Discourse"
|
search_meta_discourse: "Search Meta Discourse"
|
||||||
command_help:
|
command_help:
|
||||||
|
github_search_code: "Search for code in a GitHub repository"
|
||||||
|
github_file_content: "Retrieve content of files from a GitHub repository"
|
||||||
|
github_pull_request_diff: "Retrieve a GitHub pull request diff"
|
||||||
random_picker: "Pick a random number or a random element of a list"
|
random_picker: "Pick a random number or a random element of a list"
|
||||||
categories: "List all publicly visible categories on the forum"
|
categories: "List all publicly visible categories on the forum"
|
||||||
search: "Search all public topics on the forum"
|
search: "Search all public topics on the forum"
|
||||||
|
@ -220,6 +230,9 @@ en:
|
||||||
dall_e: "Generate image using DALL-E 3"
|
dall_e: "Generate image using DALL-E 3"
|
||||||
search_meta_discourse: "Search Meta Discourse"
|
search_meta_discourse: "Search Meta Discourse"
|
||||||
command_description:
|
command_description:
|
||||||
|
github_search_code: "Searched for '%{query}' in %{repo}"
|
||||||
|
github_pull_request_diff: "<a href='%{url}'>%{repo} %{pull_id}</a>"
|
||||||
|
github_file_content: "Retrieved content of %{file_paths} from %{repo_name}@%{branch}"
|
||||||
random_picker: "Picking from %{options}, picked: %{result}"
|
random_picker: "Picking from %{options}, picked: %{result}"
|
||||||
read: "Reading: <a href='%{url}'>%{title}</a>"
|
read: "Reading: <a href='%{url}'>%{title}</a>"
|
||||||
time: "Time in %{timezone} is %{time}"
|
time: "Time in %{timezone} is %{time}"
|
||||||
|
|
|
@ -318,3 +318,6 @@ discourse_ai:
|
||||||
ai_bot_add_to_header:
|
ai_bot_add_to_header:
|
||||||
default: true
|
default: true
|
||||||
client: true
|
client: true
|
||||||
|
ai_bot_github_access_token:
|
||||||
|
default: ""
|
||||||
|
secret: true
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
|
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module AiBot
|
module AiBot
|
||||||
|
USER_AGENT = "Discourse AI Bot 1.0 (https://www.discourse.org)"
|
||||||
|
|
||||||
class EntryPoint
|
class EntryPoint
|
||||||
REQUIRE_TITLE_UPDATE = "discourse-ai-title-update"
|
REQUIRE_TITLE_UPDATE = "discourse-ai-title-update"
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
module Personas
|
||||||
|
class GithubHelper < Persona
|
||||||
|
def tools
|
||||||
|
[Tools::GithubFileContent, Tools::GithubPullRequestDiff, Tools::GithubSearchCode]
|
||||||
|
end
|
||||||
|
|
||||||
|
def system_prompt
|
||||||
|
<<~PROMPT
|
||||||
|
You are a helpful GitHub assistant.
|
||||||
|
You _understand_ and **generate** Discourse Flavored Markdown.
|
||||||
|
You live in a Discourse Forum Message.
|
||||||
|
|
||||||
|
Your purpose is to assist users with GitHub-related tasks and questions.
|
||||||
|
When asked about a specific repository, pull request, or file, try to use the available tools to provide accurate and helpful information.
|
||||||
|
PROMPT
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -15,6 +15,7 @@ module DiscourseAi
|
||||||
Personas::Creative => -6,
|
Personas::Creative => -6,
|
||||||
Personas::DallE3 => -7,
|
Personas::DallE3 => -7,
|
||||||
Personas::DiscourseHelper => -8,
|
Personas::DiscourseHelper => -8,
|
||||||
|
Personas::GithubHelper => -9,
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -64,8 +65,12 @@ module DiscourseAi
|
||||||
Tools::SettingContext,
|
Tools::SettingContext,
|
||||||
Tools::RandomPicker,
|
Tools::RandomPicker,
|
||||||
Tools::DiscourseMetaSearch,
|
Tools::DiscourseMetaSearch,
|
||||||
|
Tools::GithubFileContent,
|
||||||
|
Tools::GithubPullRequestDiff,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
tools << Tools::GithubSearchCode if SiteSetting.ai_bot_github_access_token.present?
|
||||||
|
|
||||||
tools << Tools::ListTags if SiteSetting.tagging_enabled
|
tools << Tools::ListTags if SiteSetting.tagging_enabled
|
||||||
tools << Tools::Image if SiteSetting.ai_stability_api_key.present?
|
tools << Tools::Image if SiteSetting.ai_stability_api_key.present?
|
||||||
|
|
||||||
|
@ -162,7 +167,7 @@ module DiscourseAi
|
||||||
|
|
||||||
tool_klass.new(
|
tool_klass.new(
|
||||||
arguments,
|
arguments,
|
||||||
tool_call_id: function_id,
|
tool_call_id: function_id || function_name,
|
||||||
persona_options: options[tool_klass].to_h,
|
persona_options: options[tool_klass].to_h,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
|
@ -276,6 +276,14 @@ module DiscourseAi
|
||||||
publish_final_update(reply_post) if stream_reply
|
publish_final_update(reply_post) if stream_reply
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def available_bot_usernames
|
||||||
|
@bot_usernames ||=
|
||||||
|
AiPersona
|
||||||
|
.joins(:user)
|
||||||
|
.pluck(:username)
|
||||||
|
.concat(DiscourseAi::AiBot::EntryPoint::BOTS.map(&:second))
|
||||||
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
def publish_final_update(reply_post)
|
def publish_final_update(reply_post)
|
||||||
|
@ -348,10 +356,6 @@ module DiscourseAi
|
||||||
max_backlog_age: 60,
|
max_backlog_age: 60,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
def available_bot_usernames
|
|
||||||
@bot_usernames ||= DiscourseAi::AiBot::EntryPoint::BOTS.map(&:second)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
module Tools
|
||||||
|
class GithubFileContent < Tool
|
||||||
|
def self.signature
|
||||||
|
{
|
||||||
|
name: name,
|
||||||
|
description: "Retrieves the content of specified GitHub files",
|
||||||
|
parameters: [
|
||||||
|
{
|
||||||
|
name: "repo_name",
|
||||||
|
description: "The name of the GitHub repository (e.g., 'discourse/discourse')",
|
||||||
|
type: "string",
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "file_paths",
|
||||||
|
description: "The paths of the files to retrieve within the repository",
|
||||||
|
type: "array",
|
||||||
|
item_type: "string",
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "branch",
|
||||||
|
description:
|
||||||
|
"The branch or commit SHA to retrieve the files from (default: 'main')",
|
||||||
|
type: "string",
|
||||||
|
required: false,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.name
|
||||||
|
"github_file_content"
|
||||||
|
end
|
||||||
|
|
||||||
|
def repo_name
|
||||||
|
parameters[:repo_name]
|
||||||
|
end
|
||||||
|
|
||||||
|
def file_paths
|
||||||
|
parameters[:file_paths]
|
||||||
|
end
|
||||||
|
|
||||||
|
def branch
|
||||||
|
parameters[:branch] || "main"
|
||||||
|
end
|
||||||
|
|
||||||
|
def description_args
|
||||||
|
{ repo_name: repo_name, file_paths: file_paths.join(", "), branch: branch }
|
||||||
|
end
|
||||||
|
|
||||||
|
def invoke(_bot_user, llm)
|
||||||
|
owner, repo = repo_name.split("/")
|
||||||
|
file_contents = {}
|
||||||
|
missing_files = []
|
||||||
|
|
||||||
|
file_paths.each do |file_path|
|
||||||
|
api_url =
|
||||||
|
"https://api.github.com/repos/#{owner}/#{repo}/contents/#{file_path}?ref=#{branch}"
|
||||||
|
|
||||||
|
response =
|
||||||
|
send_http_request(
|
||||||
|
api_url,
|
||||||
|
headers: {
|
||||||
|
"Accept" => "application/vnd.github.v3+json",
|
||||||
|
},
|
||||||
|
authenticate_github: true,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.code == "200"
|
||||||
|
file_data = JSON.parse(response.body)
|
||||||
|
content = Base64.decode64(file_data["content"])
|
||||||
|
file_contents[file_path] = content
|
||||||
|
else
|
||||||
|
missing_files << file_path
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
unless file_contents.empty?
|
||||||
|
blob =
|
||||||
|
file_contents.map { |path, content| "File Path: #{path}:\n#{content}" }.join("\n")
|
||||||
|
truncated_blob = truncate(blob, max_length: 20_000, percent_length: 0.3, llm: llm)
|
||||||
|
result[:file_contents] = truncated_blob
|
||||||
|
end
|
||||||
|
|
||||||
|
result[:missing_files] = missing_files unless missing_files.empty?
|
||||||
|
|
||||||
|
result.empty? ? { error: "No files found or retrieved." } : result
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,72 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
module Tools
|
||||||
|
class GithubPullRequestDiff < Tool
|
||||||
|
def self.signature
|
||||||
|
{
|
||||||
|
name: name,
|
||||||
|
description: "Retrieves the diff for a GitHub pull request",
|
||||||
|
parameters: [
|
||||||
|
{
|
||||||
|
name: "repo",
|
||||||
|
description: "The repository name in the format 'owner/repo'",
|
||||||
|
type: "string",
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pull_id",
|
||||||
|
description: "The ID of the pull request",
|
||||||
|
type: "integer",
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.name
|
||||||
|
"github_pull_request_diff"
|
||||||
|
end
|
||||||
|
|
||||||
|
def repo
|
||||||
|
parameters[:repo]
|
||||||
|
end
|
||||||
|
|
||||||
|
def pull_id
|
||||||
|
parameters[:pull_id]
|
||||||
|
end
|
||||||
|
|
||||||
|
def url
|
||||||
|
@url
|
||||||
|
end
|
||||||
|
|
||||||
|
def invoke(_bot_user, llm)
|
||||||
|
api_url = "https://api.github.com/repos/#{repo}/pulls/#{pull_id}"
|
||||||
|
@url = "https://github.com/#{repo}/pull/#{pull_id}"
|
||||||
|
|
||||||
|
response =
|
||||||
|
send_http_request(
|
||||||
|
api_url,
|
||||||
|
headers: {
|
||||||
|
"Accept" => "application/vnd.github.v3.diff",
|
||||||
|
},
|
||||||
|
authenticate_github: true,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.code == "200"
|
||||||
|
diff = response.body
|
||||||
|
diff = truncate(diff, max_length: 20_000, percent_length: 0.3, llm: llm)
|
||||||
|
{ diff: diff }
|
||||||
|
else
|
||||||
|
{ error: "Failed to retrieve the diff. Status code: #{response.code}" }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def description_args
|
||||||
|
{ repo: repo, pull_id: pull_id, url: url }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,72 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
module Tools
|
||||||
|
class GithubSearchCode < Tool
|
||||||
|
def self.signature
|
||||||
|
{
|
||||||
|
name: name,
|
||||||
|
description: "Searches for code in a GitHub repository",
|
||||||
|
parameters: [
|
||||||
|
{
|
||||||
|
name: "repo",
|
||||||
|
description: "The repository name in the format 'owner/repo'",
|
||||||
|
type: "string",
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "query",
|
||||||
|
description: "The search query (e.g., a function name, variable, or code snippet)",
|
||||||
|
type: "string",
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.name
|
||||||
|
"github_search_code"
|
||||||
|
end
|
||||||
|
|
||||||
|
def repo
|
||||||
|
parameters[:repo]
|
||||||
|
end
|
||||||
|
|
||||||
|
def query
|
||||||
|
parameters[:query]
|
||||||
|
end
|
||||||
|
|
||||||
|
def description_args
|
||||||
|
{ repo: repo, query: query }
|
||||||
|
end
|
||||||
|
|
||||||
|
def invoke(_bot_user, llm)
|
||||||
|
api_url = "https://api.github.com/search/code?q=#{query}+repo:#{repo}"
|
||||||
|
|
||||||
|
response =
|
||||||
|
send_http_request(
|
||||||
|
api_url,
|
||||||
|
headers: {
|
||||||
|
"Accept" => "application/vnd.github.v3.text-match+json",
|
||||||
|
},
|
||||||
|
authenticate_github: true,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.code == "200"
|
||||||
|
search_data = JSON.parse(response.body)
|
||||||
|
results =
|
||||||
|
search_data["items"]
|
||||||
|
.map { |item| "#{item["path"]}:\n#{item["text_matches"][0]["fragment"]}" }
|
||||||
|
.join("\n---\n")
|
||||||
|
|
||||||
|
results = truncate(results, max_length: 20_000, percent_length: 0.3, llm: llm)
|
||||||
|
{ search_results: results }
|
||||||
|
else
|
||||||
|
{ error: "Failed to perform code search. Status code: #{response.code}" }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -77,6 +77,35 @@ module DiscourseAi
|
||||||
|
|
||||||
protected
|
protected
|
||||||
|
|
||||||
|
def send_http_request(url, headers: {}, authenticate_github: false)
|
||||||
|
uri = URI(url)
|
||||||
|
request = FinalDestination::HTTP::Get.new(uri)
|
||||||
|
request["User-Agent"] = DiscourseAi::AiBot::USER_AGENT
|
||||||
|
headers.each { |k, v| request[k] = v }
|
||||||
|
if authenticate_github
|
||||||
|
request["Authorization"] = "Bearer #{SiteSetting.ai_bot_github_access_token}"
|
||||||
|
end
|
||||||
|
|
||||||
|
FinalDestination::HTTP.start(uri.hostname, uri.port, use_ssl: uri.port != 80) do |http|
|
||||||
|
http.request(request)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def truncate(text, llm:, percent_length: nil, max_length: nil)
|
||||||
|
if !percent_length && !max_length
|
||||||
|
raise ArgumentError, "You must provide either percent_length or max_length"
|
||||||
|
end
|
||||||
|
|
||||||
|
target = llm.max_prompt_tokens
|
||||||
|
target = (target * percent_length).to_i if percent_length
|
||||||
|
|
||||||
|
if max_length
|
||||||
|
target = max_length if target > max_length
|
||||||
|
end
|
||||||
|
|
||||||
|
llm.tokenizer.truncate(text, target)
|
||||||
|
end
|
||||||
|
|
||||||
def accepted_options
|
def accepted_options
|
||||||
[]
|
[]
|
||||||
end
|
end
|
||||||
|
|
|
@ -36,7 +36,8 @@ module DiscourseAi
|
||||||
def tool_preamble
|
def tool_preamble
|
||||||
<<~TEXT
|
<<~TEXT
|
||||||
In this environment you have access to a set of tools you can use to answer the user's question.
|
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||||
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
|
You may call them like this.
|
||||||
|
|
||||||
<function_calls>
|
<function_calls>
|
||||||
<invoke>
|
<invoke>
|
||||||
<tool_name>$TOOL_NAME</tool_name>
|
<tool_name>$TOOL_NAME</tool_name>
|
||||||
|
@ -47,9 +48,12 @@ module DiscourseAi
|
||||||
</invoke>
|
</invoke>
|
||||||
</function_calls>
|
</function_calls>
|
||||||
|
|
||||||
if a parameter type is an array, return a JSON array of values. For example:
|
If a parameter type is an array, return a JSON array of values. For example:
|
||||||
[1,"two",3.0]
|
[1,"two",3.0]
|
||||||
|
|
||||||
|
Always wrap <invoke> calls in <function_calls> tags.
|
||||||
|
You may call multiple function via <invoke> in a single <function_calls> block.
|
||||||
|
|
||||||
Here are the tools available:
|
Here are the tools available:
|
||||||
TEXT
|
TEXT
|
||||||
end
|
end
|
||||||
|
|
|
@ -27,8 +27,11 @@ module DiscourseAi
|
||||||
model_params
|
model_params
|
||||||
end
|
end
|
||||||
|
|
||||||
def default_options
|
def default_options(dialect)
|
||||||
{ model: model + "-20240229", max_tokens: 3_000, stop_sequences: ["</function_calls>"] }
|
options = { model: model + "-20240229", max_tokens: 3_000 }
|
||||||
|
|
||||||
|
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
|
||||||
|
options
|
||||||
end
|
end
|
||||||
|
|
||||||
def provider_id
|
def provider_id
|
||||||
|
@ -46,8 +49,8 @@ module DiscourseAi
|
||||||
@uri ||= URI("https://api.anthropic.com/v1/messages")
|
@uri ||= URI("https://api.anthropic.com/v1/messages")
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, _dialect)
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
payload = default_options.merge(model_params).merge(messages: prompt.messages)
|
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
|
||||||
|
|
||||||
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
||||||
payload[:stream] = true if @streaming_mode
|
payload[:stream] = true if @streaming_mode
|
||||||
|
|
|
@ -64,6 +64,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def perform_completion!(dialect, user, model_params = {})
|
def perform_completion!(dialect, user, model_params = {})
|
||||||
|
allow_tools = dialect.prompt.has_tools?
|
||||||
model_params = normalize_model_params(model_params)
|
model_params = normalize_model_params(model_params)
|
||||||
|
|
||||||
@streaming_mode = block_given?
|
@streaming_mode = block_given?
|
||||||
|
@ -110,9 +111,11 @@ module DiscourseAi
|
||||||
response_data = extract_completion_from(response_raw)
|
response_data = extract_completion_from(response_raw)
|
||||||
partials_raw = response_data.to_s
|
partials_raw = response_data.to_s
|
||||||
|
|
||||||
if has_tool?(response_data)
|
if allow_tools && has_tool?(response_data)
|
||||||
function_buffer = build_buffer # Nokogiri document
|
function_buffer = build_buffer # Nokogiri document
|
||||||
function_buffer = add_to_buffer(function_buffer, "", response_data)
|
function_buffer = add_to_function_buffer(function_buffer, payload: response_data)
|
||||||
|
|
||||||
|
normalize_function_ids!(function_buffer)
|
||||||
|
|
||||||
response_data = +function_buffer.at("function_calls").to_s
|
response_data = +function_buffer.at("function_calls").to_s
|
||||||
response_data << "\n"
|
response_data << "\n"
|
||||||
|
@ -156,6 +159,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
json_error = false
|
json_error = false
|
||||||
|
buffered_partials = []
|
||||||
|
|
||||||
raw_partials.each do |raw_partial|
|
raw_partials.each do |raw_partial|
|
||||||
json_error = false
|
json_error = false
|
||||||
|
@ -173,15 +177,31 @@ module DiscourseAi
|
||||||
|
|
||||||
# Stop streaming the response as soon as you find a tool.
|
# Stop streaming the response as soon as you find a tool.
|
||||||
# We'll buffer and yield it later.
|
# We'll buffer and yield it later.
|
||||||
has_tool = true if has_tool?(partials_raw)
|
has_tool = true if allow_tools && has_tool?(partials_raw)
|
||||||
|
|
||||||
if has_tool
|
if has_tool
|
||||||
function_buffer = add_to_buffer(function_buffer, partials_raw, partial)
|
if buffered_partials.present?
|
||||||
|
joined = buffered_partials.join
|
||||||
|
joined = joined.gsub(/<.+/, "")
|
||||||
|
yield joined, cancel if joined.present?
|
||||||
|
buffered_partials = []
|
||||||
|
end
|
||||||
|
function_buffer = add_to_function_buffer(function_buffer, partial: partial)
|
||||||
else
|
else
|
||||||
|
if maybe_has_tool?(partials_raw)
|
||||||
|
buffered_partials << partial
|
||||||
|
else
|
||||||
|
if buffered_partials.present?
|
||||||
|
buffered_partials.each do |buffered_partial|
|
||||||
|
response_data << buffered_partial
|
||||||
|
yield buffered_partial, cancel
|
||||||
|
end
|
||||||
|
buffered_partials = []
|
||||||
|
end
|
||||||
response_data << partial
|
response_data << partial
|
||||||
|
|
||||||
yield partial, cancel if partial
|
yield partial, cancel if partial
|
||||||
end
|
end
|
||||||
|
end
|
||||||
rescue JSON::ParserError
|
rescue JSON::ParserError
|
||||||
leftover = redo_chunk
|
leftover = redo_chunk
|
||||||
json_error = true
|
json_error = true
|
||||||
|
@ -201,7 +221,11 @@ module DiscourseAi
|
||||||
|
|
||||||
# Once we have the full response, try to return the tool as a XML doc.
|
# Once we have the full response, try to return the tool as a XML doc.
|
||||||
if has_tool
|
if has_tool
|
||||||
|
function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw)
|
||||||
|
|
||||||
if function_buffer.at("tool_name").text.present?
|
if function_buffer.at("tool_name").text.present?
|
||||||
|
normalize_function_ids!(function_buffer)
|
||||||
|
|
||||||
invocation = +function_buffer.at("function_calls").to_s
|
invocation = +function_buffer.at("function_calls").to_s
|
||||||
invocation << "\n"
|
invocation << "\n"
|
||||||
|
|
||||||
|
@ -226,6 +250,21 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def normalize_function_ids!(function_buffer)
|
||||||
|
function_buffer
|
||||||
|
.css("invoke")
|
||||||
|
.each_with_index do |invoke, index|
|
||||||
|
if invoke.at("tool_id")
|
||||||
|
invoke.at("tool_id").content = "tool_#{index}" if invoke
|
||||||
|
.at("tool_id")
|
||||||
|
.content
|
||||||
|
.blank?
|
||||||
|
else
|
||||||
|
invoke.add_child("<tool_id>tool_#{index}</tool_id>\n") if !invoke.at("tool_id")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
def final_log_update(log)
|
def final_log_update(log)
|
||||||
# for people that need to override
|
# for people that need to override
|
||||||
end
|
end
|
||||||
|
@ -291,49 +330,37 @@ module DiscourseAi
|
||||||
(<<~TEXT).strip
|
(<<~TEXT).strip
|
||||||
<invoke>
|
<invoke>
|
||||||
<tool_name></tool_name>
|
<tool_name></tool_name>
|
||||||
<tool_id></tool_id>
|
|
||||||
<parameters>
|
<parameters>
|
||||||
</parameters>
|
</parameters>
|
||||||
|
<tool_id></tool_id>
|
||||||
</invoke>
|
</invoke>
|
||||||
TEXT
|
TEXT
|
||||||
end
|
end
|
||||||
|
|
||||||
def has_tool?(response)
|
def has_tool?(response)
|
||||||
response.include?("<function")
|
response.include?("<function_calls>")
|
||||||
end
|
end
|
||||||
|
|
||||||
def add_to_buffer(function_buffer, response_data, partial)
|
def maybe_has_tool?(response)
|
||||||
raw_data = (response_data + partial)
|
# 16 is the length of function calls
|
||||||
|
substring = response[-16..-1] || response
|
||||||
|
split = substring.split("<")
|
||||||
|
|
||||||
# recover stop word potentially
|
if split.length > 1
|
||||||
raw_data =
|
match = "<" + split.last
|
||||||
raw_data.split("</invoke>").first + "</invoke>\n</function_calls>" if raw_data.split(
|
"<function_calls>".start_with?(match)
|
||||||
"</invoke>",
|
|
||||||
).length > 1
|
|
||||||
|
|
||||||
return function_buffer unless raw_data.include?("</invoke>")
|
|
||||||
|
|
||||||
read_function = Nokogiri::HTML5.fragment(raw_data)
|
|
||||||
|
|
||||||
if tool_name = read_function.at("tool_name")&.text
|
|
||||||
function_buffer.at("tool_name").inner_html = tool_name
|
|
||||||
function_buffer.at("tool_id").inner_html = tool_name
|
|
||||||
end
|
|
||||||
|
|
||||||
read_function
|
|
||||||
.at("parameters")
|
|
||||||
&.elements
|
|
||||||
.to_a
|
|
||||||
.each do |elem|
|
|
||||||
if parameter = function_buffer.at(elem.name)&.text
|
|
||||||
function_buffer.at(elem.name).inner_html = parameter
|
|
||||||
else
|
else
|
||||||
param_node = read_function.at(elem.name)
|
false
|
||||||
function_buffer.at("parameters").add_child(param_node)
|
|
||||||
function_buffer.at("parameters").add_child("\n")
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||||
|
if payload&.include?("</invoke>")
|
||||||
|
matches = payload.match(%r{<function_calls>.*</invoke>}m)
|
||||||
|
function_buffer =
|
||||||
|
Nokogiri::HTML5.fragment(matches[0] + "\n</function_calls>") if matches
|
||||||
|
end
|
||||||
|
|
||||||
function_buffer
|
function_buffer
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -104,12 +104,20 @@ module DiscourseAi
|
||||||
@has_function_call
|
@has_function_call
|
||||||
end
|
end
|
||||||
|
|
||||||
def add_to_buffer(function_buffer, _response_data, partial)
|
def maybe_has_tool?(_partial_raw)
|
||||||
if partial[:name].present?
|
# we always get a full partial
|
||||||
function_buffer.at("tool_name").content = partial[:name]
|
false
|
||||||
function_buffer.at("tool_id").content = partial[:name]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
|
||||||
|
if @streaming_mode
|
||||||
|
return function_buffer if !partial
|
||||||
|
else
|
||||||
|
partial = payload
|
||||||
|
end
|
||||||
|
|
||||||
|
function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
|
||||||
|
|
||||||
if partial[:args]
|
if partial[:args]
|
||||||
argument_fragments =
|
argument_fragments =
|
||||||
partial[:args].reduce(+"") do |memo, (arg_name, value)|
|
partial[:args].reduce(+"") do |memo, (arg_name, value)|
|
||||||
|
|
|
@ -163,7 +163,18 @@ module DiscourseAi
|
||||||
@has_function_call
|
@has_function_call
|
||||||
end
|
end
|
||||||
|
|
||||||
def add_to_buffer(function_buffer, _response_data, partial)
|
def maybe_has_tool?(_partial_raw)
|
||||||
|
# we always get a full partial
|
||||||
|
false
|
||||||
|
end
|
||||||
|
|
||||||
|
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||||
|
if @streaming_mode
|
||||||
|
return function_buffer if !partial
|
||||||
|
else
|
||||||
|
partial = payload
|
||||||
|
end
|
||||||
|
|
||||||
@args_buffer ||= +""
|
@args_buffer ||= +""
|
||||||
|
|
||||||
f_name = partial.dig(:function, :name)
|
f_name = partial.dig(:function, :name)
|
||||||
|
|
|
@ -49,6 +49,10 @@ module DiscourseAi
|
||||||
messages << new_message
|
messages << new_message
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def has_tools?
|
||||||
|
tools.present?
|
||||||
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
def validate_message(message)
|
def validate_message(message)
|
||||||
|
|
|
@ -10,6 +10,36 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
let(:echo_tool) do
|
||||||
|
{
|
||||||
|
name: "echo",
|
||||||
|
description: "echo something",
|
||||||
|
parameters: [{ name: "text", type: "string", description: "text to echo", required: true }],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:google_tool) do
|
||||||
|
{
|
||||||
|
name: "google",
|
||||||
|
description: "google something",
|
||||||
|
parameters: [
|
||||||
|
{ name: "query", type: "string", description: "text to google", required: true },
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:prompt_with_echo_tool) do
|
||||||
|
prompt_with_tools = prompt
|
||||||
|
prompt.tools = [echo_tool]
|
||||||
|
prompt_with_tools
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:prompt_with_google_tool) do
|
||||||
|
prompt_with_tools = prompt
|
||||||
|
prompt.tools = [echo_tool]
|
||||||
|
prompt_with_tools
|
||||||
|
end
|
||||||
|
|
||||||
before { SiteSetting.ai_anthropic_api_key = "123" }
|
before { SiteSetting.ai_anthropic_api_key = "123" }
|
||||||
|
|
||||||
it "does not eat spaces with tool calls" do
|
it "does not eat spaces with tool calls" do
|
||||||
|
@ -165,16 +195,18 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do
|
||||||
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
|
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
|
||||||
|
|
||||||
result = +""
|
result = +""
|
||||||
llm.generate(prompt, user: Discourse.system_user) { |partial| result << partial }
|
llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial|
|
||||||
|
result << partial
|
||||||
|
end
|
||||||
|
|
||||||
expected = (<<~TEXT).strip
|
expected = (<<~TEXT).strip
|
||||||
<function_calls>
|
<function_calls>
|
||||||
<invoke>
|
<invoke>
|
||||||
<tool_name>google</tool_name>
|
<tool_name>google</tool_name>
|
||||||
<tool_id>google</tool_id>
|
|
||||||
<parameters>
|
<parameters>
|
||||||
<query>top 10 things to do in japan for tourists</query>
|
<query>top 10 things to do in japan for tourists</query>
|
||||||
</parameters>
|
</parameters>
|
||||||
|
<tool_id>tool_0</tool_id>
|
||||||
</invoke>
|
</invoke>
|
||||||
</function_calls>
|
</function_calls>
|
||||||
TEXT
|
TEXT
|
||||||
|
@ -232,7 +264,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do
|
||||||
expected_body = {
|
expected_body = {
|
||||||
model: "claude-3-opus-20240229",
|
model: "claude-3-opus-20240229",
|
||||||
max_tokens: 3000,
|
max_tokens: 3000,
|
||||||
stop_sequences: ["</function_calls>"],
|
|
||||||
messages: [{ role: "user", content: "user1: hello" }],
|
messages: [{ role: "user", content: "user1: hello" }],
|
||||||
system: "You are hello bot",
|
system: "You are hello bot",
|
||||||
stream: true,
|
stream: true,
|
||||||
|
@ -245,6 +276,70 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do
|
||||||
expect(log.response_tokens).to eq(15)
|
expect(log.response_tokens).to eq(15)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "can return multiple function calls" do
|
||||||
|
functions = <<~FUNCTIONS
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>echo</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<text>something</text>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>echo</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<text>something else</text>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
FUNCTIONS
|
||||||
|
|
||||||
|
body = <<~STRING
|
||||||
|
{
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"text": "Hello!\n\n#{functions}\njunk",
|
||||||
|
"type": "text"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||||
|
"model": "claude-3-opus-20240229",
|
||||||
|
"role": "assistant",
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": null,
|
||||||
|
"type": "message",
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 10,
|
||||||
|
"output_tokens": 25
|
||||||
|
}
|
||||||
|
}
|
||||||
|
STRING
|
||||||
|
|
||||||
|
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
|
||||||
|
|
||||||
|
result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user)
|
||||||
|
|
||||||
|
expected = (<<~EXPECTED).strip
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>echo</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<text>something</text>
|
||||||
|
</parameters>
|
||||||
|
<tool_id>tool_0</tool_id>
|
||||||
|
</invoke>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>echo</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<text>something else</text>
|
||||||
|
</parameters>
|
||||||
|
<tool_id>tool_1</tool_id>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
EXPECTED
|
||||||
|
|
||||||
|
expect(result.strip).to eq(expected)
|
||||||
|
end
|
||||||
|
|
||||||
it "can operate in regular mode" do
|
it "can operate in regular mode" do
|
||||||
body = <<~STRING
|
body = <<~STRING
|
||||||
{
|
{
|
||||||
|
@ -287,7 +382,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do
|
||||||
expected_body = {
|
expected_body = {
|
||||||
model: "claude-3-opus-20240229",
|
model: "claude-3-opus-20240229",
|
||||||
max_tokens: 3000,
|
max_tokens: 3000,
|
||||||
stop_sequences: ["</function_calls>"],
|
|
||||||
messages: [{ role: "user", content: "user1: hello" }],
|
messages: [{ role: "user", content: "user1: hello" }],
|
||||||
system: "You are hello bot",
|
system: "You are hello bot",
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,11 +66,11 @@ class EndpointMock
|
||||||
<function_calls>
|
<function_calls>
|
||||||
<invoke>
|
<invoke>
|
||||||
<tool_name>get_weather</tool_name>
|
<tool_name>get_weather</tool_name>
|
||||||
<tool_id>#{tool_id}</tool_id>
|
|
||||||
<parameters>
|
<parameters>
|
||||||
<location>Sydney</location>
|
<location>Sydney</location>
|
||||||
<unit>c</unit>
|
<unit>c</unit>
|
||||||
</parameters>
|
</parameters>
|
||||||
|
<tool_id>tool_0</tool_id>
|
||||||
</invoke>
|
</invoke>
|
||||||
</function_calls>
|
</function_calls>
|
||||||
TEXT
|
TEXT
|
||||||
|
|
|
@ -132,7 +132,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
|
|
||||||
fab!(:user)
|
fab!(:user)
|
||||||
|
|
||||||
let(:bedrock_mock) { GeminiMock.new(endpoint) }
|
let(:gemini_mock) { GeminiMock.new(endpoint) }
|
||||||
|
|
||||||
let(:compliance) do
|
let(:compliance) do
|
||||||
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user)
|
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user)
|
||||||
|
@ -142,13 +142,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
context "when using regular mode" do
|
context "when using regular mode" do
|
||||||
context "with simple prompts" do
|
context "with simple prompts" do
|
||||||
it "completes a trivial prompt and logs the response" do
|
it "completes a trivial prompt and logs the response" do
|
||||||
compliance.regular_mode_simple_prompt(bedrock_mock)
|
compliance.regular_mode_simple_prompt(gemini_mock)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
context "with tools" do
|
context "with tools" do
|
||||||
it "returns a function invocation" do
|
it "returns a function invocation" do
|
||||||
compliance.regular_mode_tools(bedrock_mock)
|
compliance.regular_mode_tools(gemini_mock)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -156,13 +156,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
describe "when using streaming mode" do
|
describe "when using streaming mode" do
|
||||||
context "with simple prompts" do
|
context "with simple prompts" do
|
||||||
it "completes a trivial prompt and logs the response" do
|
it "completes a trivial prompt and logs the response" do
|
||||||
compliance.streaming_mode_simple_prompt(bedrock_mock)
|
compliance.streaming_mode_simple_prompt(gemini_mock)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
context "with tools" do
|
context "with tools" do
|
||||||
it "returns a function invocation" do
|
it "returns a function invocation" do
|
||||||
compliance.streaming_mode_tools(bedrock_mock)
|
compliance.streaming_mode_tools(gemini_mock)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -102,7 +102,7 @@ class OpenAiMock < EndpointMock
|
||||||
end
|
end
|
||||||
|
|
||||||
def tool_id
|
def tool_id
|
||||||
"eujbuebfe"
|
"tool_0"
|
||||||
end
|
end
|
||||||
|
|
||||||
def tool_payload
|
def tool_payload
|
||||||
|
@ -149,6 +149,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
|
|
||||||
fab!(:user)
|
fab!(:user)
|
||||||
|
|
||||||
|
let(:echo_tool) do
|
||||||
|
{
|
||||||
|
name: "echo",
|
||||||
|
description: "echo something",
|
||||||
|
parameters: [{ name: "text", type: "string", description: "text to echo", required: true }],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:tools) { [echo_tool] }
|
||||||
|
|
||||||
let(:open_ai_mock) { OpenAiMock.new(endpoint) }
|
let(:open_ai_mock) { OpenAiMock.new(endpoint) }
|
||||||
|
|
||||||
let(:compliance) do
|
let(:compliance) do
|
||||||
|
@ -260,23 +270,25 @@ TEXT
|
||||||
open_ai_mock.stub_raw(raw_data)
|
open_ai_mock.stub_raw(raw_data)
|
||||||
content = +""
|
content = +""
|
||||||
|
|
||||||
endpoint.perform_completion!(compliance.dialect, user) { |partial| content << partial }
|
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
|
||||||
|
|
||||||
|
endpoint.perform_completion!(dialect, user) { |partial| content << partial }
|
||||||
|
|
||||||
expected = <<~TEXT
|
expected = <<~TEXT
|
||||||
<function_calls>
|
<function_calls>
|
||||||
<invoke>
|
<invoke>
|
||||||
<tool_name>search</tool_name>
|
<tool_name>search</tool_name>
|
||||||
<tool_id>call_3Gyr3HylFJwfrtKrL6NaIit1</tool_id>
|
|
||||||
<parameters>
|
<parameters>
|
||||||
<search_query>Discourse AI bot</search_query>
|
<search_query>Discourse AI bot</search_query>
|
||||||
</parameters>
|
</parameters>
|
||||||
|
<tool_id>call_3Gyr3HylFJwfrtKrL6NaIit1</tool_id>
|
||||||
</invoke>
|
</invoke>
|
||||||
<invoke>
|
<invoke>
|
||||||
<tool_name>search</tool_name>
|
<tool_name>search</tool_name>
|
||||||
<tool_id>call_H7YkbgYurHpyJqzwUN4bghwN</tool_id>
|
|
||||||
<parameters>
|
<parameters>
|
||||||
<query>Discourse AI bot</query>
|
<query>Discourse AI bot</query>
|
||||||
</parameters>
|
</parameters>
|
||||||
|
<tool_id>call_H7YkbgYurHpyJqzwUN4bghwN</tool_id>
|
||||||
</invoke>
|
</invoke>
|
||||||
</function_calls>
|
</function_calls>
|
||||||
TEXT
|
TEXT
|
||||||
|
@ -321,7 +333,8 @@ TEXT
|
||||||
open_ai_mock.stub_raw(chunks)
|
open_ai_mock.stub_raw(chunks)
|
||||||
partials = []
|
partials = []
|
||||||
|
|
||||||
endpoint.perform_completion!(compliance.dialect, user) { |partial| partials << partial }
|
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
|
||||||
|
endpoint.perform_completion!(dialect, user) { |partial| partials << partial }
|
||||||
|
|
||||||
expect(partials.length).to eq(1)
|
expect(partials.length).to eq(1)
|
||||||
|
|
||||||
|
@ -329,10 +342,10 @@ TEXT
|
||||||
<function_calls>
|
<function_calls>
|
||||||
<invoke>
|
<invoke>
|
||||||
<tool_name>google</tool_name>
|
<tool_name>google</tool_name>
|
||||||
<tool_id>func_id</tool_id>
|
|
||||||
<parameters>
|
<parameters>
|
||||||
<query>Adabas 9.1</query>
|
<query>Adabas 9.1</query>
|
||||||
</parameters>
|
</parameters>
|
||||||
|
<tool_id>func_id</tool_id>
|
||||||
</invoke>
|
</invoke>
|
||||||
</function_calls>
|
</function_calls>
|
||||||
TXT
|
TXT
|
||||||
|
|
|
@ -153,6 +153,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
DiscourseAi::AiBot::Personas::Artist,
|
DiscourseAi::AiBot::Personas::Artist,
|
||||||
DiscourseAi::AiBot::Personas::Creative,
|
DiscourseAi::AiBot::Personas::Creative,
|
||||||
DiscourseAi::AiBot::Personas::DiscourseHelper,
|
DiscourseAi::AiBot::Personas::DiscourseHelper,
|
||||||
|
DiscourseAi::AiBot::Personas::GithubHelper,
|
||||||
DiscourseAi::AiBot::Personas::Researcher,
|
DiscourseAi::AiBot::Personas::Researcher,
|
||||||
DiscourseAi::AiBot::Personas::SettingsExplorer,
|
DiscourseAi::AiBot::Personas::SettingsExplorer,
|
||||||
DiscourseAi::AiBot::Personas::SqlHelper,
|
DiscourseAi::AiBot::Personas::SqlHelper,
|
||||||
|
@ -169,6 +170,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
DiscourseAi::AiBot::Personas::SettingsExplorer,
|
DiscourseAi::AiBot::Personas::SettingsExplorer,
|
||||||
DiscourseAi::AiBot::Personas::Creative,
|
DiscourseAi::AiBot::Personas::Creative,
|
||||||
DiscourseAi::AiBot::Personas::DiscourseHelper,
|
DiscourseAi::AiBot::Personas::DiscourseHelper,
|
||||||
|
DiscourseAi::AiBot::Personas::GithubHelper,
|
||||||
)
|
)
|
||||||
|
|
||||||
AiPersona.find(
|
AiPersona.find(
|
||||||
|
@ -182,6 +184,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
DiscourseAi::AiBot::Personas::SettingsExplorer,
|
DiscourseAi::AiBot::Personas::SettingsExplorer,
|
||||||
DiscourseAi::AiBot::Personas::Creative,
|
DiscourseAi::AiBot::Personas::Creative,
|
||||||
DiscourseAi::AiBot::Personas::DiscourseHelper,
|
DiscourseAi::AiBot::Personas::DiscourseHelper,
|
||||||
|
DiscourseAi::AiBot::Personas::GithubHelper,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -337,6 +337,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "#available_bot_usernames" do
|
||||||
|
it "includes persona users" do
|
||||||
|
persona = Fabricate(:ai_persona)
|
||||||
|
persona.create_user!
|
||||||
|
|
||||||
|
expect(playground.available_bot_usernames).to include(persona.user.username)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
describe "#conversation_context" do
|
describe "#conversation_context" do
|
||||||
context "with limited context" do
|
context "with limited context" do
|
||||||
before do
|
before do
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require "rails_helper"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do
|
||||||
|
let(:tool) do
|
||||||
|
described_class.new(
|
||||||
|
{
|
||||||
|
repo_name: "discourse/discourse-ai",
|
||||||
|
file_paths: %w[lib/database/connection.rb lib/ai_bot/tools/github_pull_request_diff.rb],
|
||||||
|
branch: "8b382d6098fde879d28bbee68d3cbe0a193e4ffc",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
||||||
|
|
||||||
|
describe "#invoke" do
|
||||||
|
before do
|
||||||
|
stub_request(
|
||||||
|
:get,
|
||||||
|
"https://api.github.com/repos/discourse/discourse-ai/contents/lib/database/connection.rb?ref=8b382d6098fde879d28bbee68d3cbe0a193e4ffc",
|
||||||
|
).to_return(
|
||||||
|
status: 200,
|
||||||
|
body: { content: Base64.encode64("content of connection.rb") }.to_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
stub_request(
|
||||||
|
:get,
|
||||||
|
"https://api.github.com/repos/discourse/discourse-ai/contents/lib/ai_bot/tools/github_pull_request_diff.rb?ref=8b382d6098fde879d28bbee68d3cbe0a193e4ffc",
|
||||||
|
).to_return(
|
||||||
|
status: 200,
|
||||||
|
body: { content: Base64.encode64("content of github_pull_request_diff.rb") }.to_json,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "retrieves the content of the specified GitHub files" do
|
||||||
|
result = tool.invoke(nil, llm)
|
||||||
|
expected = {
|
||||||
|
file_contents:
|
||||||
|
"File Path: lib/database/connection.rb:\ncontent of connection.rb\nFile Path: lib/ai_bot/tools/github_pull_request_diff.rb:\ncontent of github_pull_request_diff.rb",
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(result).to eq(expected)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe ".signature" do
|
||||||
|
it "returns the tool signature" do
|
||||||
|
signature = described_class.signature
|
||||||
|
expect(signature[:name]).to eq("github_file_content")
|
||||||
|
expect(signature[:description]).to eq("Retrieves the content of specified GitHub files")
|
||||||
|
expect(signature[:parameters]).to eq(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
name: "repo_name",
|
||||||
|
description: "The name of the GitHub repository (e.g., 'discourse/discourse')",
|
||||||
|
type: "string",
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "file_paths",
|
||||||
|
description: "The paths of the files to retrieve within the repository",
|
||||||
|
type: "array",
|
||||||
|
item_type: "string",
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "branch",
|
||||||
|
description: "The branch or commit SHA to retrieve the files from (default: 'main')",
|
||||||
|
type: "string",
|
||||||
|
required: false,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,61 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require "rails_helper"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do
|
||||||
|
let(:tool) { described_class.new({ repo: repo, pull_id: pull_id }) }
|
||||||
|
let(:bot_user) { Fabricate(:user) }
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
||||||
|
|
||||||
|
context "with a valid pull request" do
|
||||||
|
let(:repo) { "discourse/discourse-automation" }
|
||||||
|
let(:pull_id) { 253 }
|
||||||
|
|
||||||
|
it "retrieves the diff for the pull request" do
|
||||||
|
stub_request(:get, "https://api.github.com/repos/#{repo}/pulls/#{pull_id}").with(
|
||||||
|
headers: {
|
||||||
|
"Accept" => "application/vnd.github.v3.diff",
|
||||||
|
"User-Agent" => DiscourseAi::AiBot::USER_AGENT,
|
||||||
|
},
|
||||||
|
).to_return(status: 200, body: "sample diff")
|
||||||
|
|
||||||
|
result = tool.invoke(bot_user, llm)
|
||||||
|
expect(result[:diff]).to eq("sample diff")
|
||||||
|
expect(result[:error]).to be_nil
|
||||||
|
end
|
||||||
|
|
||||||
|
it "uses the github access token if present" do
|
||||||
|
SiteSetting.ai_bot_github_access_token = "ABC"
|
||||||
|
|
||||||
|
stub_request(:get, "https://api.github.com/repos/#{repo}/pulls/#{pull_id}").with(
|
||||||
|
headers: {
|
||||||
|
"Accept" => "application/vnd.github.v3.diff",
|
||||||
|
"User-Agent" => DiscourseAi::AiBot::USER_AGENT,
|
||||||
|
"Authorization" => "Bearer ABC",
|
||||||
|
},
|
||||||
|
).to_return(status: 200, body: "sample diff")
|
||||||
|
|
||||||
|
result = tool.invoke(bot_user, llm)
|
||||||
|
expect(result[:diff]).to eq("sample diff")
|
||||||
|
expect(result[:error]).to be_nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with an invalid pull request" do
|
||||||
|
let(:repo) { "invalid/repo" }
|
||||||
|
let(:pull_id) { 999 }
|
||||||
|
|
||||||
|
it "returns an error message" do
|
||||||
|
stub_request(:get, "https://api.github.com/repos/#{repo}/pulls/#{pull_id}").with(
|
||||||
|
headers: {
|
||||||
|
"Accept" => "application/vnd.github.v3.diff",
|
||||||
|
"User-Agent" => DiscourseAi::AiBot::USER_AGENT,
|
||||||
|
},
|
||||||
|
).to_return(status: 404)
|
||||||
|
|
||||||
|
result = tool.invoke(bot_user, nil)
|
||||||
|
expect(result[:diff]).to be_nil
|
||||||
|
expect(result[:error]).to include("Failed to retrieve the diff")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,93 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require "rails_helper"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do
|
||||||
|
let(:tool) { described_class.new({ repo: repo, query: query }) }
|
||||||
|
let(:bot_user) { Fabricate(:user) }
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
||||||
|
|
||||||
|
context "with valid search results" do
|
||||||
|
let(:repo) { "discourse/discourse" }
|
||||||
|
let(:query) { "def hello" }
|
||||||
|
|
||||||
|
it "searches for code in the repository" do
|
||||||
|
stub_request(
|
||||||
|
:get,
|
||||||
|
"https://api.github.com/search/code?q=def%20hello+repo:discourse/discourse",
|
||||||
|
).with(
|
||||||
|
headers: {
|
||||||
|
"Accept" => "application/vnd.github.v3.text-match+json",
|
||||||
|
"User-Agent" => DiscourseAi::AiBot::USER_AGENT,
|
||||||
|
},
|
||||||
|
).to_return(
|
||||||
|
status: 200,
|
||||||
|
body: {
|
||||||
|
total_count: 1,
|
||||||
|
items: [
|
||||||
|
{
|
||||||
|
path: "test/hello.rb",
|
||||||
|
name: "hello.rb",
|
||||||
|
text_matches: [{ fragment: "def hello\n puts 'hello'\nend" }],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}.to_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = tool.invoke(bot_user, llm)
|
||||||
|
expect(result[:search_results]).to include("def hello\n puts 'hello'\nend")
|
||||||
|
expect(result[:search_results]).to include("test/hello.rb")
|
||||||
|
expect(result[:error]).to be_nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with an empty search result" do
|
||||||
|
let(:repo) { "discourse/discourse" }
|
||||||
|
let(:query) { "nonexistent_method" }
|
||||||
|
|
||||||
|
describe "#description_args" do
|
||||||
|
it "returns the repo and query" do
|
||||||
|
expect(tool.description_args).to eq(repo: repo, query: query)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
it "returns an empty result" do
|
||||||
|
SiteSetting.ai_bot_github_access_token = "ABC"
|
||||||
|
stub_request(
|
||||||
|
:get,
|
||||||
|
"https://api.github.com/search/code?q=nonexistent_method+repo:discourse/discourse",
|
||||||
|
).with(
|
||||||
|
headers: {
|
||||||
|
"Accept" => "application/vnd.github.v3.text-match+json",
|
||||||
|
"User-Agent" => DiscourseAi::AiBot::USER_AGENT,
|
||||||
|
"Authorization" => "Bearer ABC",
|
||||||
|
},
|
||||||
|
).to_return(status: 200, body: { total_count: 0, items: [] }.to_json)
|
||||||
|
|
||||||
|
result = tool.invoke(bot_user, llm)
|
||||||
|
expect(result[:search_results]).to be_empty
|
||||||
|
expect(result[:error]).to be_nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with an error response" do
|
||||||
|
let(:repo) { "discourse/discourse" }
|
||||||
|
let(:query) { "def hello" }
|
||||||
|
|
||||||
|
it "returns an error message" do
|
||||||
|
stub_request(
|
||||||
|
:get,
|
||||||
|
"https://api.github.com/search/code?q=def%20hello+repo:discourse/discourse",
|
||||||
|
).with(
|
||||||
|
headers: {
|
||||||
|
"Accept" => "application/vnd.github.v3.text-match+json",
|
||||||
|
"User-Agent" => DiscourseAi::AiBot::USER_AGENT,
|
||||||
|
},
|
||||||
|
).to_return(status: 403)
|
||||||
|
|
||||||
|
result = tool.invoke(bot_user, llm)
|
||||||
|
expect(result[:search_results]).to be_nil
|
||||||
|
expect(result[:error]).to include("Failed to perform code search")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
Loading…
Reference in New Issue