2023-11-27 07:33:31 -05:00
# frozen_string_literal: true
module DiscourseAi
module AiHelper
class Assistant
2024-04-26 08:28:35 -04:00
def self . prompt_cache
@prompt_cache || = :: DiscourseAi :: MultisiteHash . new ( " prompt_cache " )
end
2024-02-16 13:57:14 -05:00
2024-02-18 23:21:55 -05:00
def self . clear_prompt_cache!
2024-04-26 08:28:35 -04:00
prompt_cache . flush!
2024-02-18 23:21:55 -05:00
end
2024-02-16 13:57:14 -05:00
def available_prompts
2024-04-26 08:28:35 -04:00
key = " prompt_cache_ #{ I18n . locale } "
self
. class
. prompt_cache
. fetch ( key ) do
2024-02-16 13:57:14 -05:00
prompts = CompletionPrompt . where ( enabled : true )
# Hide illustrate_post if disabled
prompts =
prompts . where . not (
name : " illustrate_post " ,
) if SiteSetting . ai_helper_illustrate_post_model == " disabled "
prompts =
prompts . map do | prompt |
translation =
I18n . t ( " discourse_ai.ai_helper.prompts. #{ prompt . name } " , default : nil ) ||
prompt . translated_name || prompt . name
{
id : prompt . id ,
name : prompt . name ,
translated_name : translation ,
prompt_type : prompt . prompt_type ,
icon : icon_map ( prompt . name ) ,
location : location_map ( prompt . name ) ,
}
end
prompts
end
2023-11-27 07:33:31 -05:00
end
2024-02-27 14:31:51 -05:00
def custom_locale_instructions ( user = nil )
locale = SiteSetting . default_locale
locale = user . locale || SiteSetting . default_locale if SiteSetting . allow_user_locale && user
locale_hash = LocaleSiteSetting . language_names [ locale ]
if locale != " en " && locale_hash
locale_description = " #{ locale_hash [ " name " ] } ( #{ locale_hash [ " nativeName " ] } ) "
" It is imperative that you write your answer in #{ locale_description } , you are interacting with a #{ locale_description } speaking user. Leave tag names in English. "
else
nil
end
end
def localize_prompt! ( prompt , user = nil )
locale_instructions = custom_locale_instructions ( user )
if locale_instructions
prompt . messages [ 0 ] [ :content ] = prompt . messages [ 0 ] [ :content ] + locale_instructions
end
if prompt . messages [ 0 ] [ :content ] . include? ( " %LANGUAGE% " )
locale = SiteSetting . default_locale
2024-04-22 12:55:10 -04:00
locale = user . locale if SiteSetting . allow_user_locale && user & . locale . present?
2024-02-27 14:31:51 -05:00
locale_hash = LocaleSiteSetting . language_names [ locale ]
prompt . messages [ 0 ] [ :content ] = prompt . messages [ 0 ] [ :content ] . gsub (
" %LANGUAGE% " ,
" #{ locale_hash [ " name " ] } " ,
)
end
end
2023-12-12 12:28:39 -05:00
def generate_prompt ( completion_prompt , input , user , & block )
2023-11-28 23:17:46 -05:00
llm = DiscourseAi :: Completions :: Llm . proxy ( SiteSetting . ai_helper_model )
2024-01-12 12:36:44 -05:00
prompt = completion_prompt . messages_with_input ( input )
2024-02-27 14:31:51 -05:00
localize_prompt! ( prompt , user )
2023-11-27 07:33:31 -05:00
2024-01-04 07:53:47 -05:00
llm . generate (
2024-01-12 12:36:44 -05:00
prompt ,
2024-01-04 07:53:47 -05:00
user : user ,
temperature : completion_prompt . temperature ,
stop_sequences : completion_prompt . stop_sequences ,
& block
)
2023-12-12 12:28:39 -05:00
end
def generate_and_send_prompt ( completion_prompt , input , user )
completion_result = generate_prompt ( completion_prompt , input , user )
2023-11-27 07:33:31 -05:00
result = { type : completion_prompt . prompt_type }
result [ :suggestions ] = (
if completion_prompt . list?
2023-12-12 12:28:39 -05:00
parse_list ( completion_result ) . map { | suggestion | sanitize_result ( suggestion ) }
2023-11-27 07:33:31 -05:00
else
2024-01-04 07:53:47 -05:00
sanitized = sanitize_result ( completion_result )
result [ :diff ] = parse_diff ( input , sanitized ) if completion_prompt . diff?
[ sanitized ]
2023-11-27 07:33:31 -05:00
end
)
result
end
2023-12-12 12:28:39 -05:00
def stream_prompt ( completion_prompt , input , user , channel )
streamed_result = + " "
start = Time . now
generate_prompt ( completion_prompt , input , user ) do | partial_response , cancel_function |
streamed_result << partial_response
# Throttle the updates
if ( Time . now - start > 0 . 5 ) || Rails . env . test?
payload = { result : sanitize_result ( streamed_result ) , done : false }
publish_update ( channel , payload , user )
start = Time . now
end
end
sanitized_result = sanitize_result ( streamed_result )
if sanitized_result . present?
publish_update ( channel , { result : sanitized_result , done : true } , user )
end
end
2024-02-19 12:56:28 -05:00
def generate_image_caption ( image_url , user )
if SiteSetting . ai_helper_image_caption_model == " llava "
parameters = {
input : {
image : image_url ,
top_p : 1 ,
max_tokens : 1024 ,
temperature : 0 . 2 ,
prompt : " Please describe this image in a single sentence " ,
} ,
}
:: DiscourseAi :: Inference :: Llava . perform! ( parameters ) . dig ( :output ) . join
else
prompt =
DiscourseAi :: Completions :: Prompt . new (
messages : [
{
type : :user ,
content : [
2024-02-27 14:31:51 -05:00
{
type : " text " ,
text :
2024-02-28 00:46:32 -05:00
" Describe this image in a single sentence #{ custom_locale_instructions ( user ) } " ,
2024-02-27 14:31:51 -05:00
} ,
2024-02-19 12:56:28 -05:00
{ type : " image_url " , image_url : image_url } ,
] ,
} ,
] ,
skip_validations : true ,
)
DiscourseAi :: Completions :: Llm . proxy ( SiteSetting . ai_helper_image_caption_model ) . generate (
prompt ,
user : Discourse . system_user ,
max_tokens : 1024 ,
)
end
end
2023-11-27 07:33:31 -05:00
private
2024-01-04 07:53:47 -05:00
SANITIZE_REGEX_STR =
%w[ term context topic replyTo input output result ]
. map { | tag | " < #{ tag } > \\ n?| \\ n?</ #{ tag } > " }
. join ( " | " )
SANITIZE_REGEX = Regexp . new ( SANITIZE_REGEX_STR , Regexp :: IGNORECASE | Regexp :: MULTILINE )
2023-12-12 12:28:39 -05:00
def sanitize_result ( result )
2024-01-04 07:53:47 -05:00
result . gsub ( SANITIZE_REGEX , " " )
2023-12-12 12:28:39 -05:00
end
def publish_update ( channel , payload , user )
MessageBus . publish ( channel , payload , user_ids : [ user . id ] )
end
2023-11-27 07:33:31 -05:00
def icon_map ( name )
case name
when " translate "
" language "
when " generate_titles "
" heading "
when " proofread "
" spell-check "
when " markdown_table "
" table "
when " tone "
" microphone "
when " custom_prompt "
" comment "
when " rewrite "
" pen "
when " explain "
" question "
2023-12-19 14:17:34 -05:00
when " illustrate_post "
" images "
2023-11-27 07:33:31 -05:00
else
nil
end
end
def location_map ( name )
case name
when " translate "
%w[ composer post ]
when " generate_titles "
%w[ composer ]
when " proofread "
2023-12-14 22:30:52 -05:00
%w[ composer post ]
2023-11-27 07:33:31 -05:00
when " markdown_table "
%w[ composer ]
when " tone "
%w[ composer ]
when " custom_prompt "
2023-12-14 11:47:20 -05:00
%w[ composer post ]
2023-11-27 07:33:31 -05:00
when " rewrite "
%w[ composer ]
when " explain "
%w[ post ]
when " summarize "
%w[ post ]
2023-12-19 15:55:43 -05:00
when " illustrate_post "
2023-12-19 14:17:34 -05:00
%w[ composer ]
2023-11-27 07:33:31 -05:00
else
%w[ composer post ]
end
end
def parse_diff ( text , suggestion )
cooked_text = PrettyText . cook ( text )
cooked_suggestion = PrettyText . cook ( suggestion )
DiscourseDiff . new ( cooked_text , cooked_suggestion ) . inline_html
end
def parse_list ( list )
2023-11-28 10:52:22 -05:00
Nokogiri :: HTML5 . fragment ( list ) . css ( " item " ) . map ( & :text )
2023-11-27 07:33:31 -05:00
end
end
end
end