From 79638c2f5055f1833e58d4c059b701495bafc1df Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 9 Mar 2024 08:46:40 +1100 Subject: [PATCH] FIX: Tune function calling (#519) Adds support for "name" on functions which can be used for tool calls For function calls we need to keep track of id/name and previously we only supported either Also attempts to improve sql helper --- lib/ai_bot/bot.rb | 14 +++-- lib/ai_bot/personas/persona.rb | 4 +- lib/ai_bot/personas/sql_helper.rb | 51 ++++++++++++++----- lib/ai_bot/playground.rb | 1 + lib/completions/dialects/chat_gpt.rb | 3 +- lib/completions/dialects/claude.rb | 20 ++------ lib/completions/dialects/dialect.rb | 13 +++-- lib/completions/dialects/gemini.rb | 4 +- lib/completions/dialects/mixtral.rb | 19 ++----- lib/completions/dialects/orca_style.rb | 21 +++----- lib/completions/prompt.rb | 5 +- .../lib/completions/dialects/chat_gpt_spec.rb | 7 ++- spec/lib/completions/dialects/claude_spec.rb | 14 ++++- .../completions/dialects/dialect_context.rb | 5 +- spec/lib/completions/dialects/gemini_spec.rb | 2 +- spec/lib/completions/dialects/mixtral_spec.rb | 11 +++- .../completions/dialects/orca_style_spec.rb | 12 ++++- .../modules/ai_bot/personas/persona_spec.rb | 10 +++- 18 files changed, 134 insertions(+), 82 deletions(-) diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 4b969a0b..71702469 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -105,10 +105,16 @@ module DiscourseAi tool_call_message = { type: :tool_call, id: tool_call_id, - content: { name: tool.name, arguments: tool.parameters }.to_json, + content: { arguments: tool.parameters }.to_json, + name: tool.name, } - tool_message = { type: :tool, id: tool_call_id, content: invocation_result_json } + tool_message = { + type: :tool, + id: tool_call_id, + content: invocation_result_json, + name: tool.name, + } if tool.standalone? standalone_context = @@ -125,8 +131,8 @@ module DiscourseAi prompt.push(**tool_message) end - raw_context << [tool_call_message[:content], tool_call_id, "tool_call"] - raw_context << [invocation_result_json, tool_call_id, "tool"] + raw_context << [tool_call_message[:content], tool_call_id, "tool_call", tool.name] + raw_context << [invocation_result_json, tool_call_id, "tool", tool.name] end def invoke_tool(tool, llm, cancel, &update_blk) diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index f29cc6b3..e568a24d 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -143,10 +143,10 @@ module DiscourseAi def find_tool(parsed_function) function_id = parsed_function.at("tool_id")&.text function_name = parsed_function.at("tool_name")&.text - return false if function_name.nil? + return nil if function_name.nil? tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name } - return false if tool_klass.nil? + return nil if tool_klass.nil? arguments = {} tool_klass.signature[:parameters].to_a.each do |param| diff --git a/lib/ai_bot/personas/sql_helper.rb b/lib/ai_bot/personas/sql_helper.rb index 6afd76a8..a31dc191 100644 --- a/lib/ai_bot/personas/sql_helper.rb +++ b/lib/ai_bot/personas/sql_helper.rb @@ -8,7 +8,16 @@ module DiscourseAi return @schema if defined?(@schema) tables = Hash.new - priority_tables = %w[posts topics notifications users user_actions user_emails] + priority_tables = %w[ + posts + topics + notifications + users + user_actions + user_emails + categories + groups + ] DB.query(<<~SQL).each { |row| (tables[row.table_name] ||= []) << row.column_name } select table_name, column_name from information_schema.columns @@ -16,15 +25,16 @@ module DiscourseAi order by table_name SQL - schema = +(priority_tables.map { |name| "#{name}(#{tables[name].join(",")})" }.join("\n")) + priority = + +(priority_tables.map { |name| "#{name}(#{tables[name].join(",")})" }.join("\n")) - schema << "\nOther tables (schema redacted, available on request): " + other_tables = +"" tables.each do |table_name, _| next if priority_tables.include?(table_name) - schema << "#{table_name} " + other_tables << "#{table_name} " end - @schema = schema + @schema = { priority_tables: priority, other_tables: other_tables } end def tools @@ -38,12 +48,15 @@ module DiscourseAi def system_prompt <<~PROMPT You are a PostgreSQL expert. + - Avoid returning any text to the user prior to a tool call. - You understand and generate Discourse Markdown but specialize in creating queries. - You live in a Discourse Forum Message. - - The schema in your training set MAY be out of date. + - Format SQL for maximum readability. Use line breaks, indentation, and spaces around operators. Add comments if needed to explain complex logic. + - Never warn or inform end user you are going to look up schema. + - Always try to get ALL the schema you need in the least tool calls. + - Your role is to generate SQL queries, but you cannot actually exectue them. + - When generating SQL always use ```sql Markdown code blocks. - When generating SQL NEVER end SQL samples with a semicolon (;). - - When generating SQL always use ```sql markdown code blocks. - - Always format SQL in a highly readable format. Eg: @@ -52,17 +65,29 @@ module DiscourseAi ``` The user_actions tables stores likes (action_type 1). - the topics table stores private/personal messages it uses archetype private_message for them. + The topics table stores private/personal messages it uses archetype private_message for them. notification_level can be: {muted: 0, regular: 1, tracking: 2, watching: 3, watching_first_post: 4}. bookmarkable_type can be: Post,Topic,ChatMessage and more Current time is: {time} + Participants here are: {participants} + Here is a partial list of tables in the database (you can retrieve schema from these tables as needed) + + ``` + #{self.class.schema[:other_tables]} + ``` + + You may look up schema for the tables listed above. + + Here is full information on priority tables: + + ``` + #{self.class.schema[:priority_tables]} + ``` + + NEVER look up schema for the tables listed above, as their full schema is already provided. - The current schema for the current DB is: - {{ - #{self.class.schema} - }} PROMPT end end diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index e546f328..421033ce 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -122,6 +122,7 @@ module DiscourseAi } custom_context[:id] = message[1] if custom_context[:type] != :model + custom_context[:name] = message[3] if message[3] result << custom_context end diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index b98c9461..6cd9f3b8 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -48,6 +48,7 @@ module DiscourseAi elsif msg[:type] == :tool_call call_details = JSON.parse(msg[:content], symbolize_names: true) call_details[:arguments] = call_details[:arguments].to_json + call_details[:name] = msg[:name] { role: "assistant", @@ -55,7 +56,7 @@ module DiscourseAi tool_calls: [{ type: "function", function: call_details, id: msg[:id] }], } elsif msg[:type] == :tool - { role: "tool", tool_call_id: msg[:id], content: msg[:content] } + { role: "tool", tool_call_id: msg[:id], content: msg[:content], name: msg[:name] } else user_message = { role: "user", content: msg[:content] } if msg[:id] diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index f05739c0..6043a042 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -24,9 +24,9 @@ module DiscourseAi claude_prompt = trimmed_messages.reduce(+"") do |memo, msg| - next(memo) if msg[:type] == :tool_call - - if msg[:type] == :system + if msg[:type] == :tool_call + memo << "\n\nAssistant: #{tool_call_to_xml(msg)}" + elsif msg[:type] == :system memo << "Human: " unless uses_system_message? memo << msg[:content] if prompt.tools.present? @@ -36,18 +36,8 @@ module DiscourseAi elsif msg[:type] == :model memo << "\n\nAssistant: #{msg[:content]}" elsif msg[:type] == :tool - memo << "\n\nAssistant:\n" - - memo << (<<~TEXT).strip - - - #{msg[:id]} - - #{msg[:content]} - - - - TEXT + memo << "\n\nHuman:\n" + memo << tool_result_to_xml(msg) else memo << "\n\nHuman: " memo << "#{msg[:id]}: " if msg[:id] diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index a9d6a957..fb7b48da 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -51,10 +51,13 @@ module DiscourseAi If a parameter type is an array, return a JSON array of values. For example: [1,"two",3.0] - Always wrap calls in tags. - You may call multiple function via in a single block. + If you wish to call multiple function in one reply, wrap multiple + block in a single block. - Here are the tools available: + Always prefer to lead with tool calls, if you need to execute any. + Avoid all niceties prior to tool calls, Eg: "Let me look this up for you.." etc. + + Here are the complete list of tools available: TEXT end end @@ -73,7 +76,7 @@ module DiscourseAi (<<~TEXT).strip - #{message[:id]} + #{message[:name] || message[:id]} #{message[:content]} @@ -95,7 +98,7 @@ module DiscourseAi (<<~TEXT).strip - #{parsed[:name]} + #{message[:name] || parsed[:name]} #{parameters} TEXT diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index 4987df54..c052a5d2 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -38,7 +38,7 @@ module DiscourseAi role: "model", parts: { functionCall: { - name: call_details[:name], + name: msg[:name] || call_details[:name], args: call_details[:arguments], }, }, @@ -48,7 +48,7 @@ module DiscourseAi role: "function", parts: { functionResponse: { - name: msg[:id], + name: msg[:name] || msg[:id], response: { content: msg[:content], }, diff --git a/lib/completions/dialects/mixtral.rb b/lib/completions/dialects/mixtral.rb index fd529b24..425d741e 100644 --- a/lib/completions/dialects/mixtral.rb +++ b/lib/completions/dialects/mixtral.rb @@ -21,9 +21,10 @@ module DiscourseAi mixtral_prompt = trim_messages(messages).reduce(+"") do |memo, msg| - next(memo) if msg[:type] == :tool_call - - if msg[:type] == :system + if msg[:type] == :tool_call + memo << "\n" + memo << tool_call_to_xml(msg) + elsif msg[:type] == :system memo << (<<~TEXT).strip [INST] #{msg[:content]} @@ -34,17 +35,7 @@ module DiscourseAi memo << "\n#{msg[:content]}" elsif msg[:type] == :tool memo << "\n" - - memo << (<<~TEXT).strip - - - #{msg[:id]} - - #{msg[:content]} - - - - TEXT + memo << tool_result_to_xml(msg) else memo << "\n[INST]#{msg[:content]}[/INST]" end diff --git a/lib/completions/dialects/orca_style.rb b/lib/completions/dialects/orca_style.rb index a8c9e939..4ec42a36 100644 --- a/lib/completions/dialects/orca_style.rb +++ b/lib/completions/dialects/orca_style.rb @@ -23,9 +23,10 @@ module DiscourseAi llama2_prompt = trimmed_messages.reduce(+"") do |memo, msg| - next(memo) if msg[:type] == :tool_call - - if msg[:type] == :system + if msg[:type] == :tool_call + memo << "\n### Assistant:\n" + memo << tool_call_to_xml(msg) + elsif msg[:type] == :system memo << (<<~TEXT).strip ### System: #{msg[:content]} @@ -34,18 +35,8 @@ module DiscourseAi elsif msg[:type] == :model memo << "\n### Assistant:\n#{msg[:content]}" elsif msg[:type] == :tool - memo << "\n### Assistant:\n" - - memo << (<<~TEXT).strip - - - #{msg[:id]} - - #{msg[:content]} - - - - TEXT + memo << "\n### User:\n" + memo << tool_result_to_xml(msg) else memo << "\n### User:\n#{msg[:content]}" end diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb index 959b3d14..9e29384e 100644 --- a/lib/completions/prompt.rb +++ b/lib/completions/prompt.rb @@ -38,9 +38,10 @@ module DiscourseAi @tools = tools end - def push(type:, content:, id: nil) + def push(type:, content:, id: nil, name: nil) return if type == :system new_message = { type: type, content: content } + new_message[:name] = name.to_s if name new_message[:id] = id.to_s if id validate_message(new_message) @@ -62,7 +63,7 @@ module DiscourseAi raise ArgumentError, "message type must be one of #{valid_types}" end - valid_keys = %i[type content id] + valid_keys = %i[type content id name] if (invalid_keys = message.keys - valid_keys).any? raise ArgumentError, "message contains invalid keys: #{invalid_keys}" end diff --git a/spec/lib/completions/dialects/chat_gpt_spec.rb b/spec/lib/completions/dialects/chat_gpt_spec.rb index c1714d75..ff1137e6 100644 --- a/spec/lib/completions/dialects/chat_gpt_spec.rb +++ b/spec/lib/completions/dialects/chat_gpt_spec.rb @@ -62,7 +62,12 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do }, ], }, - { role: "tool", content: "I'm a tool result".to_json, tool_call_id: "tool_id" }, + { + role: "tool", + content: "I'm a tool result".to_json, + tool_call_id: "tool_id", + name: "get_weather", + }, ], ) end diff --git a/spec/lib/completions/dialects/claude_spec.rb b/spec/lib/completions/dialects/claude_spec.rb index 5c4619f0..efb62bb3 100644 --- a/spec/lib/completions/dialects/claude_spec.rb +++ b/spec/lib/completions/dialects/claude_spec.rb @@ -37,10 +37,20 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do Human: user1: This is a new message by a user - Assistant: + Assistant: + + get_weather + + Sydney + c + + + + + Human: - tool_id + get_weather "I'm a tool result" diff --git a/spec/lib/completions/dialects/dialect_context.rb b/spec/lib/completions/dialects/dialect_context.rb index f97eddb0..a1a4a221 100644 --- a/spec/lib/completions/dialects/dialect_context.rb +++ b/spec/lib/completions/dialects/dialect_context.rb @@ -51,9 +51,10 @@ class DialectContext { type: :tool_call, id: "tool_id", - content: { name: "get_weather", arguments: { location: "Sydney", unit: "c" } }.to_json, + name: "get_weather", + content: { arguments: { location: "Sydney", unit: "c" } }.to_json, }, - { type: :tool, id: "tool_id", content: "I'm a tool result".to_json }, + { type: :tool, id: "tool_id", name: "get_weather", content: "I'm a tool result".to_json }, ] a_prompt = prompt diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb index 534de8f4..338eaadd 100644 --- a/spec/lib/completions/dialects/gemini_spec.rb +++ b/spec/lib/completions/dialects/gemini_spec.rb @@ -72,7 +72,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do role: "function", parts: { functionResponse: { - name: "tool_id", + name: "get_weather", response: { content: "I'm a tool result".to_json, }, diff --git a/spec/lib/completions/dialects/mixtral_spec.rb b/spec/lib/completions/dialects/mixtral_spec.rb index 528b2ac9..533d3954 100644 --- a/spec/lib/completions/dialects/mixtral_spec.rb +++ b/spec/lib/completions/dialects/mixtral_spec.rb @@ -34,9 +34,18 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do [INST]This is a message by a user[/INST] I'm a previous bot reply, that's why there's no user [INST]This is a new message by a user[/INST] + + + get_weather + + Sydney + c + + + - tool_id + get_weather "I'm a tool result" diff --git a/spec/lib/completions/dialects/orca_style_spec.rb b/spec/lib/completions/dialects/orca_style_spec.rb index 6c683505..5513f184 100644 --- a/spec/lib/completions/dialects/orca_style_spec.rb +++ b/spec/lib/completions/dialects/orca_style_spec.rb @@ -38,9 +38,19 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do ### User: This is a new message by a user ### Assistant: + + + get_weather + + Sydney + c + + + + ### User: - tool_id + get_weather "I'm a tool result" diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index a60a6b16..dcbb32c0 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -89,11 +89,19 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do ["pic3"] + + unknown + abc + + ["pic3"] + + XML - dall_e1, dall_e2 = DiscourseAi::AiBot::Personas::DallE3.new.find_tools(xml) + dall_e1, dall_e2 = tools = DiscourseAi::AiBot::Personas::DallE3.new.find_tools(xml) expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"]) expect(dall_e2.parameters[:prompts]).to eq(["pic3"]) + expect(tools.length).to eq(2) end describe "custom personas" do