mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-09 11:48:47 +00:00
FEATURE: Multi-model support for the AI Bot module. (#56)
We'll create one bot user for each available model. When listed in the `ai_bot_enabled_chat_bots` setting, they will reply. This PR lets us use Claude-v1 in stream mode.
This commit is contained in:
parent
e5537d4c77
commit
7e3cb0ea16
@ -3,5 +3,21 @@
|
|||||||
class AiApiAuditLog < ActiveRecord::Base
|
class AiApiAuditLog < ActiveRecord::Base
|
||||||
module Provider
|
module Provider
|
||||||
OpenAI = 1
|
OpenAI = 1
|
||||||
|
Anthropic = 2
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# == Schema Information
|
||||||
|
#
|
||||||
|
# Table name: ai_api_audit_logs
|
||||||
|
#
|
||||||
|
# id :bigint not null, primary key
|
||||||
|
# provider_id :integer not null
|
||||||
|
# user_id :integer
|
||||||
|
# request_tokens :integer
|
||||||
|
# response_tokens :integer
|
||||||
|
# raw_request_payload :string
|
||||||
|
# raw_response_payload :string
|
||||||
|
# created_at :datetime not null
|
||||||
|
# updated_at :datetime not null
|
||||||
|
#
|
||||||
|
@ -4,52 +4,11 @@ class CompletionPrompt < ActiveRecord::Base
|
|||||||
# TODO(roman): Remove sept 2023.
|
# TODO(roman): Remove sept 2023.
|
||||||
self.ignored_columns = ["value"]
|
self.ignored_columns = ["value"]
|
||||||
|
|
||||||
# GPT 3.5 allows 4000 tokens
|
|
||||||
MAX_PROMPT_TOKENS = 3500
|
|
||||||
|
|
||||||
enum :prompt_type, { text: 0, list: 1, diff: 2 }
|
enum :prompt_type, { text: 0, list: 1, diff: 2 }
|
||||||
|
|
||||||
validates :messages, length: { maximum: 20 }
|
validates :messages, length: { maximum: 20 }
|
||||||
validate :each_message_length
|
validate :each_message_length
|
||||||
|
|
||||||
def self.bot_prompt_with_topic_context(post)
|
|
||||||
messages = []
|
|
||||||
conversation =
|
|
||||||
post
|
|
||||||
.topic
|
|
||||||
.posts
|
|
||||||
.includes(:user)
|
|
||||||
.where("post_number <= ?", post.post_number)
|
|
||||||
.order("post_number desc")
|
|
||||||
.pluck(:raw, :username)
|
|
||||||
|
|
||||||
total_prompt_tokens = 0
|
|
||||||
messages =
|
|
||||||
conversation.reduce([]) do |memo, (raw, username)|
|
|
||||||
break(memo) if total_prompt_tokens >= MAX_PROMPT_TOKENS
|
|
||||||
|
|
||||||
tokens = DiscourseAi::Tokenizer.tokenize(raw)
|
|
||||||
|
|
||||||
if tokens.length + total_prompt_tokens > MAX_PROMPT_TOKENS
|
|
||||||
tokens = tokens[0...(MAX_PROMPT_TOKENS - total_prompt_tokens)]
|
|
||||||
raw = tokens.join(" ")
|
|
||||||
end
|
|
||||||
|
|
||||||
total_prompt_tokens += tokens.length
|
|
||||||
role = username == Discourse.gpt_bot.username ? "system" : "user"
|
|
||||||
|
|
||||||
memo.unshift({ role: role, content: raw })
|
|
||||||
end
|
|
||||||
|
|
||||||
messages.unshift({ role: "system", content: <<~TEXT })
|
|
||||||
You are gpt-bot. You answer questions and generate text.
|
|
||||||
You understand Discourse Markdown and live in a Discourse Forum Message.
|
|
||||||
You are provided you with context of previous discussions.
|
|
||||||
TEXT
|
|
||||||
|
|
||||||
messages
|
|
||||||
end
|
|
||||||
|
|
||||||
def messages_with_user_input(user_input)
|
def messages_with_user_input(user_input)
|
||||||
if ::DiscourseAi::AiHelper::LlmPrompt.new.enabled_provider == "openai"
|
if ::DiscourseAi::AiHelper::LlmPrompt.new.enabled_provider == "openai"
|
||||||
self.messages << { role: "user", content: user_input }
|
self.messages << { role: "user", content: user_input }
|
||||||
|
@ -5,7 +5,7 @@ import { popupAjaxError } from "discourse/lib/ajax-error";
|
|||||||
import loadScript from "discourse/lib/load-script";
|
import loadScript from "discourse/lib/load-script";
|
||||||
|
|
||||||
function isGPTBot(user) {
|
function isGPTBot(user) {
|
||||||
return user && user.id === -110;
|
return user && [-110, -111, -112].includes(user.id);
|
||||||
}
|
}
|
||||||
|
|
||||||
function initializeAIBotReplies(api) {
|
function initializeAIBotReplies(api) {
|
||||||
@ -74,7 +74,7 @@ function initializeAIBotReplies(api) {
|
|||||||
if (
|
if (
|
||||||
this.model.isPrivateMessage &&
|
this.model.isPrivateMessage &&
|
||||||
this.model.details.allowed_users &&
|
this.model.details.allowed_users &&
|
||||||
this.model.details.allowed_users.filter(isGPTBot).length === 1
|
this.model.details.allowed_users.filter(isGPTBot).length >= 1
|
||||||
) {
|
) {
|
||||||
this.messageBus.subscribe(
|
this.messageBus.subscribe(
|
||||||
`discourse-ai/ai-bot/topic/${this.model.id}`,
|
`discourse-ai/ai-bot/topic/${this.model.id}`,
|
||||||
@ -83,7 +83,7 @@ function initializeAIBotReplies(api) {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
unsubscribe: function () {
|
unsubscribe: function () {
|
||||||
this.messageBus.unsubscribe("discourse-ai/ai-bot/topic/");
|
this.messageBus.unsubscribe("discourse-ai/ai-bot/topic/*");
|
||||||
this._super();
|
this._super();
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -32,6 +32,7 @@ en:
|
|||||||
ai_nsfw_models: "Models to use for NSFW inference."
|
ai_nsfw_models: "Models to use for NSFW inference."
|
||||||
|
|
||||||
ai_openai_api_key: "API key for OpenAI API"
|
ai_openai_api_key: "API key for OpenAI API"
|
||||||
|
ai_anthropic_api_key: "API key for Anthropic API"
|
||||||
|
|
||||||
composer_ai_helper_enabled: "Enable the Composer's AI helper."
|
composer_ai_helper_enabled: "Enable the Composer's AI helper."
|
||||||
ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
|
ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
|
||||||
@ -58,6 +59,7 @@ en:
|
|||||||
|
|
||||||
ai_bot_enabled: "Enable the AI Bot module."
|
ai_bot_enabled: "Enable the AI Bot module."
|
||||||
ai_bot_allowed_groups: "When the GPT Bot has access to the PM, it will reply to members of these groups."
|
ai_bot_allowed_groups: "When the GPT Bot has access to the PM, it will reply to members of these groups."
|
||||||
|
ai_bot_enabled_chat_bots: "Available models to act as an AI Bot"
|
||||||
|
|
||||||
|
|
||||||
reviewables:
|
reviewables:
|
||||||
|
@ -180,3 +180,12 @@ plugins:
|
|||||||
type: group_list
|
type: group_list
|
||||||
list_type: compact
|
list_type: compact
|
||||||
default: "3|14" # 3: @staff, 14: @trust_level_4
|
default: "3|14" # 3: @staff, 14: @trust_level_4
|
||||||
|
# Adding a new bot? Make sure to create a user for it on the seed file.
|
||||||
|
ai_bot_enabled_chat_bots:
|
||||||
|
type: list
|
||||||
|
default: "gpt-3.5-turbo"
|
||||||
|
client: true
|
||||||
|
choices:
|
||||||
|
- gpt-3.5-turbo
|
||||||
|
- gpt-4
|
||||||
|
- claude-v1
|
@ -1,20 +1,22 @@
|
|||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
UserEmail.seed do |ue|
|
DiscourseAi::AiBot::EntryPoint::BOTS.each do |id, bot_username|
|
||||||
ue.id = -110
|
UserEmail.seed do |ue|
|
||||||
ue.email = "no_email_gpt_bot"
|
ue.id = id
|
||||||
ue.primary = true
|
ue.email = "no_email_#{bot_username}"
|
||||||
ue.user_id = -110
|
ue.primary = true
|
||||||
end
|
ue.user_id = id
|
||||||
|
end
|
||||||
|
|
||||||
User.seed do |u|
|
User.seed do |u|
|
||||||
u.id = -110
|
u.id = id
|
||||||
u.name = "GPT Bot"
|
u.name = bot_username.titleize
|
||||||
u.username = UserNameSuggester.suggest("gpt_bot")
|
u.username = UserNameSuggester.suggest(bot_username)
|
||||||
u.password = SecureRandom.hex
|
u.password = SecureRandom.hex
|
||||||
u.active = true
|
u.active = true
|
||||||
u.admin = true
|
u.admin = true
|
||||||
u.moderator = true
|
u.moderator = true
|
||||||
u.approved = true
|
u.approved = true
|
||||||
u.trust_level = TrustLevel[4]
|
u.trust_level = TrustLevel[4]
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
45
lib/modules/ai_bot/anthropic_bot.rb
Normal file
45
lib/modules/ai_bot/anthropic_bot.rb
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
class AnthropicBot < Bot
|
||||||
|
def self.can_reply_as?(bot_user)
|
||||||
|
bot_user.id == DiscourseAi::AiBot::EntryPoint::CLAUDE_V1_ID
|
||||||
|
end
|
||||||
|
|
||||||
|
def bot_prompt_with_topic_context(post)
|
||||||
|
super(post).join("\n\n")
|
||||||
|
end
|
||||||
|
|
||||||
|
def prompt_limit
|
||||||
|
7500 # https://console.anthropic.com/docs/prompt-design#what-is-a-prompt
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def build_message(poster_username, content)
|
||||||
|
role = poster_username == bot_user.username ? "Assistant" : "Human"
|
||||||
|
|
||||||
|
"#{role}: #{content}"
|
||||||
|
end
|
||||||
|
|
||||||
|
def model_for
|
||||||
|
"claude-v1"
|
||||||
|
end
|
||||||
|
|
||||||
|
def update_with_delta(_, partial)
|
||||||
|
partial[:completion]
|
||||||
|
end
|
||||||
|
|
||||||
|
def submit_prompt_and_stream_reply(prompt, &blk)
|
||||||
|
DiscourseAi::Inference::AnthropicCompletions.perform!(
|
||||||
|
prompt,
|
||||||
|
model_for,
|
||||||
|
temperature: 0.4,
|
||||||
|
max_tokens: 3000,
|
||||||
|
&blk
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
144
lib/modules/ai_bot/bot.rb
Normal file
144
lib/modules/ai_bot/bot.rb
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
class Bot
|
||||||
|
BOT_NOT_FOUND = Class.new(StandardError)
|
||||||
|
|
||||||
|
def self.as(bot_user)
|
||||||
|
available_bots = [DiscourseAi::AiBot::OpenAiBot, DiscourseAi::AiBot::AnthropicBot]
|
||||||
|
|
||||||
|
bot =
|
||||||
|
available_bots.detect(-> { raise BOT_NOT_FOUND }) do |bot_klass|
|
||||||
|
bot_klass.can_reply_as?(bot_user)
|
||||||
|
end
|
||||||
|
|
||||||
|
bot.new(bot_user)
|
||||||
|
end
|
||||||
|
|
||||||
|
def initialize(bot_user)
|
||||||
|
@bot_user = bot_user
|
||||||
|
end
|
||||||
|
|
||||||
|
def reply_to(post)
|
||||||
|
prompt = bot_prompt_with_topic_context(post)
|
||||||
|
|
||||||
|
redis_stream_key = nil
|
||||||
|
reply = +""
|
||||||
|
bot_reply_post = nil
|
||||||
|
start = Time.now
|
||||||
|
|
||||||
|
submit_prompt_and_stream_reply(prompt) do |partial, cancel|
|
||||||
|
reply = update_with_delta(reply, partial)
|
||||||
|
|
||||||
|
if redis_stream_key && !Discourse.redis.get(redis_stream_key)
|
||||||
|
cancel&.call
|
||||||
|
|
||||||
|
bot_reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) if bot_reply_post
|
||||||
|
end
|
||||||
|
|
||||||
|
next if reply.length < SiteSetting.min_personal_message_post_length
|
||||||
|
# Minor hack to skip the delay during tests.
|
||||||
|
next if (Time.now - start < 0.5) && !Rails.env.test?
|
||||||
|
|
||||||
|
if bot_reply_post
|
||||||
|
Discourse.redis.expire(redis_stream_key, 60)
|
||||||
|
start = Time.now
|
||||||
|
|
||||||
|
publish_update(bot_reply_post, raw: reply.dup)
|
||||||
|
else
|
||||||
|
bot_reply_post =
|
||||||
|
PostCreator.create!(
|
||||||
|
bot_user,
|
||||||
|
topic_id: post.topic_id,
|
||||||
|
raw: reply,
|
||||||
|
skip_validations: false,
|
||||||
|
)
|
||||||
|
redis_stream_key = "gpt_cancel:#{bot_reply_post.id}"
|
||||||
|
Discourse.redis.setex(redis_stream_key, 60, 1)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if bot_reply_post
|
||||||
|
publish_update(bot_reply_post, done: true)
|
||||||
|
bot_reply_post.revise(
|
||||||
|
bot_user,
|
||||||
|
{ raw: reply },
|
||||||
|
skip_validations: true,
|
||||||
|
skip_revision: true,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
rescue => e
|
||||||
|
Discourse.warn_exception(e, message: "ai-bot: Reply failed")
|
||||||
|
end
|
||||||
|
|
||||||
|
def bot_prompt_with_topic_context(post)
|
||||||
|
messages = []
|
||||||
|
conversation = conversation_context(post)
|
||||||
|
|
||||||
|
total_prompt_tokens = 0
|
||||||
|
messages =
|
||||||
|
conversation.reduce([]) do |memo, (raw, username)|
|
||||||
|
break(memo) if total_prompt_tokens >= prompt_limit
|
||||||
|
|
||||||
|
tokens = DiscourseAi::Tokenizer.tokenize(raw)
|
||||||
|
|
||||||
|
if tokens.length + total_prompt_tokens > prompt_limit
|
||||||
|
tokens = tokens[0...(prompt_limit - total_prompt_tokens)]
|
||||||
|
raw = tokens.join(" ")
|
||||||
|
end
|
||||||
|
|
||||||
|
total_prompt_tokens += tokens.length
|
||||||
|
|
||||||
|
memo.unshift(build_message(username, raw))
|
||||||
|
end
|
||||||
|
|
||||||
|
messages.unshift(build_message(bot_user.username, <<~TEXT))
|
||||||
|
You are gpt-bot. You answer questions and generate text.
|
||||||
|
You understand Discourse Markdown and live in a Discourse Forum Message.
|
||||||
|
You are provided you with context of previous discussions.
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
messages
|
||||||
|
end
|
||||||
|
|
||||||
|
def prompt_limit
|
||||||
|
raise NotImplemented
|
||||||
|
end
|
||||||
|
|
||||||
|
protected
|
||||||
|
|
||||||
|
attr_reader :bot_user
|
||||||
|
|
||||||
|
def model_for(bot)
|
||||||
|
raise NotImplemented
|
||||||
|
end
|
||||||
|
|
||||||
|
def get_delta_from(partial)
|
||||||
|
raise NotImplemented
|
||||||
|
end
|
||||||
|
|
||||||
|
def submit_prompt_and_stream_reply(prompt, &blk)
|
||||||
|
raise NotImplemented
|
||||||
|
end
|
||||||
|
|
||||||
|
def conversation_context(post)
|
||||||
|
post
|
||||||
|
.topic
|
||||||
|
.posts
|
||||||
|
.includes(:user)
|
||||||
|
.where("post_number <= ?", post.post_number)
|
||||||
|
.order("post_number desc")
|
||||||
|
.pluck(:raw, :username)
|
||||||
|
end
|
||||||
|
|
||||||
|
def publish_update(bot_reply_post, payload)
|
||||||
|
MessageBus.publish(
|
||||||
|
"discourse-ai/ai-bot/topic/#{bot_reply_post.topic_id}",
|
||||||
|
payload.merge(post_id: bot_reply_post.id, post_number: bot_reply_post.post_number),
|
||||||
|
user_ids: bot_reply_post.topic.allowed_user_ids,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -3,10 +3,20 @@
|
|||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module AiBot
|
module AiBot
|
||||||
class EntryPoint
|
class EntryPoint
|
||||||
AI_BOT_ID = -110
|
GPT4_ID = -110
|
||||||
|
GPT3_5_TURBO_ID = -111
|
||||||
|
CLAUDE_V1_ID = -112
|
||||||
|
BOTS = [
|
||||||
|
[GPT4_ID, "gpt4_bot"],
|
||||||
|
[GPT3_5_TURBO_ID, "gpt3.5_bot"],
|
||||||
|
[CLAUDE_V1_ID, "claude_v1_bot"],
|
||||||
|
]
|
||||||
|
|
||||||
def load_files
|
def load_files
|
||||||
require_relative "jobs/regular/create_ai_reply"
|
require_relative "jobs/regular/create_ai_reply"
|
||||||
|
require_relative "bot"
|
||||||
|
require_relative "anthropic_bot"
|
||||||
|
require_relative "open_ai_bot"
|
||||||
end
|
end
|
||||||
|
|
||||||
def inject_into(plugin)
|
def inject_into(plugin)
|
||||||
@ -14,21 +24,15 @@ module DiscourseAi
|
|||||||
Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai_bot"),
|
Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai_bot"),
|
||||||
)
|
)
|
||||||
|
|
||||||
plugin.add_class_method(Discourse, :gpt_bot) do
|
|
||||||
@ai_bots ||= {}
|
|
||||||
current_db = RailsMultisite::ConnectionManagement.current_db
|
|
||||||
@ai_bots[current_db] ||= User.find(AI_BOT_ID)
|
|
||||||
end
|
|
||||||
|
|
||||||
plugin.on(:post_created) do |post|
|
plugin.on(:post_created) do |post|
|
||||||
if post.topic.private_message? && post.user_id != AI_BOT_ID &&
|
bot_ids = BOTS.map(&:first)
|
||||||
post.topic.topic_allowed_users.exists?(user_id: Discourse.gpt_bot.id)
|
|
||||||
in_allowed_group =
|
|
||||||
SiteSetting.ai_bot_allowed_groups_map.any? do |group_id|
|
|
||||||
post.user.group_ids.include?(group_id)
|
|
||||||
end
|
|
||||||
|
|
||||||
Jobs.enqueue(:create_ai_reply, post_id: post.id) if in_allowed_group
|
if post.topic.private_message? && !bot_ids.include?(post.user_id)
|
||||||
|
if (SiteSetting.ai_bot_allowed_groups_map & post.user.group_ids).present?
|
||||||
|
bot_id = post.topic.topic_allowed_users.where(user_id: bot_ids).first&.user_id
|
||||||
|
|
||||||
|
Jobs.enqueue(:create_ai_reply, post_id: post.id, bot_user_id: bot_id) if bot_id
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -5,72 +5,11 @@ module ::Jobs
|
|||||||
sidekiq_options retry: false
|
sidekiq_options retry: false
|
||||||
|
|
||||||
def execute(args)
|
def execute(args)
|
||||||
|
return unless bot_user = User.find_by(id: args[:bot_user_id])
|
||||||
|
return unless bot = DiscourseAi::AiBot::Bot.as(bot_user)
|
||||||
return unless post = Post.includes(:topic).find_by(id: args[:post_id])
|
return unless post = Post.includes(:topic).find_by(id: args[:post_id])
|
||||||
|
|
||||||
prompt = CompletionPrompt.bot_prompt_with_topic_context(post)
|
bot.reply_to(post)
|
||||||
|
|
||||||
redis_stream_key = nil
|
|
||||||
reply = +""
|
|
||||||
bot_reply_post = nil
|
|
||||||
start = Time.now
|
|
||||||
|
|
||||||
DiscourseAi::Inference::OpenAiCompletions.perform!(
|
|
||||||
prompt,
|
|
||||||
temperature: 0.4,
|
|
||||||
top_p: 0.9,
|
|
||||||
max_tokens: 3000,
|
|
||||||
) do |partial, cancel|
|
|
||||||
content_delta = partial.dig(:choices, 0, :delta, :content)
|
|
||||||
reply << content_delta if content_delta
|
|
||||||
|
|
||||||
if redis_stream_key && !Discourse.redis.get(redis_stream_key)
|
|
||||||
cancel&.call
|
|
||||||
|
|
||||||
bot_reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) if bot_reply_post
|
|
||||||
end
|
|
||||||
|
|
||||||
next if reply.length < SiteSetting.min_personal_message_post_length
|
|
||||||
# Minor hack to skip the delay during tests.
|
|
||||||
next if (Time.now - start < 0.5) && !Rails.env.test?
|
|
||||||
|
|
||||||
if bot_reply_post
|
|
||||||
Discourse.redis.expire(redis_stream_key, 60)
|
|
||||||
start = Time.now
|
|
||||||
|
|
||||||
MessageBus.publish(
|
|
||||||
"discourse-ai/ai-bot/topic/#{post.topic_id}",
|
|
||||||
{ raw: reply.dup, post_id: bot_reply_post.id, post_number: bot_reply_post.post_number },
|
|
||||||
user_ids: post.topic.allowed_user_ids,
|
|
||||||
)
|
|
||||||
else
|
|
||||||
bot_reply_post =
|
|
||||||
PostCreator.create!(
|
|
||||||
Discourse.gpt_bot,
|
|
||||||
topic_id: post.topic_id,
|
|
||||||
raw: reply,
|
|
||||||
skip_validations: false,
|
|
||||||
)
|
|
||||||
redis_stream_key = "gpt_cancel:#{bot_reply_post.id}"
|
|
||||||
Discourse.redis.setex(redis_stream_key, 60, 1)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
MessageBus.publish(
|
|
||||||
"discourse-ai/ai-bot/topic/#{post.topic_id}",
|
|
||||||
{ done: true, post_id: bot_reply_post.id, post_number: bot_reply_post.post_number },
|
|
||||||
user_ids: post.topic.allowed_user_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
if bot_reply_post
|
|
||||||
bot_reply_post.revise(
|
|
||||||
Discourse.gpt_bot,
|
|
||||||
{ raw: reply },
|
|
||||||
skip_validations: true,
|
|
||||||
skip_revision: true,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
rescue => e
|
|
||||||
Discourse.warn_exception(e, message: "ai-bot: Reply failed")
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
48
lib/modules/ai_bot/open_ai_bot.rb
Normal file
48
lib/modules/ai_bot/open_ai_bot.rb
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
class OpenAiBot < Bot
|
||||||
|
def self.can_reply_as?(bot_user)
|
||||||
|
open_ai_bot_ids = [
|
||||||
|
DiscourseAi::AiBot::EntryPoint::GPT4_ID,
|
||||||
|
DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID,
|
||||||
|
]
|
||||||
|
|
||||||
|
open_ai_bot_ids.include?(bot_user.id)
|
||||||
|
end
|
||||||
|
|
||||||
|
def prompt_limit
|
||||||
|
3500
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def build_message(poster_username, content)
|
||||||
|
role = poster_username == bot_user.username ? "system" : "user"
|
||||||
|
|
||||||
|
{ role: role, content: content }
|
||||||
|
end
|
||||||
|
|
||||||
|
def model_for
|
||||||
|
return "gpt-4" if bot_user.id == DiscourseAi::AiBot::EntryPoint::GPT4_ID
|
||||||
|
"gpt-3.5-turbo"
|
||||||
|
end
|
||||||
|
|
||||||
|
def update_with_delta(current_delta, partial)
|
||||||
|
current_delta + partial.dig(:choices, 0, :delta, :content).to_s
|
||||||
|
end
|
||||||
|
|
||||||
|
def submit_prompt_and_stream_reply(prompt, &blk)
|
||||||
|
DiscourseAi::Inference::OpenAiCompletions.perform!(
|
||||||
|
prompt,
|
||||||
|
model_for,
|
||||||
|
temperature: 0.4,
|
||||||
|
top_p: 0.9,
|
||||||
|
max_tokens: 3000,
|
||||||
|
&blk
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -64,7 +64,7 @@ module DiscourseAi
|
|||||||
messages = prompt.messages_with_user_input(text)
|
messages = prompt.messages_with_user_input(text)
|
||||||
|
|
||||||
result[:suggestions] = DiscourseAi::Inference::OpenAiCompletions
|
result[:suggestions] = DiscourseAi::Inference::OpenAiCompletions
|
||||||
.perform!(messages)
|
.perform!(messages, SiteSetting.ai_helper_model)
|
||||||
.dig(:choices)
|
.dig(:choices)
|
||||||
.to_a
|
.to_a
|
||||||
.flat_map { |choice| parse_content(prompt, choice.dig(:message, :content).to_s) }
|
.flat_map { |choice| parse_content(prompt, choice.dig(:message, :content).to_s) }
|
||||||
|
@ -4,32 +4,106 @@ module ::DiscourseAi
|
|||||||
module Inference
|
module Inference
|
||||||
class AnthropicCompletions
|
class AnthropicCompletions
|
||||||
CompletionFailed = Class.new(StandardError)
|
CompletionFailed = Class.new(StandardError)
|
||||||
|
TIMEOUT = 60
|
||||||
|
|
||||||
def self.perform!(prompt)
|
def self.perform!(
|
||||||
|
prompt,
|
||||||
|
model = "claude-v1",
|
||||||
|
temperature: nil,
|
||||||
|
top_p: nil,
|
||||||
|
max_tokens: nil,
|
||||||
|
user_id: nil
|
||||||
|
)
|
||||||
|
url = URI("https://api.anthropic.com/v1/complete")
|
||||||
headers = {
|
headers = {
|
||||||
"x-api-key" => SiteSetting.ai_anthropic_api_key,
|
"x-api-key" => SiteSetting.ai_anthropic_api_key,
|
||||||
"Content-Type" => "application/json",
|
"Content-Type" => "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
model = "claude-v1"
|
payload = { model: model, prompt: prompt }
|
||||||
|
|
||||||
connection_opts = { request: { write_timeout: 60, read_timeout: 60, open_timeout: 60 } }
|
payload[:temperature] = temperature if temperature
|
||||||
|
payload[:top_p] = top_p if top_p
|
||||||
|
payload[:max_tokens_to_sample] = max_tokens || 300
|
||||||
|
payload[:stream] = true if block_given?
|
||||||
|
|
||||||
response =
|
Net::HTTP.start(
|
||||||
Faraday.new(nil, connection_opts).post(
|
url.host,
|
||||||
"https://api.anthropic.com/v1/complete",
|
url.port,
|
||||||
{ model: model, prompt: prompt, max_tokens_to_sample: 300 }.to_json,
|
use_ssl: true,
|
||||||
headers,
|
read_timeout: TIMEOUT,
|
||||||
)
|
open_timeout: TIMEOUT,
|
||||||
|
write_timeout: TIMEOUT,
|
||||||
|
) do |http|
|
||||||
|
request = Net::HTTP::Post.new(url, headers)
|
||||||
|
request_body = payload.to_json
|
||||||
|
request.body = request_body
|
||||||
|
|
||||||
if response.status != 200
|
http.request(request) do |response|
|
||||||
Rails.logger.error(
|
if response.code.to_i != 200
|
||||||
"AnthropicCompletions: status: #{response.status} - body: #{response.body}",
|
Rails.logger.error(
|
||||||
)
|
"AnthropicCompletions: status: #{response.code.to_i} - body: #{response.body}",
|
||||||
raise CompletionFailed
|
)
|
||||||
|
raise CompletionFailed
|
||||||
|
end
|
||||||
|
|
||||||
|
log =
|
||||||
|
AiApiAuditLog.create!(
|
||||||
|
provider_id: AiApiAuditLog::Provider::Anthropic,
|
||||||
|
raw_request_payload: request_body,
|
||||||
|
user_id: user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if !block_given?
|
||||||
|
response_body = response.read_body
|
||||||
|
parsed_response = JSON.parse(response_body, symbolize_names: true)
|
||||||
|
|
||||||
|
log.update!(
|
||||||
|
raw_response_payload: response_body,
|
||||||
|
request_tokens: DiscourseAi::Tokenizer.size(prompt),
|
||||||
|
response_tokens: DiscourseAi::Tokenizer.size(parsed_response[:completion]),
|
||||||
|
)
|
||||||
|
return parsed_response
|
||||||
|
end
|
||||||
|
|
||||||
|
begin
|
||||||
|
cancelled = false
|
||||||
|
cancel = lambda { cancelled = true }
|
||||||
|
response_data = +""
|
||||||
|
response_raw = +""
|
||||||
|
|
||||||
|
response.read_body do |chunk|
|
||||||
|
if cancelled
|
||||||
|
http.finish
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
response_raw << chunk
|
||||||
|
|
||||||
|
chunk
|
||||||
|
.split("\n")
|
||||||
|
.each do |line|
|
||||||
|
data = line.split("data: ", 2)[1]
|
||||||
|
next if !data || data.squish == "[DONE]"
|
||||||
|
|
||||||
|
if !cancelled && partial = JSON.parse(data, symbolize_names: true)
|
||||||
|
response_data << partial[:completion].to_s
|
||||||
|
|
||||||
|
yield partial, cancel
|
||||||
|
end
|
||||||
|
end
|
||||||
|
rescue IOError
|
||||||
|
raise if !cancelled
|
||||||
|
ensure
|
||||||
|
log.update!(
|
||||||
|
raw_response_payload: response_raw,
|
||||||
|
request_tokens: DiscourseAi::Tokenizer.size(prompt),
|
||||||
|
response_tokens: DiscourseAi::Tokenizer.size(response_data),
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
JSON.parse(response.body, symbolize_names: true)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -9,7 +9,7 @@ module ::DiscourseAi
|
|||||||
|
|
||||||
def self.perform!(
|
def self.perform!(
|
||||||
messages,
|
messages,
|
||||||
model = SiteSetting.ai_helper_model,
|
model,
|
||||||
temperature: nil,
|
temperature: nil,
|
||||||
top_p: nil,
|
top_p: nil,
|
||||||
max_tokens: nil,
|
max_tokens: nil,
|
||||||
|
@ -4,7 +4,7 @@ RSpec.describe DiscourseAi::AiBot::EntryPoint do
|
|||||||
describe "#inject_into" do
|
describe "#inject_into" do
|
||||||
describe "subscribes to the post_created event" do
|
describe "subscribes to the post_created event" do
|
||||||
fab!(:admin) { Fabricate(:admin) }
|
fab!(:admin) { Fabricate(:admin) }
|
||||||
let(:gpt_bot) { Discourse.gpt_bot }
|
let(:gpt_bot) { User.find(described_class::GPT4_ID) }
|
||||||
fab!(:bot_allowed_group) { Fabricate(:group) }
|
fab!(:bot_allowed_group) { Fabricate(:group) }
|
||||||
|
|
||||||
let(:post_args) do
|
let(:post_args) do
|
||||||
@ -13,7 +13,6 @@ RSpec.describe DiscourseAi::AiBot::EntryPoint do
|
|||||||
raw: "Hello, Can you please tell me a story?",
|
raw: "Hello, Can you please tell me a story?",
|
||||||
archetype: Archetype.private_message,
|
archetype: Archetype.private_message,
|
||||||
target_usernames: [gpt_bot.username].join(","),
|
target_usernames: [gpt_bot.username].join(","),
|
||||||
category: 1,
|
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -29,6 +28,19 @@ RSpec.describe DiscourseAi::AiBot::EntryPoint do
|
|||||||
).by(1)
|
).by(1)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "includes the bot's user_id" do
|
||||||
|
claude_bot = User.find(described_class::CLAUDE_V1_ID)
|
||||||
|
claude_post_attrs = post_args.merge(target_usernames: [claude_bot.username].join(","))
|
||||||
|
|
||||||
|
expect { PostCreator.create!(admin, claude_post_attrs) }.to change(
|
||||||
|
Jobs::CreateAiReply.jobs,
|
||||||
|
:size,
|
||||||
|
).by(1)
|
||||||
|
|
||||||
|
job_args = Jobs::CreateAiReply.jobs.last["args"].first
|
||||||
|
expect(job_args["bot_user_id"]).to eq(claude_bot.id)
|
||||||
|
end
|
||||||
|
|
||||||
context "when the post is not from a PM" do
|
context "when the post is not from a PM" do
|
||||||
it "does nothing" do
|
it "does nothing" do
|
||||||
expect {
|
expect {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
require_relative "../../../../../support/openai_completions_inference_stubs"
|
require_relative "../../../../../support/openai_completions_inference_stubs"
|
||||||
|
require_relative "../../../../../support/anthropic_completion_stubs"
|
||||||
|
|
||||||
RSpec.describe Jobs::CreateAiReply do
|
RSpec.describe Jobs::CreateAiReply do
|
||||||
describe "#execute" do
|
describe "#execute" do
|
||||||
@ -10,44 +11,82 @@ RSpec.describe Jobs::CreateAiReply do
|
|||||||
let(:expected_response) do
|
let(:expected_response) do
|
||||||
"Hello this is a bot and what you just said is an interesting question"
|
"Hello this is a bot and what you just said is an interesting question"
|
||||||
end
|
end
|
||||||
let(:deltas) { expected_response.split(" ").map { |w| { content: "#{w} " } } }
|
|
||||||
|
|
||||||
before do
|
before { SiteSetting.min_personal_message_post_length = 5 }
|
||||||
SiteSetting.min_personal_message_post_length = 5
|
|
||||||
|
|
||||||
OpenAiCompletionsInferenceStubs.stub_streamed_response(
|
context "when chatting with the Open AI bot" do
|
||||||
CompletionPrompt.bot_prompt_with_topic_context(post),
|
let(:deltas) { expected_response.split(" ").map { |w| { content: "#{w} " } } }
|
||||||
deltas,
|
|
||||||
req_opts: {
|
|
||||||
temperature: 0.4,
|
|
||||||
top_p: 0.9,
|
|
||||||
max_tokens: 3000,
|
|
||||||
stream: true,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "adds a reply from the GPT bot" do
|
before do
|
||||||
subject.execute(post_id: topic.first_post.id)
|
bot_user = User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID)
|
||||||
|
|
||||||
expect(topic.posts.last.raw).to eq(expected_response)
|
OpenAiCompletionsInferenceStubs.stub_streamed_response(
|
||||||
end
|
DiscourseAi::AiBot::OpenAiBot.new(bot_user).bot_prompt_with_topic_context(post),
|
||||||
|
deltas,
|
||||||
it "streams the reply on the fly to the client through MB" do
|
req_opts: {
|
||||||
messages =
|
temperature: 0.4,
|
||||||
MessageBus.track_publish("discourse-ai/ai-bot/topic/#{topic.id}") do
|
top_p: 0.9,
|
||||||
subject.execute(post_id: topic.first_post.id)
|
max_tokens: 3000,
|
||||||
end
|
stream: true,
|
||||||
|
},
|
||||||
done_signal = messages.pop
|
)
|
||||||
|
|
||||||
expect(messages.length).to eq(deltas.length)
|
|
||||||
|
|
||||||
messages.each_with_index do |m, idx|
|
|
||||||
expect(m.data[:raw]).to eq(deltas[0..(idx + 1)].map { |d| d[:content] }.join)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
expect(done_signal.data[:done]).to eq(true)
|
it "adds a reply from the GPT bot" do
|
||||||
|
subject.execute(
|
||||||
|
post_id: topic.first_post.id,
|
||||||
|
bot_user_id: DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(topic.posts.last.raw).to eq(expected_response)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "streams the reply on the fly to the client through MB" do
|
||||||
|
messages =
|
||||||
|
MessageBus.track_publish("discourse-ai/ai-bot/topic/#{topic.id}") do
|
||||||
|
subject.execute(
|
||||||
|
post_id: topic.first_post.id,
|
||||||
|
bot_user_id: DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
done_signal = messages.pop
|
||||||
|
|
||||||
|
expect(messages.length).to eq(deltas.length)
|
||||||
|
|
||||||
|
messages.each_with_index do |m, idx|
|
||||||
|
expect(m.data[:raw]).to eq(deltas[0..(idx + 1)].map { |d| d[:content] }.join)
|
||||||
|
end
|
||||||
|
|
||||||
|
expect(done_signal.data[:done]).to eq(true)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when chatting with Claude from Anthropic" do
|
||||||
|
let(:deltas) { expected_response.split(" ").map { |w| "#{w} " } }
|
||||||
|
|
||||||
|
before do
|
||||||
|
bot_user = User.find(DiscourseAi::AiBot::EntryPoint::CLAUDE_V1_ID)
|
||||||
|
|
||||||
|
AnthropicCompletionStubs.stub_streamed_response(
|
||||||
|
DiscourseAi::AiBot::AnthropicBot.new(bot_user).bot_prompt_with_topic_context(post),
|
||||||
|
deltas,
|
||||||
|
req_opts: {
|
||||||
|
temperature: 0.4,
|
||||||
|
max_tokens_to_sample: 3000,
|
||||||
|
stream: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "adds a reply from the Claude bot" do
|
||||||
|
subject.execute(
|
||||||
|
post_id: topic.first_post.id,
|
||||||
|
bot_user_id: DiscourseAi::AiBot::EntryPoint::CLAUDE_V1_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(topic.posts.last.raw).to eq(expected_response)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
64
spec/lib/modules/ai_bot/open_ai_bot_spec.rb
Normal file
64
spec/lib/modules/ai_bot/open_ai_bot_spec.rb
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::AiBot::OpenAiBot do
|
||||||
|
describe "#bot_prompt_with_topic_context" do
|
||||||
|
fab!(:topic) { Fabricate(:topic) }
|
||||||
|
|
||||||
|
def post_body(post_number)
|
||||||
|
"This is post #{post_number}"
|
||||||
|
end
|
||||||
|
|
||||||
|
def bot_user
|
||||||
|
User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID)
|
||||||
|
end
|
||||||
|
|
||||||
|
subject { described_class.new(bot_user) }
|
||||||
|
|
||||||
|
context "when the topic has one post" do
|
||||||
|
fab!(:post_1) { Fabricate(:post, topic: topic, raw: post_body(1), post_number: 1) }
|
||||||
|
|
||||||
|
it "includes it in the prompt" do
|
||||||
|
prompt_messages = subject.bot_prompt_with_topic_context(post_1)
|
||||||
|
|
||||||
|
post_1_message = prompt_messages[1]
|
||||||
|
|
||||||
|
expect(post_1_message[:role]).to eq("user")
|
||||||
|
expect(post_1_message[:content]).to eq(post_body(1))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when prompt gets very long" do
|
||||||
|
fab!(:post_1) { Fabricate(:post, topic: topic, raw: "test " * 6000, post_number: 1) }
|
||||||
|
|
||||||
|
it "trims the prompt" do
|
||||||
|
prompt_messages = subject.bot_prompt_with_topic_context(post_1)
|
||||||
|
|
||||||
|
expect(prompt_messages[0][:role]).to eq("system")
|
||||||
|
expect(prompt_messages[1][:role]).to eq("user")
|
||||||
|
expected_length = ("test " * (subject.prompt_limit)).length
|
||||||
|
expect(prompt_messages[1][:content].length).to eq(expected_length)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when the topic has multiple posts" do
|
||||||
|
fab!(:post_1) { Fabricate(:post, topic: topic, raw: post_body(1), post_number: 1) }
|
||||||
|
fab!(:post_2) do
|
||||||
|
Fabricate(:post, topic: topic, user: bot_user, raw: post_body(2), post_number: 2)
|
||||||
|
end
|
||||||
|
fab!(:post_3) { Fabricate(:post, topic: topic, raw: post_body(3), post_number: 3) }
|
||||||
|
|
||||||
|
it "includes them in the prompt respecting the post number order" do
|
||||||
|
prompt_messages = subject.bot_prompt_with_topic_context(post_3)
|
||||||
|
|
||||||
|
expect(prompt_messages[1][:role]).to eq("user")
|
||||||
|
expect(prompt_messages[1][:content]).to eq(post_body(1))
|
||||||
|
|
||||||
|
expect(prompt_messages[2][:role]).to eq("system")
|
||||||
|
expect(prompt_messages[2][:content]).to eq(post_body(2))
|
||||||
|
|
||||||
|
expect(prompt_messages[3][:role]).to eq("user")
|
||||||
|
expect(prompt_messages[3][:content]).to eq(post_body(3))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -18,59 +18,4 @@ RSpec.describe CompletionPrompt do
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
describe ".bot_prompt_with_topic_context" do
|
|
||||||
fab!(:topic) { Fabricate(:topic) }
|
|
||||||
|
|
||||||
def post_body(post_number)
|
|
||||||
"This is post #{post_number}"
|
|
||||||
end
|
|
||||||
|
|
||||||
context "when the topic has one post" do
|
|
||||||
fab!(:post_1) { Fabricate(:post, topic: topic, raw: post_body(1), post_number: 1) }
|
|
||||||
|
|
||||||
it "includes it in the prompt" do
|
|
||||||
prompt_messages = described_class.bot_prompt_with_topic_context(post_1)
|
|
||||||
|
|
||||||
post_1_message = prompt_messages[1]
|
|
||||||
|
|
||||||
expect(post_1_message[:role]).to eq("user")
|
|
||||||
expect(post_1_message[:content]).to eq(post_body(1))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "when prompt gets very long" do
|
|
||||||
fab!(:post_1) { Fabricate(:post, topic: topic, raw: "test " * 6000, post_number: 1) }
|
|
||||||
|
|
||||||
it "trims the prompt" do
|
|
||||||
prompt_messages = described_class.bot_prompt_with_topic_context(post_1)
|
|
||||||
|
|
||||||
expect(prompt_messages[0][:role]).to eq("system")
|
|
||||||
expect(prompt_messages[1][:role]).to eq("user")
|
|
||||||
expected_length = ("test " * (CompletionPrompt::MAX_PROMPT_TOKENS)).length
|
|
||||||
expect(prompt_messages[1][:content].length).to eq(expected_length)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "when the topic has multiple posts" do
|
|
||||||
fab!(:post_1) { Fabricate(:post, topic: topic, raw: post_body(1), post_number: 1) }
|
|
||||||
fab!(:post_2) do
|
|
||||||
Fabricate(:post, topic: topic, user: Discourse.gpt_bot, raw: post_body(2), post_number: 2)
|
|
||||||
end
|
|
||||||
fab!(:post_3) { Fabricate(:post, topic: topic, raw: post_body(3), post_number: 3) }
|
|
||||||
|
|
||||||
it "includes them in the prompt respecting the post number order" do
|
|
||||||
prompt_messages = described_class.bot_prompt_with_topic_context(post_3)
|
|
||||||
|
|
||||||
expect(prompt_messages[1][:role]).to eq("user")
|
|
||||||
expect(prompt_messages[1][:content]).to eq(post_body(1))
|
|
||||||
|
|
||||||
expect(prompt_messages[2][:role]).to eq("system")
|
|
||||||
expect(prompt_messages[2][:content]).to eq(post_body(2))
|
|
||||||
|
|
||||||
expect(prompt_messages[3][:role]).to eq("user")
|
|
||||||
expect(prompt_messages[3][:content]).to eq(post_body(3))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
67
spec/shared/inference/anthropic_completions_spec.rb
Normal file
67
spec/shared/inference/anthropic_completions_spec.rb
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "../../support/anthropic_completion_stubs"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Inference::AnthropicCompletions do
|
||||||
|
before { SiteSetting.ai_anthropic_api_key = "abc-123" }
|
||||||
|
|
||||||
|
it "can complete a trivial prompt" do
|
||||||
|
response_text = "1. Serenity\\n2. Laughter\\n3. Adventure"
|
||||||
|
prompt = "Human: write 3 words\n\n"
|
||||||
|
user_id = 183
|
||||||
|
req_opts = { temperature: 0.5, max_tokens_to_sample: 700 }
|
||||||
|
|
||||||
|
AnthropicCompletionStubs.stub_response(prompt, response_text, req_opts: req_opts)
|
||||||
|
|
||||||
|
completions =
|
||||||
|
DiscourseAi::Inference::AnthropicCompletions.perform!(
|
||||||
|
prompt,
|
||||||
|
"claude-v1",
|
||||||
|
temperature: req_opts[:temperature],
|
||||||
|
max_tokens: req_opts[:max_tokens_to_sample],
|
||||||
|
user_id: user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(completions[:completion]).to eq(response_text)
|
||||||
|
|
||||||
|
expect(AiApiAuditLog.count).to eq(1)
|
||||||
|
log = AiApiAuditLog.first
|
||||||
|
|
||||||
|
request_body = { model: "claude-v1", prompt: prompt }.merge(req_opts).to_json
|
||||||
|
response_body = AnthropicCompletionStubs.response(response_text).to_json
|
||||||
|
|
||||||
|
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
|
||||||
|
expect(log.request_tokens).to eq(7)
|
||||||
|
expect(log.response_tokens).to eq(16)
|
||||||
|
expect(log.raw_request_payload).to eq(request_body)
|
||||||
|
expect(log.raw_response_payload).to eq(response_body)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "supports streaming mode" do
|
||||||
|
deltas = ["Mount", "ain", " ", "Tree ", "Frog"]
|
||||||
|
prompt = "Human: write 3 words\n\n"
|
||||||
|
req_opts = { max_tokens_to_sample: 300, stream: true }
|
||||||
|
content = +""
|
||||||
|
|
||||||
|
AnthropicCompletionStubs.stub_streamed_response(prompt, deltas, req_opts: req_opts)
|
||||||
|
|
||||||
|
DiscourseAi::Inference::AnthropicCompletions.perform!(prompt, "claude-v1") do |partial, cancel|
|
||||||
|
data = partial[:completion]
|
||||||
|
content = data if data
|
||||||
|
cancel.call if content.split(" ").length == 2
|
||||||
|
end
|
||||||
|
|
||||||
|
expect(content).to eq("Mountain Tree ")
|
||||||
|
|
||||||
|
expect(AiApiAuditLog.count).to eq(1)
|
||||||
|
log = AiApiAuditLog.first
|
||||||
|
|
||||||
|
request_body = { model: "claude-v1", prompt: prompt }.merge(req_opts).to_json
|
||||||
|
|
||||||
|
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
|
||||||
|
expect(log.request_tokens).to eq(7)
|
||||||
|
expect(log.response_tokens).to eq(9)
|
||||||
|
expect(log.raw_request_payload).to eq(request_body)
|
||||||
|
expect(log.raw_response_payload).to be_present
|
||||||
|
end
|
||||||
|
end
|
55
spec/support/anthropic_completion_stubs.rb
Normal file
55
spec/support/anthropic_completion_stubs.rb
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class AnthropicCompletionStubs
|
||||||
|
class << self
|
||||||
|
def response(content)
|
||||||
|
{
|
||||||
|
completion: content,
|
||||||
|
stop: "\n\nHuman:",
|
||||||
|
stop_reason: "stop_sequence",
|
||||||
|
truncated: false,
|
||||||
|
log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
|
||||||
|
model: "claude-v1",
|
||||||
|
exception: nil,
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def stub_response(prompt, response_text, req_opts: {})
|
||||||
|
WebMock
|
||||||
|
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
||||||
|
.with(body: { model: "claude-v1", prompt: prompt }.merge(req_opts).to_json)
|
||||||
|
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||||
|
end
|
||||||
|
|
||||||
|
def stream_line(delta, finish_reason: nil)
|
||||||
|
+"data: " << {
|
||||||
|
completion: delta,
|
||||||
|
stop: finish_reason ? "\n\nHuman:" : nil,
|
||||||
|
stop_reason: finish_reason,
|
||||||
|
truncated: false,
|
||||||
|
log_id: "12b029451c6d18094d868bc04ce83f63",
|
||||||
|
model: "claude-v1",
|
||||||
|
exception: nil,
|
||||||
|
}.to_json
|
||||||
|
end
|
||||||
|
|
||||||
|
def stub_streamed_response(prompt, deltas, req_opts: {})
|
||||||
|
chunks =
|
||||||
|
deltas.each_with_index.map do |_, index|
|
||||||
|
if index == (deltas.length - 1)
|
||||||
|
stream_line(deltas.join(""), finish_reason: "stop_sequence")
|
||||||
|
else
|
||||||
|
stream_line(deltas[0..index].join(""))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
chunks << "[DONE]"
|
||||||
|
chunks = chunks.join("\n\n")
|
||||||
|
|
||||||
|
WebMock
|
||||||
|
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
||||||
|
.with(body: { model: "claude-v1", prompt: prompt }.merge(req_opts).to_json)
|
||||||
|
.to_return(status: 200, body: chunks)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
Loading…
x
Reference in New Issue
Block a user