FEATURE: Add vision support to AI personas (Claude 3) (#546)

This commit adds the ability to enable vision for AI personas, allowing them to understand images that are posted in the conversation.

For personas with vision enabled, any images the user has posted will be resized to be within the configured max_pixels limit, base64 encoded and included in the prompt sent to the AI provider.

The persona editor allows enabling/disabling vision and has a dropdown to select the max supported image size (low, medium, high). Vision is disabled by default.

This initial vision support has been tested and implemented with Anthropic's claude-3 models which accept images in a special format as part of the prompt.

Other integrations will need to be updated to support images.

Several specs were added to test the new functionality at the persona, prompt building and API layers.

 - Gemini is omitted, pending API support for Gemini 1.5. Current Gemini bot is not performing well, adding images is unlikely to make it perform any better.

 - Open AI is omitted, vision support on GPT-4 it limited in that the API has no tool support when images are enabled so we would need to full back to a different prompting technique, something that would add lots of complexity


---------

Co-authored-by: Martin Brennan <martin@discourse.org>
This commit is contained in:
Sam 2024-03-27 14:30:11 +11:00 committed by GitHub
parent 82387cc51d
commit 61e4c56e1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 422 additions and 16 deletions

View File

@ -79,6 +79,8 @@ module DiscourseAi
:user_id,
:mentionable,
:max_context_posts,
:vision_enabled,
:vision_max_pixels,
allowed_group_ids: [],
)

View File

@ -9,6 +9,9 @@ class AiPersona < ActiveRecord::Base
validates :system_prompt, presence: true, length: { maximum: 10_000_000 }
validate :system_persona_unchangeable, on: :update, if: :system
validates :max_context_posts, numericality: { greater_than: 0 }, allow_nil: true
# leaves some room for growth but sets a maximum to avoid memory issues
# we may want to revisit this in the future
validates :vision_max_pixels, numericality: { greater_than: 0, maximum: 4_000_000 }
belongs_to :created_by, class_name: "User"
belongs_to :user
@ -98,6 +101,8 @@ class AiPersona < ActiveRecord::Base
mentionable = self.mentionable
default_llm = self.default_llm
max_context_posts = self.max_context_posts
vision_enabled = self.vision_enabled
vision_max_pixels = self.vision_max_pixels
persona_class = DiscourseAi::AiBot::Personas::Persona.system_personas_by_id[self.id]
if persona_class
@ -129,6 +134,14 @@ class AiPersona < ActiveRecord::Base
max_context_posts
end
persona_class.define_singleton_method :vision_enabled do
vision_enabled
end
persona_class.define_singleton_method :vision_max_pixels do
vision_max_pixels
end
return persona_class
end
@ -204,6 +217,14 @@ class AiPersona < ActiveRecord::Base
max_context_posts
end
define_singleton_method :vision_enabled do
vision_enabled
end
define_singleton_method :vision_max_pixels do
vision_max_pixels
end
define_singleton_method :to_s do
"#<DiscourseAi::AiBot::Personas::Persona::Custom @name=#{self.name} @allowed_group_ids=#{self.allowed_group_ids.join(",")}>"
end

View File

@ -17,7 +17,9 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer
:mentionable,
:default_llm,
:user_id,
:max_context_posts
:max_context_posts,
:vision_enabled,
:vision_max_pixels
has_one :user, serializer: BasicUserSerializer, embed: :object

View File

@ -18,6 +18,8 @@ const ATTRIBUTES = [
"default_llm",
"user",
"max_context_posts",
"vision_enabled",
"vision_max_pixels",
];
const SYSTEM_ATTRIBUTES = [
@ -30,6 +32,8 @@ const SYSTEM_ATTRIBUTES = [
"default_llm",
"user",
"max_context_posts",
"vision_enabled",
"vision_max_pixels",
];
class CommandOption {

View File

@ -1,5 +1,5 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { cached, tracked } from "@glimmer/tracking";
import { Input } from "@ember/component";
import { on } from "@ember/modifier";
import { action } from "@ember/object";
@ -17,6 +17,7 @@ import { popupAjaxError } from "discourse/lib/ajax-error";
import Group from "discourse/models/group";
import I18n from "discourse-i18n";
import AdminUser from "admin/models/admin-user";
import ComboBox from "select-kit/components/combo-box";
import GroupChooser from "select-kit/components/group-chooser";
import DTooltip from "float-kit/components/d-tooltip";
import AiCommandSelector from "./ai-command-selector";
@ -34,10 +35,36 @@ export default class PersonaEditor extends Component {
@tracked editingModel = null;
@tracked showDelete = false;
@tracked maxPixelsValue = null;
@action
updateModel() {
this.editingModel = this.args.model.workingCopy();
this.showDelete = !this.args.model.isNew && !this.args.model.system;
this.maxPixelsValue = this.findClosestPixelValue(
this.editingModel.vision_max_pixels
);
}
findClosestPixelValue(pixels) {
let value = "high";
this.maxPixelValues.forEach((info) => {
if (pixels === info.pixels) {
value = info.id;
}
});
return value;
}
@cached
get maxPixelValues() {
const l = (key) =>
I18n.t(`discourse_ai.ai_persona.vision_max_pixel_sizes.${key}`);
return [
{ id: "low", name: l("low"), pixels: 65536 },
{ id: "medium", name: l("medium"), pixels: 262144 },
{ id: "high", name: l("high"), pixels: 1048576 },
];
}
@action
@ -102,6 +129,16 @@ export default class PersonaEditor extends Component {
}
}
@action
onChangeMaxPixels(value) {
const entry = this.maxPixelValues.findBy("id", value);
if (!entry) {
return;
}
this.maxPixelsValue = value;
this.editingModel.vision_max_pixels = entry.pixels;
}
@action
delete() {
return this.dialog.confirm({
@ -137,6 +174,11 @@ export default class PersonaEditor extends Component {
await this.toggleField("mentionable");
}
@action
async toggleVisionEnabled() {
await this.toggleField("vision_enabled");
}
@action
async createUser() {
try {
@ -225,6 +267,17 @@ export default class PersonaEditor extends Component {
/>
</div>
{{/if}}
<div class="control-group ai-persona-editor__vision_enabled">
<DToggleSwitch
@state={{@model.vision_enabled}}
@label="discourse_ai.ai_persona.vision_enabled"
{{on "click" this.toggleVisionEnabled}}
/>
<DTooltip
@icon="question-circle"
@content={{I18n.t "discourse_ai.ai_persona.vision_enabled_help"}}
/>
</div>
<div class="control-group">
<label>{{I18n.t "discourse_ai.ai_persona.name"}}</label>
<Input
@ -329,6 +382,16 @@ export default class PersonaEditor extends Component {
@content={{I18n.t "discourse_ai.ai_persona.max_context_posts_help"}}
/>
</div>
{{#if @model.vision_enabled}}
<div class="control-group">
<label>{{I18n.t "discourse_ai.ai_persona.vision_max_pixels"}}</label>
<ComboBox
@value={{this.maxPixelsValue}}
@content={{this.maxPixelValues}}
@onChange={{this.onChangeMaxPixels}}
/>
</div>
{{/if}}
{{#if this.showTemperature}}
<div class="control-group">
<label>{{I18n.t "discourse_ai.ai_persona.temperature"}}</label>

View File

@ -71,4 +71,9 @@
display: flex;
align-items: center;
}
&__vision_enabled {
display: flex;
align-items: center;
}
}

View File

@ -121,6 +121,13 @@ en:
no_llm_selected: "No language model selected"
max_context_posts: "Max Context Posts"
max_context_posts_help: "The maximum number of posts to use as context for the AI when responding to a user. (empty for default)"
vision_enabled: Vision Enabled
vision_enabled_help: If enabled, the AI will attempt to understand images users post in the topic, depends on the model being used supporting vision. Anthropic Claude 3 models support vision.
vision_max_pixels: Supported image size
vision_max_pixel_sizes:
low: Low Quality - cheapest (256x256)
medium: Medium Quality (512x512)
high: High Quality - slowest (1024x1024)
mentionable: Mentionable
mentionable_help: If enabled, users in allowed groups can mention this user in posts and messages, the AI will respond as this persona.
user: User

View File

@ -0,0 +1,10 @@
# frozen_string_literal: true
class AddImagesToAiPersonas < ActiveRecord::Migration[7.0]
def change
change_table :ai_personas do |t|
add_column :ai_personas, :vision_enabled, :boolean, default: false, null: false
add_column :ai_personas, :vision_max_pixels, :integer, default: 1_048_576, null: false
end
end
end

View File

@ -5,6 +5,14 @@ module DiscourseAi
module Personas
class Persona
class << self
def vision_enabled
false
end
def vision_max_pixels
1_048_576
end
def system_personas
@system_personas ||= {
Personas::General => -1,
@ -126,6 +134,7 @@ module DiscourseAi
post_id: context[:post_id],
)
prompt.max_pixels = self.class.vision_max_pixels if self.class.vision_enabled
prompt.tools = available_tools.map(&:signature) if available_tools
prompt

View File

@ -111,17 +111,26 @@ module DiscourseAi
post
.topic
.posts
.includes(:user)
.joins(:user)
.joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id")
.where("post_number <= ?", post.post_number)
.order("post_number desc")
.where("post_type in (?)", post_types)
.limit(max_posts)
.pluck(:raw, :username, "post_custom_prompts.custom_prompt")
.pluck(
"posts.raw",
"users.username",
"post_custom_prompts.custom_prompt",
"(
SELECT array_agg(ref.upload_id)
FROM upload_references ref
WHERE ref.target_type = 'Post' AND ref.target_id = posts.id
) as upload_ids",
)
result = []
context.reverse_each do |raw, username, custom_prompt|
context.reverse_each do |raw, username, custom_prompt, upload_ids|
custom_prompt_translation =
Proc.new do |message|
# We can't keep backwards-compatibility for stored functions.
@ -149,6 +158,10 @@ module DiscourseAi
context[:id] = username if context[:type] == :user
if upload_ids.present? && context[:type] == :user && bot.persona.class.vision_enabled
context[:upload_ids] = upload_ids.compact
end
result << context
end
end

View File

@ -47,6 +47,7 @@ module DiscourseAi
content = +""
content << "#{msg[:id]}: " if msg[:id]
content << msg[:content]
content = inline_images(content, msg)
{ role: "user", content: content }
end
@ -80,6 +81,33 @@ module DiscourseAi
# Longer term it will have over 1 million
200_000 # Claude-3 has a 200k context window for now
end
private
def inline_images(content, message)
if model_name.include?("claude-3")
encoded_uploads = prompt.encoded_uploads(message)
if encoded_uploads.present?
new_content = []
new_content.concat(
encoded_uploads.map do |details|
{
source: {
type: "base64",
data: details[:base64],
media_type: details[:mime_type],
},
type: "image",
}
end,
)
new_content << { type: "text", text: content }
content = new_content
end
end
content
end
end
end
end

View File

@ -66,12 +66,18 @@ module DiscourseAi
def with_prepared_responses(responses, llm: nil)
@canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses)
@canned_llm = llm
@prompts = []
yield(@canned_response, llm)
yield(@canned_response, llm, @prompts)
ensure
# Don't leak prepared response if there's an exception.
@canned_response = nil
@canned_llm = nil
@prompts = nil
end
def record_prompt(prompt)
@prompts << prompt if @prompts
end
def proxy(model_name)
@ -138,6 +144,8 @@ module DiscourseAi
user:,
&partial_read_blk
)
self.class.record_prompt(prompt)
model_params = { max_tokens: max_tokens, stop_sequences: stop_sequences }
model_params[:temperature] = temperature if temperature

View File

@ -6,7 +6,7 @@ module DiscourseAi
INVALID_TURN = Class.new(StandardError)
attr_reader :messages
attr_accessor :tools, :topic_id, :post_id
attr_accessor :tools, :topic_id, :post_id, :max_pixels
def initialize(
system_message_text = nil,
@ -14,11 +14,14 @@ module DiscourseAi
tools: [],
skip_validations: false,
topic_id: nil,
post_id: nil
post_id: nil,
max_pixels: nil
)
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
@max_pixels = max_pixels || 1_048_576
@topic_id = topic_id
@post_id = post_id
@ -38,11 +41,12 @@ module DiscourseAi
@tools = tools
end
def push(type:, content:, id: nil, name: nil)
def push(type:, content:, id: nil, name: nil, upload_ids: nil)
return if type == :system
new_message = { type: type, content: content }
new_message[:name] = name.to_s if name
new_message[:id] = id.to_s if id
new_message[:upload_ids] = upload_ids if upload_ids
validate_message(new_message)
validate_turn(messages.last, new_message)
@ -54,6 +58,13 @@ module DiscourseAi
tools.present?
end
# helper method to get base64 encoded uploads
# at the correct dimentions
def encoded_uploads(message)
return [] if message[:upload_ids].blank?
UploadEncoder.encode(upload_ids: message[:upload_ids], max_pixels: max_pixels)
end
private
def validate_message(message)
@ -63,11 +74,19 @@ module DiscourseAi
raise ArgumentError, "message type must be one of #{valid_types}"
end
valid_keys = %i[type content id name]
valid_keys = %i[type content id name upload_ids]
if (invalid_keys = message.keys - valid_keys).any?
raise ArgumentError, "message contains invalid keys: #{invalid_keys}"
end
if message[:type] == :upload_ids && !message[:upload_ids].is_a?(Array)
raise ArgumentError, "upload_ids must be an array of ids"
end
if message[:upload_ids].present? && message[:type] != :user
raise ArgumentError, "upload_ids are only supported for users"
end
raise ArgumentError, "message content must be a string" if !message[:content].is_a?(String)
end

View File

@ -0,0 +1,45 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
class UploadEncoder
def self.encode(upload_ids:, max_pixels:)
uploads = []
upload_ids.each do |upload_id|
upload = Upload.find(upload_id)
next if upload.blank?
next if upload.width.to_i == 0 || upload.height.to_i == 0
original_pixels = upload.width * upload.height
image = upload
if original_pixels > max_pixels
ratio = max_pixels.to_f / original_pixels
new_width = (ratio * upload.width).to_i
new_height = (ratio * upload.height).to_i
image = upload.get_optimized_image(new_width, new_height)
end
next if !image
mime_type = MiniMime.lookup_by_filename(upload.original_filename).content_type
path = Discourse.store.path_for(image)
if path.blank?
# download is protected with a DistributedMutex
external_copy = Discourse.store.download_safe(image)
path = external_copy&.path
end
encoded = Base64.strict_encode64(File.read(path))
uploads << { base64: encoded, mime_type: mime_type }
end
uploads
end
end
end
end

BIN
spec/fixtures/images/100x100.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 415 B

View File

@ -2,6 +2,10 @@
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
let(:llm) { DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-opus") }
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
let(:upload100x100) do
UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
end
let(:prompt) do
DiscourseAi::Completions::Prompt.new(
@ -340,6 +344,73 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
expect(result.strip).to eq(expected)
end
it "can send images via a completion prompt" do
prompt =
DiscourseAi::Completions::Prompt.new(
"You are image bot",
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
)
encoded = prompt.encoded_uploads(prompt.messages.last)
request_body = {
model: "claude-3-opus-20240229",
max_tokens: 3000,
messages: [
{
role: "user",
content: [
{
type: "image",
source: {
type: "base64",
media_type: "image/jpeg",
data: encoded[0][:base64],
},
},
{ type: "text", text: "user1: hello" },
],
},
],
system: "You are image bot",
}
response_body = <<~STRING
{
"content": [
{
"text": "What a cool image",
"type": "text"
}
],
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
"model": "claude-3-opus-20240229",
"role": "assistant",
"stop_reason": "end_turn",
"stop_sequence": null,
"type": "message",
"usage": {
"input_tokens": 10,
"output_tokens": 25
}
}
STRING
requested_body = nil
stub_request(:post, "https://api.anthropic.com/v1/messages").with(
body:
proc do |req_body|
requested_body = JSON.parse(req_body, symbolize_names: true)
true
end,
).to_return(status: 200, body: response_body)
result = llm.generate(prompt, user: Discourse.system_user)
expect(result).to eq("What a cool image")
expect(requested_body).to eq(request_body)
end
it "can operate in regular mode" do
body = <<~STRING
{

View File

@ -6,6 +6,7 @@ RSpec.describe DiscourseAi::Completions::Prompt do
let(:system_insts) { "These are the system instructions." }
let(:user_msg) { "Write something nice" }
let(:username) { "username1" }
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
describe ".new" do
it "raises for invalid attributes" do
@ -23,6 +24,38 @@ RSpec.describe DiscourseAi::Completions::Prompt do
end
end
describe "image support" do
it "allows adding uploads to messages" do
upload = UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
prompt.max_pixels = 300
prompt.push(type: :user, content: "hello", upload_ids: [upload.id])
expect(prompt.messages.last[:upload_ids]).to eq([upload.id])
expect(prompt.max_pixels).to eq(300)
encoded = prompt.encoded_uploads(prompt.messages.last)
expect(encoded.length).to eq(1)
expect(encoded[0][:mime_type]).to eq("image/jpeg")
old_base64 = encoded[0][:base64]
prompt.max_pixels = 1_000_000
encoded = prompt.encoded_uploads(prompt.messages.last)
expect(encoded.length).to eq(1)
expect(encoded[0][:mime_type]).to eq("image/jpeg")
new_base64 = encoded[0][:base64]
expect(new_base64.length).to be > old_base64.length
expect(new_base64.length).to be > 0
expect(old_base64.length).to be > 0
end
end
describe "#push" do
describe "turn validations" do
it "validates that tool messages have a previous tool_call message" do

View File

@ -61,6 +61,55 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end
end
describe "image support" do
before do
Jobs.run_immediately!
SiteSetting.ai_bot_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}"
end
fab!(:persona) do
AiPersona.create!(
name: "Test Persona",
description: "A test persona",
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
enabled: true,
system_prompt: "You are a helpful bot",
vision_enabled: true,
vision_max_pixels: 1_000,
default_llm: "anthropic:claude-3-opus",
mentionable: true,
)
end
fab!(:upload)
it "sends images to llm" do
post = nil
persona.create_user!
image = "![image](upload://#{upload.base62_sha1}.jpg)"
body = "Hey @#{persona.user.username}, can you help me with this image? #{image}"
prompts = nil
DiscourseAi::Completions::Llm.with_prepared_responses(
["I understood image"],
) do |_, _, inner_prompts|
post = create_post(title: "some new topic I created", raw: body)
prompts = inner_prompts
end
expect(prompts[0].messages[1][:upload_ids]).to eq([upload.id])
expect(prompts[0].max_pixels).to eq(1000)
post.topic.reload
last_post = post.topic.posts.order(:post_number).last
expect(last_post.raw).to eq("I understood image")
end
end
describe "persona with user support" do
before do
Jobs.run_immediately!

View File

@ -9,13 +9,9 @@ RSpec.describe DiscourseAi::AiBot::Tools::DiscourseMetaSearch do
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
let(:progress_blk) { Proc.new {} }
let(:mock_search_json) do
File.read(File.expand_path("../../../../../fixtures/search_meta/search.json", __FILE__))
end
let(:mock_search_json) { plugin_file_from_fixtures("search.json", "search_meta").read }
let(:mock_site_json) do
File.read(File.expand_path("../../../../../fixtures/search_meta/site.json", __FILE__))
end
let(:mock_site_json) { plugin_file_from_fixtures("site.json", "search_meta").read }
before do
stub_request(:get, "https://meta.discourse.org/site.json").to_return(

View File

@ -204,6 +204,23 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
expect(persona.temperature).to eq(nil)
end
it "supports updating vision params" do
persona = Fabricate(:ai_persona, name: "test_bot2")
put "/admin/plugins/discourse-ai/ai-personas/#{persona.id}.json",
params: {
ai_persona: {
vision_enabled: true,
vision_max_pixels: 512 * 512,
},
}
expect(response).to have_http_status(:ok)
persona.reload
expect(persona.vision_enabled).to eq(true)
expect(persona.vision_max_pixels).to eq(512 * 512)
end
it "does not allow temperature and top p changes on stock personas" do
put "/admin/plugins/discourse-ai/ai-personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json",
params: {

View File

@ -46,6 +46,8 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
user: null,
user_id: null,
max_context_posts: 5,
vision_enabled: true,
vision_max_pixels: 100,
};
const aiPersona = AiPersona.create({ ...properties });
@ -77,6 +79,8 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
default_llm: "Default LLM",
mentionable: false,
max_context_posts: 5,
vision_enabled: true,
vision_max_pixels: 100,
};
const aiPersona = AiPersona.create({ ...properties });