FEATURE: GPT4o support and better auditing (#618)
- Introduce new support for GPT4o (automation / bot / summary / helper) - Properly account for token counts on OpenAI models - Track feature that was used when generating AI completions - Remove custom llm support for summarization as we need better interfaces to control registration and de-registration
This commit is contained in:
parent
8b00c47087
commit
8eee6893d6
|
@ -22,12 +22,13 @@ end
|
|||
# id :bigint not null, primary key
|
||||
# provider_id :integer not null
|
||||
# user_id :integer
|
||||
# topic_id :integer
|
||||
# post_id :integer
|
||||
# request_tokens :integer
|
||||
# response_tokens :integer
|
||||
# raw_request_payload :string
|
||||
# raw_response_payload :string
|
||||
# created_at :datetime not null
|
||||
# updated_at :datetime not null
|
||||
# topic_id :integer
|
||||
# post_id :integer
|
||||
# feature_name :string(255)
|
||||
#
|
||||
|
|
|
@ -20,6 +20,7 @@ en:
|
|||
mistral_7b_instruct_v0_2: Mistral 7B Instruct V0.2
|
||||
command_r: Cohere Command R
|
||||
command_r_plus: Cohere Command R+
|
||||
gpt_4o: GPT 4 Omni
|
||||
scriptables:
|
||||
llm_report:
|
||||
fields:
|
||||
|
@ -328,6 +329,7 @@ en:
|
|||
cohere-command-r-plus: "Cohere Command R Plus"
|
||||
gpt-4: "GPT-4"
|
||||
gpt-4-turbo: "GPT-4 Turbo"
|
||||
gpt-4o: "GPT-4 Omni"
|
||||
gpt-3:
|
||||
5-turbo: "GPT-3.5"
|
||||
claude-2: "Claude 2"
|
||||
|
|
|
@ -215,6 +215,7 @@ pl_PL:
|
|||
bot_names:
|
||||
gpt-4: "GPT-4"
|
||||
gpt-4-turbo: "GPT-4 Turbo"
|
||||
gpt-4o: "GPT-4 Omni"
|
||||
gpt-3:
|
||||
5-turbo: "GPT-3.5"
|
||||
claude-2: "Claude 2"
|
||||
|
|
|
@ -43,6 +43,7 @@ en:
|
|||
ai_openai_gpt35_url: "Custom URL used for GPT 3.5 chat completions. (for Azure support)"
|
||||
ai_openai_gpt35_16k_url: "Custom URL used for GPT 3.5 16k chat completions. (for Azure support)"
|
||||
ai_openai_gpt4_url: "Custom URL used for GPT 4 chat completions. (for Azure support)"
|
||||
ai_openai_gpt4o_url: "Custom URL used for GPT 4 Omni chat completions. (for Azure support)"
|
||||
ai_openai_gpt4_32k_url: "Custom URL used for GPT 4 32k chat completions. (for Azure support)"
|
||||
ai_openai_gpt4_turbo_url: "Custom URL used for GPT 4 Turbo chat completions. (for Azure support)"
|
||||
ai_openai_dall_e_3_url: "Custom URL used for DALL-E 3 image generation. (for Azure support)"
|
||||
|
|
|
@ -98,6 +98,7 @@ discourse_ai:
|
|||
|
||||
ai_openai_gpt35_url: "https://api.openai.com/v1/chat/completions"
|
||||
ai_openai_gpt35_16k_url: "https://api.openai.com/v1/chat/completions"
|
||||
ai_openai_gpt4o_url: "https://api.openai.com/v1/chat/completions"
|
||||
ai_openai_gpt4_url: "https://api.openai.com/v1/chat/completions"
|
||||
ai_openai_gpt4_32k_url: "https://api.openai.com/v1/chat/completions"
|
||||
ai_openai_gpt4_turbo_url: "https://api.openai.com/v1/chat/completions"
|
||||
|
@ -343,6 +344,7 @@ discourse_ai:
|
|||
- gpt-3.5-turbo
|
||||
- gpt-4
|
||||
- gpt-4-turbo
|
||||
- gpt-4o
|
||||
- claude-2
|
||||
- gemini-1.5-pro
|
||||
- mixtral-8x7B-Instruct-V0.1
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class AddFeatureNameToAiApiAuditLog < ActiveRecord::Migration[7.0]
|
||||
def change
|
||||
add_column :ai_api_audit_logs, :feature_name, :string, limit: 255
|
||||
end
|
||||
end
|
|
@ -52,7 +52,7 @@ class ExplicitProviderBackwardsCompat < ActiveRecord::Migration[7.0]
|
|||
end
|
||||
|
||||
def append_provider(value)
|
||||
open_ai_models = %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k gpt-4-turbo]
|
||||
open_ai_models = %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k gpt-4-turbo gpt-4o]
|
||||
return "open_ai:#{value}" if open_ai_models.include?(value)
|
||||
return "google:#{value}" if value == "gemini-pro"
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ module DiscourseAi
|
|||
|
||||
DiscourseAi::Completions::Llm
|
||||
.proxy(model)
|
||||
.generate(title_prompt, user: post.user)
|
||||
.generate(title_prompt, user: post.user, feature_name: "bot_title")
|
||||
.strip
|
||||
.split("\n")
|
||||
.last
|
||||
|
@ -67,7 +67,7 @@ module DiscourseAi
|
|||
tool_found = false
|
||||
|
||||
result =
|
||||
llm.generate(prompt, **llm_kwargs) do |partial, cancel|
|
||||
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
|
||||
tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context)
|
||||
|
||||
if (tools.present?)
|
||||
|
@ -162,6 +162,8 @@ module DiscourseAi
|
|||
"open_ai:gpt-4"
|
||||
when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
|
||||
"open_ai:gpt-4-turbo"
|
||||
when DiscourseAi::AiBot::EntryPoint::GPT4O_ID
|
||||
"open_ai:gpt-4o"
|
||||
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
|
||||
"open_ai:gpt-3.5-turbo-16k"
|
||||
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
|
||||
|
|
|
@ -18,6 +18,7 @@ module DiscourseAi
|
|||
CLAUDE_3_SONNET_ID = -118
|
||||
CLAUDE_3_HAIKU_ID = -119
|
||||
COHERE_COMMAND_R_PLUS = -120
|
||||
GPT4O_ID = -121
|
||||
|
||||
BOTS = [
|
||||
[GPT4_ID, "gpt4_bot", "gpt-4"],
|
||||
|
@ -31,6 +32,7 @@ module DiscourseAi
|
|||
[CLAUDE_3_SONNET_ID, "claude_3_sonnet_bot", "claude-3-sonnet"],
|
||||
[CLAUDE_3_HAIKU_ID, "claude_3_haiku_bot", "claude-3-haiku"],
|
||||
[COHERE_COMMAND_R_PLUS, "cohere_command_bot", "cohere-command-r-plus"],
|
||||
[GPT4O_ID, "gpt4o_bot", "gpt-4o"],
|
||||
]
|
||||
|
||||
BOT_USER_IDS = BOTS.map(&:first)
|
||||
|
@ -49,6 +51,8 @@ module DiscourseAi
|
|||
|
||||
def self.map_bot_model_to_user_id(model_name)
|
||||
case model_name
|
||||
in "gpt-4o"
|
||||
GPT4O_ID
|
||||
in "gpt-4-turbo"
|
||||
GPT4_TURBO_ID
|
||||
in "gpt-3.5-turbo"
|
||||
|
|
|
@ -17,7 +17,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def consolidate_question
|
||||
@llm.generate(revised_prompt, user: @user)
|
||||
@llm.generate(revised_prompt, user: @user, feature_name: "question_consolidator")
|
||||
end
|
||||
|
||||
def revised_prompt
|
||||
|
|
|
@ -135,7 +135,14 @@ module DiscourseAi
|
|||
|
||||
prompt = section_prompt(topic, section, guidance)
|
||||
|
||||
summary = llm.generate(prompt, temperature: 0.6, max_tokens: 400, user: bot_user)
|
||||
summary =
|
||||
llm.generate(
|
||||
prompt,
|
||||
temperature: 0.6,
|
||||
max_tokens: 400,
|
||||
user: bot_user,
|
||||
feature_name: "summarize_tool",
|
||||
)
|
||||
|
||||
summaries << summary
|
||||
end
|
||||
|
@ -150,7 +157,13 @@ module DiscourseAi
|
|||
"concatenated the disjoint summaries, creating a cohesive narrative:\n#{summaries.join("\n")}}",
|
||||
}
|
||||
|
||||
llm.generate(concatenation_prompt, temperature: 0.6, max_tokens: 500, user: bot_user)
|
||||
llm.generate(
|
||||
concatenation_prompt,
|
||||
temperature: 0.6,
|
||||
max_tokens: 500,
|
||||
user: bot_user,
|
||||
feature_name: "summarize_tool",
|
||||
)
|
||||
else
|
||||
summaries.first
|
||||
end
|
||||
|
|
|
@ -85,6 +85,7 @@ module DiscourseAi
|
|||
user: user,
|
||||
temperature: completion_prompt.temperature,
|
||||
stop_sequences: completion_prompt.stop_sequences,
|
||||
feature_name: "ai_helper",
|
||||
&block
|
||||
)
|
||||
end
|
||||
|
@ -163,6 +164,7 @@ module DiscourseAi
|
|||
prompt,
|
||||
user: Discourse.system_user,
|
||||
max_tokens: 1024,
|
||||
feature_name: "image_caption",
|
||||
)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -32,6 +32,7 @@ module DiscourseAi
|
|||
prompt,
|
||||
user: Discourse.system_user,
|
||||
stop_sequences: ["</input>"],
|
||||
feature_name: "chat_thread_title",
|
||||
)
|
||||
end
|
||||
|
||||
|
|
|
@ -68,6 +68,7 @@ module DiscourseAi
|
|||
DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model).generate(
|
||||
prompt,
|
||||
user: user,
|
||||
feature_name: "illustrate_post",
|
||||
)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
module DiscourseAi
|
||||
module Automation
|
||||
AVAILABLE_MODELS = [
|
||||
{ id: "gpt-4o", name: "discourse_automation.ai_models.gpt_4o" },
|
||||
{ id: "gpt-4-turbo", name: "discourse_automation.ai_models.gpt_4_turbo" },
|
||||
{ id: "gpt-4", name: "discourse_automation.ai_models.gpt_4" },
|
||||
{ id: "gpt-3.5-turbo", name: "discourse_automation.ai_models.gpt_3_5_turbo" },
|
||||
|
|
|
@ -41,6 +41,7 @@ module DiscourseAi
|
|||
temperature: 0,
|
||||
max_tokens: llm.tokenizer.tokenize(search_for_text).length * 2 + 10,
|
||||
user: Discourse.system_user,
|
||||
feature_name: "llm_triage",
|
||||
)
|
||||
|
||||
if result.present? && result.strip.downcase.include?(search_for_text)
|
||||
|
|
|
@ -154,6 +154,7 @@ Follow the provided writing composition instructions carefully and precisely ste
|
|||
temperature: @temperature,
|
||||
top_p: @top_p,
|
||||
user: Discourse.system_user,
|
||||
feature_name: "ai_report",
|
||||
) do |response|
|
||||
print response if Rails.env.development? && @debug_mode
|
||||
result << response
|
||||
|
|
|
@ -83,7 +83,8 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def inline_images(content, message)
|
||||
if model_name.include?("gpt-4-vision") || model_name == "gpt-4-turbo"
|
||||
if model_name.include?("gpt-4-vision") || model_name == "gpt-4-turbo" ||
|
||||
model_name == "gpt-4o"
|
||||
content = message[:content]
|
||||
encoded_uploads = prompt.encoded_uploads(message)
|
||||
if encoded_uploads.present?
|
||||
|
@ -125,6 +126,8 @@ module DiscourseAi
|
|||
32_768
|
||||
when "gpt-4-turbo"
|
||||
131_072
|
||||
when "gpt-4o"
|
||||
131_072
|
||||
else
|
||||
8192
|
||||
end
|
||||
|
|
|
@ -73,7 +73,7 @@ module DiscourseAi
|
|||
true
|
||||
end
|
||||
|
||||
def perform_completion!(dialect, user, model_params = {}, &blk)
|
||||
def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &blk)
|
||||
allow_tools = dialect.prompt.has_tools?
|
||||
model_params = normalize_model_params(model_params)
|
||||
|
||||
|
@ -114,6 +114,7 @@ module DiscourseAi
|
|||
request_tokens: prompt_size(prompt),
|
||||
topic_id: dialect.prompt.topic_id,
|
||||
post_id: dialect.prompt.post_id,
|
||||
feature_name: feature_name,
|
||||
)
|
||||
|
||||
if !@streaming_mode
|
||||
|
|
|
@ -23,7 +23,7 @@ module DiscourseAi
|
|||
|
||||
attr_reader :responses, :completions, :prompt
|
||||
|
||||
def perform_completion!(prompt, _user, _model_params)
|
||||
def perform_completion!(prompt, _user, _model_params, feature_name: nil)
|
||||
@prompt = prompt
|
||||
response = responses[completions]
|
||||
if response.nil?
|
||||
|
|
|
@ -110,7 +110,7 @@ module DiscourseAi
|
|||
@last_call = params
|
||||
end
|
||||
|
||||
def perform_completion!(dialect, user, model_params = {})
|
||||
def perform_completion!(dialect, user, model_params = {}, feature_name: nil)
|
||||
self.class.last_call = { dialect: dialect, user: user, model_params: model_params }
|
||||
|
||||
content = self.class.fake_content
|
||||
|
|
|
@ -12,6 +12,7 @@ module DiscourseAi
|
|||
def dependant_setting_names
|
||||
%w[
|
||||
ai_openai_api_key
|
||||
ai_openai_gpt4o_url
|
||||
ai_openai_gpt4_32k_url
|
||||
ai_openai_gpt4_turbo_url
|
||||
ai_openai_gpt4_url
|
||||
|
@ -33,6 +34,8 @@ module DiscourseAi
|
|||
else
|
||||
if model.include?("1106") || model.include?("turbo")
|
||||
SiteSetting.ai_openai_gpt4_turbo_url
|
||||
elsif model.include?("gpt-4o")
|
||||
SiteSetting.ai_openai_gpt4o_url
|
||||
else
|
||||
SiteSetting.ai_openai_gpt4_url
|
||||
end
|
||||
|
@ -98,35 +101,47 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
default_options
|
||||
.merge(model_params)
|
||||
.merge(messages: prompt)
|
||||
.tap do |payload|
|
||||
payload[:stream] = true if @streaming_mode
|
||||
payload[:tools] = dialect.tools if dialect.tools.present?
|
||||
end
|
||||
payload = default_options.merge(model_params).merge(messages: prompt)
|
||||
|
||||
if @streaming_mode
|
||||
payload[:stream] = true
|
||||
payload[:stream_options] = { include_usage: true }
|
||||
end
|
||||
|
||||
payload[:tools] = dialect.tools if dialect.tools.present?
|
||||
payload
|
||||
end
|
||||
|
||||
def prepare_request(payload)
|
||||
headers =
|
||||
{ "Content-Type" => "application/json" }.tap do |h|
|
||||
if model_uri.host.include?("azure")
|
||||
h["api-key"] = SiteSetting.ai_openai_api_key
|
||||
else
|
||||
h["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
||||
end
|
||||
headers = { "Content-Type" => "application/json" }
|
||||
|
||||
if SiteSetting.ai_openai_organization.present?
|
||||
h["OpenAI-Organization"] = SiteSetting.ai_openai_organization
|
||||
end
|
||||
end
|
||||
if model_uri.host.include?("azure")
|
||||
headers["api-key"] = SiteSetting.ai_openai_api_key
|
||||
else
|
||||
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
||||
end
|
||||
|
||||
if SiteSetting.ai_openai_organization.present?
|
||||
headers["OpenAI-Organization"] = SiteSetting.ai_openai_organization
|
||||
end
|
||||
|
||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
||||
def final_log_update(log)
|
||||
log.request_tokens = @prompt_tokens if @prompt_tokens
|
||||
log.response_tokens = @completion_tokens if @completion_tokens
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
|
||||
# half a line sent here
|
||||
json = JSON.parse(response_raw, symbolize_names: true)
|
||||
|
||||
if @streaming_mode
|
||||
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
|
||||
@completion_tokens ||= json.dig(:usage, :completion_tokens)
|
||||
end
|
||||
|
||||
parsed = json.dig(:choices, 0)
|
||||
return if !parsed
|
||||
|
||||
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
||||
|
|
|
@ -54,6 +54,7 @@ module DiscourseAi
|
|||
gpt-4-32k
|
||||
gpt-4-turbo
|
||||
gpt-4-vision-preview
|
||||
gpt-4o
|
||||
],
|
||||
google: %w[gemini-pro gemini-1.5-pro],
|
||||
}.tap do |h|
|
||||
|
@ -106,12 +107,6 @@ module DiscourseAi
|
|||
dialect_klass =
|
||||
DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov)
|
||||
|
||||
if is_custom_model
|
||||
tokenizer = llm_model.tokenizer_class
|
||||
else
|
||||
tokenizer = dialect_klass.tokenizer
|
||||
end
|
||||
|
||||
if @canned_response
|
||||
if @canned_llm && @canned_llm != model_name
|
||||
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
|
||||
|
@ -164,6 +159,7 @@ module DiscourseAi
|
|||
max_tokens: nil,
|
||||
stop_sequences: nil,
|
||||
user:,
|
||||
feature_name: nil,
|
||||
&partial_read_blk
|
||||
)
|
||||
self.class.record_prompt(prompt)
|
||||
|
@ -196,7 +192,13 @@ module DiscourseAi
|
|||
model_name,
|
||||
opts: model_params.merge(max_prompt_tokens: @max_prompt_tokens),
|
||||
)
|
||||
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
|
||||
gateway.perform_completion!(
|
||||
dialect,
|
||||
user,
|
||||
model_params,
|
||||
feature_name: feature_name,
|
||||
&partial_read_blk
|
||||
)
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
|
|
|
@ -69,7 +69,7 @@ module DiscourseAi
|
|||
def can_talk_to_model?(model_name)
|
||||
DiscourseAi::Completions::Llm
|
||||
.proxy(model_name)
|
||||
.generate("How much is 1 + 1?", user: nil)
|
||||
.generate("How much is 1 + 1?", user: nil, feature_name: "llm_validator")
|
||||
.present?
|
||||
rescue StandardError
|
||||
false
|
||||
|
|
|
@ -169,7 +169,7 @@ module DiscourseAi
|
|||
llm_response =
|
||||
DiscourseAi::Completions::Llm.proxy(
|
||||
SiteSetting.ai_embeddings_semantic_search_hyde_model,
|
||||
).generate(prompt, user: @guardian.user)
|
||||
).generate(prompt, user: @guardian.user, feature_name: "semantic_search_hyde")
|
||||
|
||||
Nokogiri::HTML5.fragment(llm_response).at("ai")&.text&.presence || llm_response
|
||||
end
|
||||
|
|
|
@ -8,6 +8,7 @@ module DiscourseAi
|
|||
Models::OpenAi.new("open_ai:gpt-4", max_tokens: 8192),
|
||||
Models::OpenAi.new("open_ai:gpt-4-32k", max_tokens: 32_768),
|
||||
Models::OpenAi.new("open_ai:gpt-4-turbo", max_tokens: 100_000),
|
||||
Models::OpenAi.new("open_ai:gpt-4o", max_tokens: 100_000),
|
||||
Models::OpenAi.new("open_ai:gpt-3.5-turbo", max_tokens: 4096),
|
||||
Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384),
|
||||
Models::Gemini.new("google:gemini-pro", max_tokens: 32_768),
|
||||
|
@ -50,24 +51,31 @@ module DiscourseAi
|
|||
max_tokens: 32_000,
|
||||
)
|
||||
|
||||
LlmModel.all.each do |model|
|
||||
foldable_models << Models::CustomLlm.new(
|
||||
"custom:#{model.id}",
|
||||
max_tokens: model.max_prompt_tokens,
|
||||
)
|
||||
end
|
||||
# TODO: Roman, we need to de-register custom LLMs on destroy from summarization
|
||||
# strategy and clear cache
|
||||
# it may be better to pull all of this code into Discourse AI cause as it stands
|
||||
# the coupling is making it really hard to reason about summarization
|
||||
#
|
||||
# Auto registration and de-registration needs to be tested
|
||||
|
||||
#LlmModel.all.each do |model|
|
||||
# foldable_models << Models::CustomLlm.new(
|
||||
# "custom:#{model.id}",
|
||||
# max_tokens: model.max_prompt_tokens,
|
||||
# )
|
||||
#end
|
||||
|
||||
foldable_models.each do |model|
|
||||
plugin.register_summarization_strategy(Strategies::FoldContent.new(model))
|
||||
end
|
||||
|
||||
plugin.add_model_callback(LlmModel, :after_create) do
|
||||
new_model = Models::CustomLlm.new("custom:#{self.id}", max_tokens: self.max_prompt_tokens)
|
||||
#plugin.add_model_callback(LlmModel, :after_create) do
|
||||
# new_model = Models::CustomLlm.new("custom:#{self.id}", max_tokens: self.max_prompt_tokens)
|
||||
|
||||
if ::Summarization::Base.find_strategy("custom:#{self.id}").nil?
|
||||
plugin.register_summarization_strategy(Strategies::FoldContent.new(new_model))
|
||||
end
|
||||
end
|
||||
# if ::Summarization::Base.find_strategy("custom:#{self.id}").nil?
|
||||
# plugin.register_summarization_strategy(Strategies::FoldContent.new(new_model))
|
||||
# end
|
||||
#end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -99,14 +99,19 @@ module DiscourseAi
|
|||
def summarize_single(llm, text, user, opts, &on_partial_blk)
|
||||
prompt = summarization_prompt(text, opts)
|
||||
|
||||
llm.generate(prompt, user: user, &on_partial_blk)
|
||||
llm.generate(prompt, user: user, feature_name: "summarize", &on_partial_blk)
|
||||
end
|
||||
|
||||
def summarize_in_chunks(llm, chunks, user, opts)
|
||||
chunks.map do |chunk|
|
||||
prompt = summarization_prompt(chunk[:summary], opts)
|
||||
|
||||
chunk[:summary] = llm.generate(prompt, user: user, max_tokens: 300)
|
||||
chunk[:summary] = llm.generate(
|
||||
prompt,
|
||||
user: user,
|
||||
max_tokens: 300,
|
||||
feature_name: "summarize",
|
||||
)
|
||||
chunk
|
||||
end
|
||||
end
|
||||
|
|
|
@ -268,7 +268,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
).to_return(status: 200, body: body)
|
||||
|
||||
result = +""
|
||||
llm.generate(prompt, user: Discourse.system_user) { |partial, cancel| result << partial }
|
||||
llm.generate(prompt, user: Discourse.system_user, feature_name: "testing") do |partial, cancel|
|
||||
result << partial
|
||||
end
|
||||
|
||||
expect(result).to eq("Hello!")
|
||||
|
||||
|
@ -285,6 +287,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
|
||||
expect(log.request_tokens).to eq(25)
|
||||
expect(log.response_tokens).to eq(15)
|
||||
expect(log.feature_name).to eq("testing")
|
||||
end
|
||||
|
||||
it "can return multiple function calls" do
|
||||
|
|
|
@ -135,7 +135,10 @@ class OpenAiMock < EndpointMock
|
|||
.default_options
|
||||
.merge(messages: prompt)
|
||||
.tap do |b|
|
||||
b[:stream] = true if stream
|
||||
if stream
|
||||
b[:stream] = true
|
||||
b[:stream_options] = { include_usage: true }
|
||||
end
|
||||
b[:tools] = [tool_payload] if tool_call
|
||||
end
|
||||
.to_json
|
||||
|
@ -431,6 +434,36 @@ TEXT
|
|||
expect(content).to eq(expected)
|
||||
end
|
||||
|
||||
it "uses proper token accounting" do
|
||||
response = <<~TEXT.strip
|
||||
data: {"id":"chatcmpl-9OZidiHncpBhhNMcqCus9XiJ3TkqR","object":"chat.completion.chunk","created":1715644203,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_729ea513f7","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":null}|
|
||||
|
||||
data: {"id":"chatcmpl-9OZidiHncpBhhNMcqCus9XiJ3TkqR","object":"chat.completion.chunk","created":1715644203,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_729ea513f7","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":null}|
|
||||
|
||||
data: {"id":"chatcmpl-9OZidiHncpBhhNMcqCus9XiJ3TkqR","object":"chat.completion.chunk","created":1715644203,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_729ea513f7","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null}|
|
||||
|
||||
data: {"id":"chatcmpl-9OZidiHncpBhhNMcqCus9XiJ3TkqR","object":"chat.completion.chunk","created":1715644203,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_729ea513f7","choices":[],"usage":{"prompt_tokens":20,"completion_tokens":9,"total_tokens":29}}|
|
||||
|
||||
data: [DONE]
|
||||
TEXT
|
||||
|
||||
chunks = response.split("|")
|
||||
open_ai_mock.with_chunk_array_support do
|
||||
open_ai_mock.stub_raw(chunks)
|
||||
partials = []
|
||||
|
||||
dialect = compliance.dialect(prompt: compliance.generic_prompt)
|
||||
endpoint.perform_completion!(dialect, user) { |partial| partials << partial }
|
||||
|
||||
expect(partials).to eq(["Hello"])
|
||||
|
||||
log = AiApiAuditLog.order("id desc").first
|
||||
|
||||
expect(log.request_tokens).to eq(20)
|
||||
expect(log.response_tokens).to eq(9)
|
||||
end
|
||||
end
|
||||
|
||||
it "properly handles spaces in tools payload" do
|
||||
raw_data = <<~TEXT.strip
|
||||
data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"go|ogle","arg|uments":""}}]}}]}
|
||||
|
|
Loading…
Reference in New Issue