FEATURE: Add vision support to AI personas (Claude 3) (#546)
This commit adds the ability to enable vision for AI personas, allowing them to understand images that are posted in the conversation. For personas with vision enabled, any images the user has posted will be resized to be within the configured max_pixels limit, base64 encoded and included in the prompt sent to the AI provider. The persona editor allows enabling/disabling vision and has a dropdown to select the max supported image size (low, medium, high). Vision is disabled by default. This initial vision support has been tested and implemented with Anthropic's claude-3 models which accept images in a special format as part of the prompt. Other integrations will need to be updated to support images. Several specs were added to test the new functionality at the persona, prompt building and API layers. - Gemini is omitted, pending API support for Gemini 1.5. Current Gemini bot is not performing well, adding images is unlikely to make it perform any better. - Open AI is omitted, vision support on GPT-4 it limited in that the API has no tool support when images are enabled so we would need to full back to a different prompting technique, something that would add lots of complexity --------- Co-authored-by: Martin Brennan <martin@discourse.org>
This commit is contained in:
parent
82387cc51d
commit
61e4c56e1a
|
@ -79,6 +79,8 @@ module DiscourseAi
|
||||||
:user_id,
|
:user_id,
|
||||||
:mentionable,
|
:mentionable,
|
||||||
:max_context_posts,
|
:max_context_posts,
|
||||||
|
:vision_enabled,
|
||||||
|
:vision_max_pixels,
|
||||||
allowed_group_ids: [],
|
allowed_group_ids: [],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,9 @@ class AiPersona < ActiveRecord::Base
|
||||||
validates :system_prompt, presence: true, length: { maximum: 10_000_000 }
|
validates :system_prompt, presence: true, length: { maximum: 10_000_000 }
|
||||||
validate :system_persona_unchangeable, on: :update, if: :system
|
validate :system_persona_unchangeable, on: :update, if: :system
|
||||||
validates :max_context_posts, numericality: { greater_than: 0 }, allow_nil: true
|
validates :max_context_posts, numericality: { greater_than: 0 }, allow_nil: true
|
||||||
|
# leaves some room for growth but sets a maximum to avoid memory issues
|
||||||
|
# we may want to revisit this in the future
|
||||||
|
validates :vision_max_pixels, numericality: { greater_than: 0, maximum: 4_000_000 }
|
||||||
|
|
||||||
belongs_to :created_by, class_name: "User"
|
belongs_to :created_by, class_name: "User"
|
||||||
belongs_to :user
|
belongs_to :user
|
||||||
|
@ -98,6 +101,8 @@ class AiPersona < ActiveRecord::Base
|
||||||
mentionable = self.mentionable
|
mentionable = self.mentionable
|
||||||
default_llm = self.default_llm
|
default_llm = self.default_llm
|
||||||
max_context_posts = self.max_context_posts
|
max_context_posts = self.max_context_posts
|
||||||
|
vision_enabled = self.vision_enabled
|
||||||
|
vision_max_pixels = self.vision_max_pixels
|
||||||
|
|
||||||
persona_class = DiscourseAi::AiBot::Personas::Persona.system_personas_by_id[self.id]
|
persona_class = DiscourseAi::AiBot::Personas::Persona.system_personas_by_id[self.id]
|
||||||
if persona_class
|
if persona_class
|
||||||
|
@ -129,6 +134,14 @@ class AiPersona < ActiveRecord::Base
|
||||||
max_context_posts
|
max_context_posts
|
||||||
end
|
end
|
||||||
|
|
||||||
|
persona_class.define_singleton_method :vision_enabled do
|
||||||
|
vision_enabled
|
||||||
|
end
|
||||||
|
|
||||||
|
persona_class.define_singleton_method :vision_max_pixels do
|
||||||
|
vision_max_pixels
|
||||||
|
end
|
||||||
|
|
||||||
return persona_class
|
return persona_class
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -204,6 +217,14 @@ class AiPersona < ActiveRecord::Base
|
||||||
max_context_posts
|
max_context_posts
|
||||||
end
|
end
|
||||||
|
|
||||||
|
define_singleton_method :vision_enabled do
|
||||||
|
vision_enabled
|
||||||
|
end
|
||||||
|
|
||||||
|
define_singleton_method :vision_max_pixels do
|
||||||
|
vision_max_pixels
|
||||||
|
end
|
||||||
|
|
||||||
define_singleton_method :to_s do
|
define_singleton_method :to_s do
|
||||||
"#<DiscourseAi::AiBot::Personas::Persona::Custom @name=#{self.name} @allowed_group_ids=#{self.allowed_group_ids.join(",")}>"
|
"#<DiscourseAi::AiBot::Personas::Persona::Custom @name=#{self.name} @allowed_group_ids=#{self.allowed_group_ids.join(",")}>"
|
||||||
end
|
end
|
||||||
|
|
|
@ -17,7 +17,9 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer
|
||||||
:mentionable,
|
:mentionable,
|
||||||
:default_llm,
|
:default_llm,
|
||||||
:user_id,
|
:user_id,
|
||||||
:max_context_posts
|
:max_context_posts,
|
||||||
|
:vision_enabled,
|
||||||
|
:vision_max_pixels
|
||||||
|
|
||||||
has_one :user, serializer: BasicUserSerializer, embed: :object
|
has_one :user, serializer: BasicUserSerializer, embed: :object
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,8 @@ const ATTRIBUTES = [
|
||||||
"default_llm",
|
"default_llm",
|
||||||
"user",
|
"user",
|
||||||
"max_context_posts",
|
"max_context_posts",
|
||||||
|
"vision_enabled",
|
||||||
|
"vision_max_pixels",
|
||||||
];
|
];
|
||||||
|
|
||||||
const SYSTEM_ATTRIBUTES = [
|
const SYSTEM_ATTRIBUTES = [
|
||||||
|
@ -30,6 +32,8 @@ const SYSTEM_ATTRIBUTES = [
|
||||||
"default_llm",
|
"default_llm",
|
||||||
"user",
|
"user",
|
||||||
"max_context_posts",
|
"max_context_posts",
|
||||||
|
"vision_enabled",
|
||||||
|
"vision_max_pixels",
|
||||||
];
|
];
|
||||||
|
|
||||||
class CommandOption {
|
class CommandOption {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import Component from "@glimmer/component";
|
import Component from "@glimmer/component";
|
||||||
import { tracked } from "@glimmer/tracking";
|
import { cached, tracked } from "@glimmer/tracking";
|
||||||
import { Input } from "@ember/component";
|
import { Input } from "@ember/component";
|
||||||
import { on } from "@ember/modifier";
|
import { on } from "@ember/modifier";
|
||||||
import { action } from "@ember/object";
|
import { action } from "@ember/object";
|
||||||
|
@ -17,6 +17,7 @@ import { popupAjaxError } from "discourse/lib/ajax-error";
|
||||||
import Group from "discourse/models/group";
|
import Group from "discourse/models/group";
|
||||||
import I18n from "discourse-i18n";
|
import I18n from "discourse-i18n";
|
||||||
import AdminUser from "admin/models/admin-user";
|
import AdminUser from "admin/models/admin-user";
|
||||||
|
import ComboBox from "select-kit/components/combo-box";
|
||||||
import GroupChooser from "select-kit/components/group-chooser";
|
import GroupChooser from "select-kit/components/group-chooser";
|
||||||
import DTooltip from "float-kit/components/d-tooltip";
|
import DTooltip from "float-kit/components/d-tooltip";
|
||||||
import AiCommandSelector from "./ai-command-selector";
|
import AiCommandSelector from "./ai-command-selector";
|
||||||
|
@ -34,10 +35,36 @@ export default class PersonaEditor extends Component {
|
||||||
@tracked editingModel = null;
|
@tracked editingModel = null;
|
||||||
@tracked showDelete = false;
|
@tracked showDelete = false;
|
||||||
|
|
||||||
|
@tracked maxPixelsValue = null;
|
||||||
|
|
||||||
@action
|
@action
|
||||||
updateModel() {
|
updateModel() {
|
||||||
this.editingModel = this.args.model.workingCopy();
|
this.editingModel = this.args.model.workingCopy();
|
||||||
this.showDelete = !this.args.model.isNew && !this.args.model.system;
|
this.showDelete = !this.args.model.isNew && !this.args.model.system;
|
||||||
|
this.maxPixelsValue = this.findClosestPixelValue(
|
||||||
|
this.editingModel.vision_max_pixels
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
findClosestPixelValue(pixels) {
|
||||||
|
let value = "high";
|
||||||
|
this.maxPixelValues.forEach((info) => {
|
||||||
|
if (pixels === info.pixels) {
|
||||||
|
value = info.id;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
@cached
|
||||||
|
get maxPixelValues() {
|
||||||
|
const l = (key) =>
|
||||||
|
I18n.t(`discourse_ai.ai_persona.vision_max_pixel_sizes.${key}`);
|
||||||
|
return [
|
||||||
|
{ id: "low", name: l("low"), pixels: 65536 },
|
||||||
|
{ id: "medium", name: l("medium"), pixels: 262144 },
|
||||||
|
{ id: "high", name: l("high"), pixels: 1048576 },
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
@action
|
@action
|
||||||
|
@ -102,6 +129,16 @@ export default class PersonaEditor extends Component {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
onChangeMaxPixels(value) {
|
||||||
|
const entry = this.maxPixelValues.findBy("id", value);
|
||||||
|
if (!entry) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this.maxPixelsValue = value;
|
||||||
|
this.editingModel.vision_max_pixels = entry.pixels;
|
||||||
|
}
|
||||||
|
|
||||||
@action
|
@action
|
||||||
delete() {
|
delete() {
|
||||||
return this.dialog.confirm({
|
return this.dialog.confirm({
|
||||||
|
@ -137,6 +174,11 @@ export default class PersonaEditor extends Component {
|
||||||
await this.toggleField("mentionable");
|
await this.toggleField("mentionable");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
async toggleVisionEnabled() {
|
||||||
|
await this.toggleField("vision_enabled");
|
||||||
|
}
|
||||||
|
|
||||||
@action
|
@action
|
||||||
async createUser() {
|
async createUser() {
|
||||||
try {
|
try {
|
||||||
|
@ -225,6 +267,17 @@ export default class PersonaEditor extends Component {
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
{{/if}}
|
{{/if}}
|
||||||
|
<div class="control-group ai-persona-editor__vision_enabled">
|
||||||
|
<DToggleSwitch
|
||||||
|
@state={{@model.vision_enabled}}
|
||||||
|
@label="discourse_ai.ai_persona.vision_enabled"
|
||||||
|
{{on "click" this.toggleVisionEnabled}}
|
||||||
|
/>
|
||||||
|
<DTooltip
|
||||||
|
@icon="question-circle"
|
||||||
|
@content={{I18n.t "discourse_ai.ai_persona.vision_enabled_help"}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
<div class="control-group">
|
<div class="control-group">
|
||||||
<label>{{I18n.t "discourse_ai.ai_persona.name"}}</label>
|
<label>{{I18n.t "discourse_ai.ai_persona.name"}}</label>
|
||||||
<Input
|
<Input
|
||||||
|
@ -329,6 +382,16 @@ export default class PersonaEditor extends Component {
|
||||||
@content={{I18n.t "discourse_ai.ai_persona.max_context_posts_help"}}
|
@content={{I18n.t "discourse_ai.ai_persona.max_context_posts_help"}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
{{#if @model.vision_enabled}}
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{I18n.t "discourse_ai.ai_persona.vision_max_pixels"}}</label>
|
||||||
|
<ComboBox
|
||||||
|
@value={{this.maxPixelsValue}}
|
||||||
|
@content={{this.maxPixelValues}}
|
||||||
|
@onChange={{this.onChangeMaxPixels}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{{/if}}
|
||||||
{{#if this.showTemperature}}
|
{{#if this.showTemperature}}
|
||||||
<div class="control-group">
|
<div class="control-group">
|
||||||
<label>{{I18n.t "discourse_ai.ai_persona.temperature"}}</label>
|
<label>{{I18n.t "discourse_ai.ai_persona.temperature"}}</label>
|
||||||
|
|
|
@ -71,4 +71,9 @@
|
||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
&__vision_enabled {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,6 +121,13 @@ en:
|
||||||
no_llm_selected: "No language model selected"
|
no_llm_selected: "No language model selected"
|
||||||
max_context_posts: "Max Context Posts"
|
max_context_posts: "Max Context Posts"
|
||||||
max_context_posts_help: "The maximum number of posts to use as context for the AI when responding to a user. (empty for default)"
|
max_context_posts_help: "The maximum number of posts to use as context for the AI when responding to a user. (empty for default)"
|
||||||
|
vision_enabled: Vision Enabled
|
||||||
|
vision_enabled_help: If enabled, the AI will attempt to understand images users post in the topic, depends on the model being used supporting vision. Anthropic Claude 3 models support vision.
|
||||||
|
vision_max_pixels: Supported image size
|
||||||
|
vision_max_pixel_sizes:
|
||||||
|
low: Low Quality - cheapest (256x256)
|
||||||
|
medium: Medium Quality (512x512)
|
||||||
|
high: High Quality - slowest (1024x1024)
|
||||||
mentionable: Mentionable
|
mentionable: Mentionable
|
||||||
mentionable_help: If enabled, users in allowed groups can mention this user in posts and messages, the AI will respond as this persona.
|
mentionable_help: If enabled, users in allowed groups can mention this user in posts and messages, the AI will respond as this persona.
|
||||||
user: User
|
user: User
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class AddImagesToAiPersonas < ActiveRecord::Migration[7.0]
|
||||||
|
def change
|
||||||
|
change_table :ai_personas do |t|
|
||||||
|
add_column :ai_personas, :vision_enabled, :boolean, default: false, null: false
|
||||||
|
add_column :ai_personas, :vision_max_pixels, :integer, default: 1_048_576, null: false
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -5,6 +5,14 @@ module DiscourseAi
|
||||||
module Personas
|
module Personas
|
||||||
class Persona
|
class Persona
|
||||||
class << self
|
class << self
|
||||||
|
def vision_enabled
|
||||||
|
false
|
||||||
|
end
|
||||||
|
|
||||||
|
def vision_max_pixels
|
||||||
|
1_048_576
|
||||||
|
end
|
||||||
|
|
||||||
def system_personas
|
def system_personas
|
||||||
@system_personas ||= {
|
@system_personas ||= {
|
||||||
Personas::General => -1,
|
Personas::General => -1,
|
||||||
|
@ -126,6 +134,7 @@ module DiscourseAi
|
||||||
post_id: context[:post_id],
|
post_id: context[:post_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prompt.max_pixels = self.class.vision_max_pixels if self.class.vision_enabled
|
||||||
prompt.tools = available_tools.map(&:signature) if available_tools
|
prompt.tools = available_tools.map(&:signature) if available_tools
|
||||||
|
|
||||||
prompt
|
prompt
|
||||||
|
|
|
@ -111,17 +111,26 @@ module DiscourseAi
|
||||||
post
|
post
|
||||||
.topic
|
.topic
|
||||||
.posts
|
.posts
|
||||||
.includes(:user)
|
.joins(:user)
|
||||||
.joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id")
|
.joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id")
|
||||||
.where("post_number <= ?", post.post_number)
|
.where("post_number <= ?", post.post_number)
|
||||||
.order("post_number desc")
|
.order("post_number desc")
|
||||||
.where("post_type in (?)", post_types)
|
.where("post_type in (?)", post_types)
|
||||||
.limit(max_posts)
|
.limit(max_posts)
|
||||||
.pluck(:raw, :username, "post_custom_prompts.custom_prompt")
|
.pluck(
|
||||||
|
"posts.raw",
|
||||||
|
"users.username",
|
||||||
|
"post_custom_prompts.custom_prompt",
|
||||||
|
"(
|
||||||
|
SELECT array_agg(ref.upload_id)
|
||||||
|
FROM upload_references ref
|
||||||
|
WHERE ref.target_type = 'Post' AND ref.target_id = posts.id
|
||||||
|
) as upload_ids",
|
||||||
|
)
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
context.reverse_each do |raw, username, custom_prompt|
|
context.reverse_each do |raw, username, custom_prompt, upload_ids|
|
||||||
custom_prompt_translation =
|
custom_prompt_translation =
|
||||||
Proc.new do |message|
|
Proc.new do |message|
|
||||||
# We can't keep backwards-compatibility for stored functions.
|
# We can't keep backwards-compatibility for stored functions.
|
||||||
|
@ -149,6 +158,10 @@ module DiscourseAi
|
||||||
|
|
||||||
context[:id] = username if context[:type] == :user
|
context[:id] = username if context[:type] == :user
|
||||||
|
|
||||||
|
if upload_ids.present? && context[:type] == :user && bot.persona.class.vision_enabled
|
||||||
|
context[:upload_ids] = upload_ids.compact
|
||||||
|
end
|
||||||
|
|
||||||
result << context
|
result << context
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -47,6 +47,7 @@ module DiscourseAi
|
||||||
content = +""
|
content = +""
|
||||||
content << "#{msg[:id]}: " if msg[:id]
|
content << "#{msg[:id]}: " if msg[:id]
|
||||||
content << msg[:content]
|
content << msg[:content]
|
||||||
|
content = inline_images(content, msg)
|
||||||
|
|
||||||
{ role: "user", content: content }
|
{ role: "user", content: content }
|
||||||
end
|
end
|
||||||
|
@ -80,6 +81,33 @@ module DiscourseAi
|
||||||
# Longer term it will have over 1 million
|
# Longer term it will have over 1 million
|
||||||
200_000 # Claude-3 has a 200k context window for now
|
200_000 # Claude-3 has a 200k context window for now
|
||||||
end
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def inline_images(content, message)
|
||||||
|
if model_name.include?("claude-3")
|
||||||
|
encoded_uploads = prompt.encoded_uploads(message)
|
||||||
|
if encoded_uploads.present?
|
||||||
|
new_content = []
|
||||||
|
new_content.concat(
|
||||||
|
encoded_uploads.map do |details|
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
type: "base64",
|
||||||
|
data: details[:base64],
|
||||||
|
media_type: details[:mime_type],
|
||||||
|
},
|
||||||
|
type: "image",
|
||||||
|
}
|
||||||
|
end,
|
||||||
|
)
|
||||||
|
new_content << { type: "text", text: content }
|
||||||
|
content = new_content
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
content
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -66,12 +66,18 @@ module DiscourseAi
|
||||||
def with_prepared_responses(responses, llm: nil)
|
def with_prepared_responses(responses, llm: nil)
|
||||||
@canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses)
|
@canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses)
|
||||||
@canned_llm = llm
|
@canned_llm = llm
|
||||||
|
@prompts = []
|
||||||
|
|
||||||
yield(@canned_response, llm)
|
yield(@canned_response, llm, @prompts)
|
||||||
ensure
|
ensure
|
||||||
# Don't leak prepared response if there's an exception.
|
# Don't leak prepared response if there's an exception.
|
||||||
@canned_response = nil
|
@canned_response = nil
|
||||||
@canned_llm = nil
|
@canned_llm = nil
|
||||||
|
@prompts = nil
|
||||||
|
end
|
||||||
|
|
||||||
|
def record_prompt(prompt)
|
||||||
|
@prompts << prompt if @prompts
|
||||||
end
|
end
|
||||||
|
|
||||||
def proxy(model_name)
|
def proxy(model_name)
|
||||||
|
@ -138,6 +144,8 @@ module DiscourseAi
|
||||||
user:,
|
user:,
|
||||||
&partial_read_blk
|
&partial_read_blk
|
||||||
)
|
)
|
||||||
|
self.class.record_prompt(prompt)
|
||||||
|
|
||||||
model_params = { max_tokens: max_tokens, stop_sequences: stop_sequences }
|
model_params = { max_tokens: max_tokens, stop_sequences: stop_sequences }
|
||||||
|
|
||||||
model_params[:temperature] = temperature if temperature
|
model_params[:temperature] = temperature if temperature
|
||||||
|
|
|
@ -6,7 +6,7 @@ module DiscourseAi
|
||||||
INVALID_TURN = Class.new(StandardError)
|
INVALID_TURN = Class.new(StandardError)
|
||||||
|
|
||||||
attr_reader :messages
|
attr_reader :messages
|
||||||
attr_accessor :tools, :topic_id, :post_id
|
attr_accessor :tools, :topic_id, :post_id, :max_pixels
|
||||||
|
|
||||||
def initialize(
|
def initialize(
|
||||||
system_message_text = nil,
|
system_message_text = nil,
|
||||||
|
@ -14,11 +14,14 @@ module DiscourseAi
|
||||||
tools: [],
|
tools: [],
|
||||||
skip_validations: false,
|
skip_validations: false,
|
||||||
topic_id: nil,
|
topic_id: nil,
|
||||||
post_id: nil
|
post_id: nil,
|
||||||
|
max_pixels: nil
|
||||||
)
|
)
|
||||||
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
|
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
|
||||||
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
|
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
|
||||||
|
|
||||||
|
@max_pixels = max_pixels || 1_048_576
|
||||||
|
|
||||||
@topic_id = topic_id
|
@topic_id = topic_id
|
||||||
@post_id = post_id
|
@post_id = post_id
|
||||||
|
|
||||||
|
@ -38,11 +41,12 @@ module DiscourseAi
|
||||||
@tools = tools
|
@tools = tools
|
||||||
end
|
end
|
||||||
|
|
||||||
def push(type:, content:, id: nil, name: nil)
|
def push(type:, content:, id: nil, name: nil, upload_ids: nil)
|
||||||
return if type == :system
|
return if type == :system
|
||||||
new_message = { type: type, content: content }
|
new_message = { type: type, content: content }
|
||||||
new_message[:name] = name.to_s if name
|
new_message[:name] = name.to_s if name
|
||||||
new_message[:id] = id.to_s if id
|
new_message[:id] = id.to_s if id
|
||||||
|
new_message[:upload_ids] = upload_ids if upload_ids
|
||||||
|
|
||||||
validate_message(new_message)
|
validate_message(new_message)
|
||||||
validate_turn(messages.last, new_message)
|
validate_turn(messages.last, new_message)
|
||||||
|
@ -54,6 +58,13 @@ module DiscourseAi
|
||||||
tools.present?
|
tools.present?
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# helper method to get base64 encoded uploads
|
||||||
|
# at the correct dimentions
|
||||||
|
def encoded_uploads(message)
|
||||||
|
return [] if message[:upload_ids].blank?
|
||||||
|
UploadEncoder.encode(upload_ids: message[:upload_ids], max_pixels: max_pixels)
|
||||||
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
def validate_message(message)
|
def validate_message(message)
|
||||||
|
@ -63,11 +74,19 @@ module DiscourseAi
|
||||||
raise ArgumentError, "message type must be one of #{valid_types}"
|
raise ArgumentError, "message type must be one of #{valid_types}"
|
||||||
end
|
end
|
||||||
|
|
||||||
valid_keys = %i[type content id name]
|
valid_keys = %i[type content id name upload_ids]
|
||||||
if (invalid_keys = message.keys - valid_keys).any?
|
if (invalid_keys = message.keys - valid_keys).any?
|
||||||
raise ArgumentError, "message contains invalid keys: #{invalid_keys}"
|
raise ArgumentError, "message contains invalid keys: #{invalid_keys}"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
if message[:type] == :upload_ids && !message[:upload_ids].is_a?(Array)
|
||||||
|
raise ArgumentError, "upload_ids must be an array of ids"
|
||||||
|
end
|
||||||
|
|
||||||
|
if message[:upload_ids].present? && message[:type] != :user
|
||||||
|
raise ArgumentError, "upload_ids are only supported for users"
|
||||||
|
end
|
||||||
|
|
||||||
raise ArgumentError, "message content must be a string" if !message[:content].is_a?(String)
|
raise ArgumentError, "message content must be a string" if !message[:content].is_a?(String)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
class UploadEncoder
|
||||||
|
def self.encode(upload_ids:, max_pixels:)
|
||||||
|
uploads = []
|
||||||
|
upload_ids.each do |upload_id|
|
||||||
|
upload = Upload.find(upload_id)
|
||||||
|
next if upload.blank?
|
||||||
|
next if upload.width.to_i == 0 || upload.height.to_i == 0
|
||||||
|
|
||||||
|
original_pixels = upload.width * upload.height
|
||||||
|
|
||||||
|
image = upload
|
||||||
|
|
||||||
|
if original_pixels > max_pixels
|
||||||
|
ratio = max_pixels.to_f / original_pixels
|
||||||
|
|
||||||
|
new_width = (ratio * upload.width).to_i
|
||||||
|
new_height = (ratio * upload.height).to_i
|
||||||
|
|
||||||
|
image = upload.get_optimized_image(new_width, new_height)
|
||||||
|
end
|
||||||
|
|
||||||
|
next if !image
|
||||||
|
|
||||||
|
mime_type = MiniMime.lookup_by_filename(upload.original_filename).content_type
|
||||||
|
|
||||||
|
path = Discourse.store.path_for(image)
|
||||||
|
if path.blank?
|
||||||
|
# download is protected with a DistributedMutex
|
||||||
|
external_copy = Discourse.store.download_safe(image)
|
||||||
|
path = external_copy&.path
|
||||||
|
end
|
||||||
|
|
||||||
|
encoded = Base64.strict_encode64(File.read(path))
|
||||||
|
|
||||||
|
uploads << { base64: encoded, mime_type: mime_type }
|
||||||
|
end
|
||||||
|
uploads
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
Binary file not shown.
After Width: | Height: | Size: 415 B |
|
@ -2,6 +2,10 @@
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-opus") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-opus") }
|
||||||
|
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
|
||||||
|
let(:upload100x100) do
|
||||||
|
UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
|
||||||
|
end
|
||||||
|
|
||||||
let(:prompt) do
|
let(:prompt) do
|
||||||
DiscourseAi::Completions::Prompt.new(
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
@ -340,6 +344,73 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
expect(result.strip).to eq(expected)
|
expect(result.strip).to eq(expected)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "can send images via a completion prompt" do
|
||||||
|
prompt =
|
||||||
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
"You are image bot",
|
||||||
|
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded = prompt.encoded_uploads(prompt.messages.last)
|
||||||
|
|
||||||
|
request_body = {
|
||||||
|
model: "claude-3-opus-20240229",
|
||||||
|
max_tokens: 3000,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: "image",
|
||||||
|
source: {
|
||||||
|
type: "base64",
|
||||||
|
media_type: "image/jpeg",
|
||||||
|
data: encoded[0][:base64],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{ type: "text", text: "user1: hello" },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
system: "You are image bot",
|
||||||
|
}
|
||||||
|
|
||||||
|
response_body = <<~STRING
|
||||||
|
{
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"text": "What a cool image",
|
||||||
|
"type": "text"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||||
|
"model": "claude-3-opus-20240229",
|
||||||
|
"role": "assistant",
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": null,
|
||||||
|
"type": "message",
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 10,
|
||||||
|
"output_tokens": 25
|
||||||
|
}
|
||||||
|
}
|
||||||
|
STRING
|
||||||
|
|
||||||
|
requested_body = nil
|
||||||
|
stub_request(:post, "https://api.anthropic.com/v1/messages").with(
|
||||||
|
body:
|
||||||
|
proc do |req_body|
|
||||||
|
requested_body = JSON.parse(req_body, symbolize_names: true)
|
||||||
|
true
|
||||||
|
end,
|
||||||
|
).to_return(status: 200, body: response_body)
|
||||||
|
|
||||||
|
result = llm.generate(prompt, user: Discourse.system_user)
|
||||||
|
|
||||||
|
expect(result).to eq("What a cool image")
|
||||||
|
expect(requested_body).to eq(request_body)
|
||||||
|
end
|
||||||
|
|
||||||
it "can operate in regular mode" do
|
it "can operate in regular mode" do
|
||||||
body = <<~STRING
|
body = <<~STRING
|
||||||
{
|
{
|
||||||
|
|
|
@ -6,6 +6,7 @@ RSpec.describe DiscourseAi::Completions::Prompt do
|
||||||
let(:system_insts) { "These are the system instructions." }
|
let(:system_insts) { "These are the system instructions." }
|
||||||
let(:user_msg) { "Write something nice" }
|
let(:user_msg) { "Write something nice" }
|
||||||
let(:username) { "username1" }
|
let(:username) { "username1" }
|
||||||
|
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
|
||||||
|
|
||||||
describe ".new" do
|
describe ".new" do
|
||||||
it "raises for invalid attributes" do
|
it "raises for invalid attributes" do
|
||||||
|
@ -23,6 +24,38 @@ RSpec.describe DiscourseAi::Completions::Prompt do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "image support" do
|
||||||
|
it "allows adding uploads to messages" do
|
||||||
|
upload = UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
|
||||||
|
|
||||||
|
prompt.max_pixels = 300
|
||||||
|
prompt.push(type: :user, content: "hello", upload_ids: [upload.id])
|
||||||
|
|
||||||
|
expect(prompt.messages.last[:upload_ids]).to eq([upload.id])
|
||||||
|
expect(prompt.max_pixels).to eq(300)
|
||||||
|
|
||||||
|
encoded = prompt.encoded_uploads(prompt.messages.last)
|
||||||
|
|
||||||
|
expect(encoded.length).to eq(1)
|
||||||
|
expect(encoded[0][:mime_type]).to eq("image/jpeg")
|
||||||
|
|
||||||
|
old_base64 = encoded[0][:base64]
|
||||||
|
|
||||||
|
prompt.max_pixels = 1_000_000
|
||||||
|
|
||||||
|
encoded = prompt.encoded_uploads(prompt.messages.last)
|
||||||
|
|
||||||
|
expect(encoded.length).to eq(1)
|
||||||
|
expect(encoded[0][:mime_type]).to eq("image/jpeg")
|
||||||
|
|
||||||
|
new_base64 = encoded[0][:base64]
|
||||||
|
|
||||||
|
expect(new_base64.length).to be > old_base64.length
|
||||||
|
expect(new_base64.length).to be > 0
|
||||||
|
expect(old_base64.length).to be > 0
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
describe "#push" do
|
describe "#push" do
|
||||||
describe "turn validations" do
|
describe "turn validations" do
|
||||||
it "validates that tool messages have a previous tool_call message" do
|
it "validates that tool messages have a previous tool_call message" do
|
||||||
|
|
|
@ -61,6 +61,55 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "image support" do
|
||||||
|
before do
|
||||||
|
Jobs.run_immediately!
|
||||||
|
SiteSetting.ai_bot_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}"
|
||||||
|
end
|
||||||
|
|
||||||
|
fab!(:persona) do
|
||||||
|
AiPersona.create!(
|
||||||
|
name: "Test Persona",
|
||||||
|
description: "A test persona",
|
||||||
|
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
|
||||||
|
enabled: true,
|
||||||
|
system_prompt: "You are a helpful bot",
|
||||||
|
vision_enabled: true,
|
||||||
|
vision_max_pixels: 1_000,
|
||||||
|
default_llm: "anthropic:claude-3-opus",
|
||||||
|
mentionable: true,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
fab!(:upload)
|
||||||
|
|
||||||
|
it "sends images to llm" do
|
||||||
|
post = nil
|
||||||
|
|
||||||
|
persona.create_user!
|
||||||
|
|
||||||
|
image = "![image](upload://#{upload.base62_sha1}.jpg)"
|
||||||
|
body = "Hey @#{persona.user.username}, can you help me with this image? #{image}"
|
||||||
|
|
||||||
|
prompts = nil
|
||||||
|
DiscourseAi::Completions::Llm.with_prepared_responses(
|
||||||
|
["I understood image"],
|
||||||
|
) do |_, _, inner_prompts|
|
||||||
|
post = create_post(title: "some new topic I created", raw: body)
|
||||||
|
|
||||||
|
prompts = inner_prompts
|
||||||
|
end
|
||||||
|
|
||||||
|
expect(prompts[0].messages[1][:upload_ids]).to eq([upload.id])
|
||||||
|
expect(prompts[0].max_pixels).to eq(1000)
|
||||||
|
|
||||||
|
post.topic.reload
|
||||||
|
last_post = post.topic.posts.order(:post_number).last
|
||||||
|
|
||||||
|
expect(last_post.raw).to eq("I understood image")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
describe "persona with user support" do
|
describe "persona with user support" do
|
||||||
before do
|
before do
|
||||||
Jobs.run_immediately!
|
Jobs.run_immediately!
|
||||||
|
|
|
@ -9,13 +9,9 @@ RSpec.describe DiscourseAi::AiBot::Tools::DiscourseMetaSearch do
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
|
|
||||||
let(:mock_search_json) do
|
let(:mock_search_json) { plugin_file_from_fixtures("search.json", "search_meta").read }
|
||||||
File.read(File.expand_path("../../../../../fixtures/search_meta/search.json", __FILE__))
|
|
||||||
end
|
|
||||||
|
|
||||||
let(:mock_site_json) do
|
let(:mock_site_json) { plugin_file_from_fixtures("site.json", "search_meta").read }
|
||||||
File.read(File.expand_path("../../../../../fixtures/search_meta/site.json", __FILE__))
|
|
||||||
end
|
|
||||||
|
|
||||||
before do
|
before do
|
||||||
stub_request(:get, "https://meta.discourse.org/site.json").to_return(
|
stub_request(:get, "https://meta.discourse.org/site.json").to_return(
|
||||||
|
|
|
@ -204,6 +204,23 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
||||||
expect(persona.temperature).to eq(nil)
|
expect(persona.temperature).to eq(nil)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "supports updating vision params" do
|
||||||
|
persona = Fabricate(:ai_persona, name: "test_bot2")
|
||||||
|
put "/admin/plugins/discourse-ai/ai-personas/#{persona.id}.json",
|
||||||
|
params: {
|
||||||
|
ai_persona: {
|
||||||
|
vision_enabled: true,
|
||||||
|
vision_max_pixels: 512 * 512,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(response).to have_http_status(:ok)
|
||||||
|
persona.reload
|
||||||
|
|
||||||
|
expect(persona.vision_enabled).to eq(true)
|
||||||
|
expect(persona.vision_max_pixels).to eq(512 * 512)
|
||||||
|
end
|
||||||
|
|
||||||
it "does not allow temperature and top p changes on stock personas" do
|
it "does not allow temperature and top p changes on stock personas" do
|
||||||
put "/admin/plugins/discourse-ai/ai-personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json",
|
put "/admin/plugins/discourse-ai/ai-personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json",
|
||||||
params: {
|
params: {
|
||||||
|
|
|
@ -46,6 +46,8 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
|
||||||
user: null,
|
user: null,
|
||||||
user_id: null,
|
user_id: null,
|
||||||
max_context_posts: 5,
|
max_context_posts: 5,
|
||||||
|
vision_enabled: true,
|
||||||
|
vision_max_pixels: 100,
|
||||||
};
|
};
|
||||||
|
|
||||||
const aiPersona = AiPersona.create({ ...properties });
|
const aiPersona = AiPersona.create({ ...properties });
|
||||||
|
@ -77,6 +79,8 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
|
||||||
default_llm: "Default LLM",
|
default_llm: "Default LLM",
|
||||||
mentionable: false,
|
mentionable: false,
|
||||||
max_context_posts: 5,
|
max_context_posts: 5,
|
||||||
|
vision_enabled: true,
|
||||||
|
vision_max_pixels: 100,
|
||||||
};
|
};
|
||||||
|
|
||||||
const aiPersona = AiPersona.create({ ...properties });
|
const aiPersona = AiPersona.create({ ...properties });
|
||||||
|
|
Loading…
Reference in New Issue