mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-02-27 14:08:10 +00:00
This is a significant PR that introduces AI Artifacts functionality to the discourse-ai plugin along with several other improvements. Here are the key changes: 1. AI Artifacts System: - Adds a new `AiArtifact` model and database migration - Allows creation of web artifacts with HTML, CSS, and JavaScript content - Introduces security settings (`strict`, `lax`, `disabled`) for controlling artifact execution - Implements artifact rendering in iframes with sandbox protection - New `CreateArtifact` tool for AI to generate interactive content 2. Tool System Improvements: - Adds support for partial tool calls, allowing incremental updates during generation - Better handling of tool call states and progress tracking - Improved XML tool processing with CDATA support - Fixes for tool parameter handling and duplicate invocations 3. LLM Provider Updates: - Updates for Anthropic Claude models with correct token limits - Adds support for native/XML tool modes in Gemini integration - Adds new model configurations including Llama 3.1 models - Improvements to streaming response handling 4. UI Enhancements: - New artifact viewer component with expand/collapse functionality - Security controls for artifact execution (click-to-run in strict mode) - Improved dialog and response handling - Better error management for tool execution 5. Security Improvements: - Sandbox controls for artifact execution - Public/private artifact sharing controls - Security settings to control artifact behavior - CSP and frame-options handling for artifacts 6. Technical Improvements: - Better post streaming implementation - Improved error handling in completions - Better memory management for partial tool calls - Enhanced testing coverage 7. Configuration: - New site settings for artifact security - Extended LLM model configurations - Additional tool configuration options This PR significantly enhances the plugin's capabilities for generating and displaying interactive content while maintaining security and providing flexible configuration options for administrators.
325 lines
9.7 KiB
Ruby
325 lines
9.7 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Completions
|
|
module Endpoints
|
|
class Base
|
|
attr_reader :partial_tool_calls
|
|
|
|
CompletionFailed = Class.new(StandardError)
|
|
TIMEOUT = 60
|
|
|
|
class << self
|
|
def endpoint_for(provider_name)
|
|
endpoints = [
|
|
DiscourseAi::Completions::Endpoints::AwsBedrock,
|
|
DiscourseAi::Completions::Endpoints::OpenAi,
|
|
DiscourseAi::Completions::Endpoints::HuggingFace,
|
|
DiscourseAi::Completions::Endpoints::Gemini,
|
|
DiscourseAi::Completions::Endpoints::Vllm,
|
|
DiscourseAi::Completions::Endpoints::Anthropic,
|
|
DiscourseAi::Completions::Endpoints::Cohere,
|
|
DiscourseAi::Completions::Endpoints::SambaNova,
|
|
]
|
|
|
|
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
|
|
|
|
if Rails.env.test? || Rails.env.development?
|
|
endpoints << DiscourseAi::Completions::Endpoints::Fake
|
|
end
|
|
|
|
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
|
ek.can_contact?(provider_name)
|
|
end
|
|
end
|
|
|
|
def can_contact?(_model_provider)
|
|
raise NotImplementedError
|
|
end
|
|
end
|
|
|
|
def initialize(llm_model)
|
|
@llm_model = llm_model
|
|
end
|
|
|
|
def use_ssl?
|
|
if model_uri&.scheme.present?
|
|
model_uri.scheme == "https"
|
|
else
|
|
true
|
|
end
|
|
end
|
|
|
|
def xml_tags_to_strip(dialect)
|
|
[]
|
|
end
|
|
|
|
def perform_completion!(
|
|
dialect,
|
|
user,
|
|
model_params = {},
|
|
feature_name: nil,
|
|
feature_context: nil,
|
|
partial_tool_calls: false,
|
|
&blk
|
|
)
|
|
@partial_tool_calls = partial_tool_calls
|
|
model_params = normalize_model_params(model_params)
|
|
orig_blk = blk
|
|
|
|
@streaming_mode = block_given?
|
|
|
|
prompt = dialect.translate
|
|
|
|
FinalDestination::HTTP.start(
|
|
model_uri.host,
|
|
model_uri.port,
|
|
use_ssl: use_ssl?,
|
|
read_timeout: TIMEOUT,
|
|
open_timeout: TIMEOUT,
|
|
write_timeout: TIMEOUT,
|
|
) do |http|
|
|
response_data = +""
|
|
response_raw = +""
|
|
|
|
# Needed to response token calculations. Cannot rely on response_data due to function buffering.
|
|
partials_raw = +""
|
|
request_body = prepare_payload(prompt, model_params, dialect).to_json
|
|
|
|
request = prepare_request(request_body)
|
|
|
|
http.request(request) do |response|
|
|
if response.code.to_i != 200
|
|
Rails.logger.error(
|
|
"#{self.class.name}: status: #{response.code.to_i} - body: #{response.body}",
|
|
)
|
|
raise CompletionFailed, response.body
|
|
end
|
|
|
|
xml_tool_processor =
|
|
XmlToolProcessor.new(
|
|
partial_tool_calls: partial_tool_calls,
|
|
) if xml_tools_enabled? && dialect.prompt.has_tools?
|
|
|
|
to_strip = xml_tags_to_strip(dialect)
|
|
xml_stripper =
|
|
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
|
|
|
|
if @streaming_mode && xml_stripper
|
|
blk =
|
|
lambda do |partial, cancel|
|
|
partial = xml_stripper << partial if partial.is_a?(String)
|
|
orig_blk.call(partial, cancel) if partial
|
|
end
|
|
end
|
|
|
|
log =
|
|
start_log(
|
|
provider_id: provider_id,
|
|
request_body: request_body,
|
|
dialect: dialect,
|
|
prompt: prompt,
|
|
user: user,
|
|
feature_name: feature_name,
|
|
feature_context: feature_context,
|
|
)
|
|
|
|
if !@streaming_mode
|
|
return(
|
|
non_streaming_response(
|
|
response: response,
|
|
xml_tool_processor: xml_tool_processor,
|
|
xml_stripper: xml_stripper,
|
|
partials_raw: partials_raw,
|
|
response_raw: response_raw,
|
|
)
|
|
)
|
|
end
|
|
|
|
begin
|
|
cancelled = false
|
|
cancel = -> { cancelled = true }
|
|
if cancelled
|
|
http.finish
|
|
break
|
|
end
|
|
|
|
response.read_body do |chunk|
|
|
response_raw << chunk
|
|
decode_chunk(chunk).each do |partial|
|
|
partials_raw << partial.to_s
|
|
response_data << partial if partial.is_a?(String)
|
|
partials = [partial]
|
|
if xml_tool_processor && partial.is_a?(String)
|
|
partials = (xml_tool_processor << partial)
|
|
if xml_tool_processor.should_cancel?
|
|
cancel.call
|
|
break
|
|
end
|
|
end
|
|
partials.each { |inner_partial| blk.call(inner_partial, cancel) }
|
|
end
|
|
end
|
|
rescue IOError, StandardError
|
|
raise if !cancelled
|
|
end
|
|
if xml_stripper
|
|
stripped = xml_stripper.finish
|
|
if stripped.present?
|
|
response_data << stripped
|
|
result = []
|
|
result = (xml_tool_processor << stripped) if xml_tool_processor
|
|
result.each { |partial| blk.call(partial, cancel) }
|
|
end
|
|
end
|
|
if xml_tool_processor
|
|
xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) }
|
|
end
|
|
decode_chunk_finish.each { |partial| blk.call(partial, cancel) }
|
|
return response_data
|
|
ensure
|
|
if log
|
|
log.raw_response_payload = response_raw
|
|
final_log_update(log)
|
|
|
|
log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
|
|
log.save!
|
|
|
|
if Rails.env.development?
|
|
puts "#{self.class.name}: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
def final_log_update(log)
|
|
# for people that need to override
|
|
end
|
|
|
|
def default_options
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def provider_id
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def prompt_size(prompt)
|
|
tokenizer.size(extract_prompt_for_tokenizer(prompt))
|
|
end
|
|
|
|
attr_reader :llm_model
|
|
|
|
protected
|
|
|
|
def tokenizer
|
|
llm_model.tokenizer_class
|
|
end
|
|
|
|
# should normalize temperature, max_tokens, stop_words to endpoint specific values
|
|
def normalize_model_params(model_params)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def model_uri
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def prepare_payload(_prompt, _model_params)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def prepare_request(_payload)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def decode(_response_raw)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def decode_chunk_finish
|
|
[]
|
|
end
|
|
|
|
def decode_chunk(_chunk)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def extract_prompt_for_tokenizer(prompt)
|
|
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
|
end
|
|
|
|
def xml_tools_enabled?
|
|
raise NotImplementedError
|
|
end
|
|
|
|
private
|
|
|
|
def start_log(
|
|
provider_id:,
|
|
request_body:,
|
|
dialect:,
|
|
prompt:,
|
|
user:,
|
|
feature_name:,
|
|
feature_context:
|
|
)
|
|
AiApiAuditLog.new(
|
|
provider_id: provider_id,
|
|
user_id: user&.id,
|
|
raw_request_payload: request_body,
|
|
request_tokens: prompt_size(prompt),
|
|
topic_id: dialect.prompt.topic_id,
|
|
post_id: dialect.prompt.post_id,
|
|
feature_name: feature_name,
|
|
language_model: llm_model.name,
|
|
feature_context: feature_context.present? ? feature_context.as_json : nil,
|
|
)
|
|
end
|
|
|
|
def non_streaming_response(
|
|
response:,
|
|
xml_tool_processor:,
|
|
xml_stripper:,
|
|
partials_raw:,
|
|
response_raw:
|
|
)
|
|
response_raw << response.read_body
|
|
response_data = decode(response_raw)
|
|
|
|
response_data.each { |partial| partials_raw << partial.to_s }
|
|
|
|
if xml_tool_processor
|
|
response_data.each do |partial|
|
|
processed = (xml_tool_processor << partial)
|
|
processed << xml_tool_processor.finish
|
|
response_data = []
|
|
processed.flatten.compact.each { |inner| response_data << inner }
|
|
end
|
|
end
|
|
|
|
if xml_stripper
|
|
response_data.map! do |partial|
|
|
stripped = (xml_stripper << partial) if partial.is_a?(String)
|
|
if stripped.present?
|
|
stripped
|
|
else
|
|
partial
|
|
end
|
|
end
|
|
response_data << xml_stripper.finish
|
|
end
|
|
|
|
response_data.reject!(&:blank?)
|
|
|
|
# this is to keep stuff backwards compatible
|
|
response_data = response_data.first if response_data.length == 1
|
|
|
|
response_data
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|