FEATURE: Add vision support to AI personas (Claude 3) (#546)

This commit adds the ability to enable vision for AI personas, allowing them to understand images that are posted in the conversation.

For personas with vision enabled, any images the user has posted will be resized to be within the configured max_pixels limit, base64 encoded and included in the prompt sent to the AI provider.

The persona editor allows enabling/disabling vision and has a dropdown to select the max supported image size (low, medium, high). Vision is disabled by default.

This initial vision support has been tested and implemented with Anthropic's claude-3 models which accept images in a special format as part of the prompt.

Other integrations will need to be updated to support images.

Several specs were added to test the new functionality at the persona, prompt building and API layers.

 - Gemini is omitted, pending API support for Gemini 1.5. Current Gemini bot is not performing well, adding images is unlikely to make it perform any better.

 - Open AI is omitted, vision support on GPT-4 it limited in that the API has no tool support when images are enabled so we would need to full back to a different prompting technique, something that would add lots of complexity


---------

Co-authored-by: Martin Brennan <martin@discourse.org>
This commit is contained in:
Sam 2024-03-27 14:30:11 +11:00 committed by GitHub
parent 82387cc51d
commit 61e4c56e1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 422 additions and 16 deletions

View File

@ -79,6 +79,8 @@ module DiscourseAi
:user_id, :user_id,
:mentionable, :mentionable,
:max_context_posts, :max_context_posts,
:vision_enabled,
:vision_max_pixels,
allowed_group_ids: [], allowed_group_ids: [],
) )

View File

@ -9,6 +9,9 @@ class AiPersona < ActiveRecord::Base
validates :system_prompt, presence: true, length: { maximum: 10_000_000 } validates :system_prompt, presence: true, length: { maximum: 10_000_000 }
validate :system_persona_unchangeable, on: :update, if: :system validate :system_persona_unchangeable, on: :update, if: :system
validates :max_context_posts, numericality: { greater_than: 0 }, allow_nil: true validates :max_context_posts, numericality: { greater_than: 0 }, allow_nil: true
# leaves some room for growth but sets a maximum to avoid memory issues
# we may want to revisit this in the future
validates :vision_max_pixels, numericality: { greater_than: 0, maximum: 4_000_000 }
belongs_to :created_by, class_name: "User" belongs_to :created_by, class_name: "User"
belongs_to :user belongs_to :user
@ -98,6 +101,8 @@ class AiPersona < ActiveRecord::Base
mentionable = self.mentionable mentionable = self.mentionable
default_llm = self.default_llm default_llm = self.default_llm
max_context_posts = self.max_context_posts max_context_posts = self.max_context_posts
vision_enabled = self.vision_enabled
vision_max_pixels = self.vision_max_pixels
persona_class = DiscourseAi::AiBot::Personas::Persona.system_personas_by_id[self.id] persona_class = DiscourseAi::AiBot::Personas::Persona.system_personas_by_id[self.id]
if persona_class if persona_class
@ -129,6 +134,14 @@ class AiPersona < ActiveRecord::Base
max_context_posts max_context_posts
end end
persona_class.define_singleton_method :vision_enabled do
vision_enabled
end
persona_class.define_singleton_method :vision_max_pixels do
vision_max_pixels
end
return persona_class return persona_class
end end
@ -204,6 +217,14 @@ class AiPersona < ActiveRecord::Base
max_context_posts max_context_posts
end end
define_singleton_method :vision_enabled do
vision_enabled
end
define_singleton_method :vision_max_pixels do
vision_max_pixels
end
define_singleton_method :to_s do define_singleton_method :to_s do
"#<DiscourseAi::AiBot::Personas::Persona::Custom @name=#{self.name} @allowed_group_ids=#{self.allowed_group_ids.join(",")}>" "#<DiscourseAi::AiBot::Personas::Persona::Custom @name=#{self.name} @allowed_group_ids=#{self.allowed_group_ids.join(",")}>"
end end

View File

@ -17,7 +17,9 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer
:mentionable, :mentionable,
:default_llm, :default_llm,
:user_id, :user_id,
:max_context_posts :max_context_posts,
:vision_enabled,
:vision_max_pixels
has_one :user, serializer: BasicUserSerializer, embed: :object has_one :user, serializer: BasicUserSerializer, embed: :object

View File

@ -18,6 +18,8 @@ const ATTRIBUTES = [
"default_llm", "default_llm",
"user", "user",
"max_context_posts", "max_context_posts",
"vision_enabled",
"vision_max_pixels",
]; ];
const SYSTEM_ATTRIBUTES = [ const SYSTEM_ATTRIBUTES = [
@ -30,6 +32,8 @@ const SYSTEM_ATTRIBUTES = [
"default_llm", "default_llm",
"user", "user",
"max_context_posts", "max_context_posts",
"vision_enabled",
"vision_max_pixels",
]; ];
class CommandOption { class CommandOption {

View File

@ -1,5 +1,5 @@
import Component from "@glimmer/component"; import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking"; import { cached, tracked } from "@glimmer/tracking";
import { Input } from "@ember/component"; import { Input } from "@ember/component";
import { on } from "@ember/modifier"; import { on } from "@ember/modifier";
import { action } from "@ember/object"; import { action } from "@ember/object";
@ -17,6 +17,7 @@ import { popupAjaxError } from "discourse/lib/ajax-error";
import Group from "discourse/models/group"; import Group from "discourse/models/group";
import I18n from "discourse-i18n"; import I18n from "discourse-i18n";
import AdminUser from "admin/models/admin-user"; import AdminUser from "admin/models/admin-user";
import ComboBox from "select-kit/components/combo-box";
import GroupChooser from "select-kit/components/group-chooser"; import GroupChooser from "select-kit/components/group-chooser";
import DTooltip from "float-kit/components/d-tooltip"; import DTooltip from "float-kit/components/d-tooltip";
import AiCommandSelector from "./ai-command-selector"; import AiCommandSelector from "./ai-command-selector";
@ -34,10 +35,36 @@ export default class PersonaEditor extends Component {
@tracked editingModel = null; @tracked editingModel = null;
@tracked showDelete = false; @tracked showDelete = false;
@tracked maxPixelsValue = null;
@action @action
updateModel() { updateModel() {
this.editingModel = this.args.model.workingCopy(); this.editingModel = this.args.model.workingCopy();
this.showDelete = !this.args.model.isNew && !this.args.model.system; this.showDelete = !this.args.model.isNew && !this.args.model.system;
this.maxPixelsValue = this.findClosestPixelValue(
this.editingModel.vision_max_pixels
);
}
findClosestPixelValue(pixels) {
let value = "high";
this.maxPixelValues.forEach((info) => {
if (pixels === info.pixels) {
value = info.id;
}
});
return value;
}
@cached
get maxPixelValues() {
const l = (key) =>
I18n.t(`discourse_ai.ai_persona.vision_max_pixel_sizes.${key}`);
return [
{ id: "low", name: l("low"), pixels: 65536 },
{ id: "medium", name: l("medium"), pixels: 262144 },
{ id: "high", name: l("high"), pixels: 1048576 },
];
} }
@action @action
@ -102,6 +129,16 @@ export default class PersonaEditor extends Component {
} }
} }
@action
onChangeMaxPixels(value) {
const entry = this.maxPixelValues.findBy("id", value);
if (!entry) {
return;
}
this.maxPixelsValue = value;
this.editingModel.vision_max_pixels = entry.pixels;
}
@action @action
delete() { delete() {
return this.dialog.confirm({ return this.dialog.confirm({
@ -137,6 +174,11 @@ export default class PersonaEditor extends Component {
await this.toggleField("mentionable"); await this.toggleField("mentionable");
} }
@action
async toggleVisionEnabled() {
await this.toggleField("vision_enabled");
}
@action @action
async createUser() { async createUser() {
try { try {
@ -225,6 +267,17 @@ export default class PersonaEditor extends Component {
/> />
</div> </div>
{{/if}} {{/if}}
<div class="control-group ai-persona-editor__vision_enabled">
<DToggleSwitch
@state={{@model.vision_enabled}}
@label="discourse_ai.ai_persona.vision_enabled"
{{on "click" this.toggleVisionEnabled}}
/>
<DTooltip
@icon="question-circle"
@content={{I18n.t "discourse_ai.ai_persona.vision_enabled_help"}}
/>
</div>
<div class="control-group"> <div class="control-group">
<label>{{I18n.t "discourse_ai.ai_persona.name"}}</label> <label>{{I18n.t "discourse_ai.ai_persona.name"}}</label>
<Input <Input
@ -329,6 +382,16 @@ export default class PersonaEditor extends Component {
@content={{I18n.t "discourse_ai.ai_persona.max_context_posts_help"}} @content={{I18n.t "discourse_ai.ai_persona.max_context_posts_help"}}
/> />
</div> </div>
{{#if @model.vision_enabled}}
<div class="control-group">
<label>{{I18n.t "discourse_ai.ai_persona.vision_max_pixels"}}</label>
<ComboBox
@value={{this.maxPixelsValue}}
@content={{this.maxPixelValues}}
@onChange={{this.onChangeMaxPixels}}
/>
</div>
{{/if}}
{{#if this.showTemperature}} {{#if this.showTemperature}}
<div class="control-group"> <div class="control-group">
<label>{{I18n.t "discourse_ai.ai_persona.temperature"}}</label> <label>{{I18n.t "discourse_ai.ai_persona.temperature"}}</label>

View File

@ -71,4 +71,9 @@
display: flex; display: flex;
align-items: center; align-items: center;
} }
&__vision_enabled {
display: flex;
align-items: center;
}
} }

View File

@ -121,6 +121,13 @@ en:
no_llm_selected: "No language model selected" no_llm_selected: "No language model selected"
max_context_posts: "Max Context Posts" max_context_posts: "Max Context Posts"
max_context_posts_help: "The maximum number of posts to use as context for the AI when responding to a user. (empty for default)" max_context_posts_help: "The maximum number of posts to use as context for the AI when responding to a user. (empty for default)"
vision_enabled: Vision Enabled
vision_enabled_help: If enabled, the AI will attempt to understand images users post in the topic, depends on the model being used supporting vision. Anthropic Claude 3 models support vision.
vision_max_pixels: Supported image size
vision_max_pixel_sizes:
low: Low Quality - cheapest (256x256)
medium: Medium Quality (512x512)
high: High Quality - slowest (1024x1024)
mentionable: Mentionable mentionable: Mentionable
mentionable_help: If enabled, users in allowed groups can mention this user in posts and messages, the AI will respond as this persona. mentionable_help: If enabled, users in allowed groups can mention this user in posts and messages, the AI will respond as this persona.
user: User user: User

View File

@ -0,0 +1,10 @@
# frozen_string_literal: true
class AddImagesToAiPersonas < ActiveRecord::Migration[7.0]
def change
change_table :ai_personas do |t|
add_column :ai_personas, :vision_enabled, :boolean, default: false, null: false
add_column :ai_personas, :vision_max_pixels, :integer, default: 1_048_576, null: false
end
end
end

View File

@ -5,6 +5,14 @@ module DiscourseAi
module Personas module Personas
class Persona class Persona
class << self class << self
def vision_enabled
false
end
def vision_max_pixels
1_048_576
end
def system_personas def system_personas
@system_personas ||= { @system_personas ||= {
Personas::General => -1, Personas::General => -1,
@ -126,6 +134,7 @@ module DiscourseAi
post_id: context[:post_id], post_id: context[:post_id],
) )
prompt.max_pixels = self.class.vision_max_pixels if self.class.vision_enabled
prompt.tools = available_tools.map(&:signature) if available_tools prompt.tools = available_tools.map(&:signature) if available_tools
prompt prompt

View File

@ -111,17 +111,26 @@ module DiscourseAi
post post
.topic .topic
.posts .posts
.includes(:user) .joins(:user)
.joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id") .joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id")
.where("post_number <= ?", post.post_number) .where("post_number <= ?", post.post_number)
.order("post_number desc") .order("post_number desc")
.where("post_type in (?)", post_types) .where("post_type in (?)", post_types)
.limit(max_posts) .limit(max_posts)
.pluck(:raw, :username, "post_custom_prompts.custom_prompt") .pluck(
"posts.raw",
"users.username",
"post_custom_prompts.custom_prompt",
"(
SELECT array_agg(ref.upload_id)
FROM upload_references ref
WHERE ref.target_type = 'Post' AND ref.target_id = posts.id
) as upload_ids",
)
result = [] result = []
context.reverse_each do |raw, username, custom_prompt| context.reverse_each do |raw, username, custom_prompt, upload_ids|
custom_prompt_translation = custom_prompt_translation =
Proc.new do |message| Proc.new do |message|
# We can't keep backwards-compatibility for stored functions. # We can't keep backwards-compatibility for stored functions.
@ -149,6 +158,10 @@ module DiscourseAi
context[:id] = username if context[:type] == :user context[:id] = username if context[:type] == :user
if upload_ids.present? && context[:type] == :user && bot.persona.class.vision_enabled
context[:upload_ids] = upload_ids.compact
end
result << context result << context
end end
end end

View File

@ -47,6 +47,7 @@ module DiscourseAi
content = +"" content = +""
content << "#{msg[:id]}: " if msg[:id] content << "#{msg[:id]}: " if msg[:id]
content << msg[:content] content << msg[:content]
content = inline_images(content, msg)
{ role: "user", content: content } { role: "user", content: content }
end end
@ -80,6 +81,33 @@ module DiscourseAi
# Longer term it will have over 1 million # Longer term it will have over 1 million
200_000 # Claude-3 has a 200k context window for now 200_000 # Claude-3 has a 200k context window for now
end end
private
def inline_images(content, message)
if model_name.include?("claude-3")
encoded_uploads = prompt.encoded_uploads(message)
if encoded_uploads.present?
new_content = []
new_content.concat(
encoded_uploads.map do |details|
{
source: {
type: "base64",
data: details[:base64],
media_type: details[:mime_type],
},
type: "image",
}
end,
)
new_content << { type: "text", text: content }
content = new_content
end
end
content
end
end end
end end
end end

View File

@ -66,12 +66,18 @@ module DiscourseAi
def with_prepared_responses(responses, llm: nil) def with_prepared_responses(responses, llm: nil)
@canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses) @canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses)
@canned_llm = llm @canned_llm = llm
@prompts = []
yield(@canned_response, llm) yield(@canned_response, llm, @prompts)
ensure ensure
# Don't leak prepared response if there's an exception. # Don't leak prepared response if there's an exception.
@canned_response = nil @canned_response = nil
@canned_llm = nil @canned_llm = nil
@prompts = nil
end
def record_prompt(prompt)
@prompts << prompt if @prompts
end end
def proxy(model_name) def proxy(model_name)
@ -138,6 +144,8 @@ module DiscourseAi
user:, user:,
&partial_read_blk &partial_read_blk
) )
self.class.record_prompt(prompt)
model_params = { max_tokens: max_tokens, stop_sequences: stop_sequences } model_params = { max_tokens: max_tokens, stop_sequences: stop_sequences }
model_params[:temperature] = temperature if temperature model_params[:temperature] = temperature if temperature

View File

@ -6,7 +6,7 @@ module DiscourseAi
INVALID_TURN = Class.new(StandardError) INVALID_TURN = Class.new(StandardError)
attr_reader :messages attr_reader :messages
attr_accessor :tools, :topic_id, :post_id attr_accessor :tools, :topic_id, :post_id, :max_pixels
def initialize( def initialize(
system_message_text = nil, system_message_text = nil,
@ -14,11 +14,14 @@ module DiscourseAi
tools: [], tools: [],
skip_validations: false, skip_validations: false,
topic_id: nil, topic_id: nil,
post_id: nil post_id: nil,
max_pixels: nil
) )
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array) raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array) raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
@max_pixels = max_pixels || 1_048_576
@topic_id = topic_id @topic_id = topic_id
@post_id = post_id @post_id = post_id
@ -38,11 +41,12 @@ module DiscourseAi
@tools = tools @tools = tools
end end
def push(type:, content:, id: nil, name: nil) def push(type:, content:, id: nil, name: nil, upload_ids: nil)
return if type == :system return if type == :system
new_message = { type: type, content: content } new_message = { type: type, content: content }
new_message[:name] = name.to_s if name new_message[:name] = name.to_s if name
new_message[:id] = id.to_s if id new_message[:id] = id.to_s if id
new_message[:upload_ids] = upload_ids if upload_ids
validate_message(new_message) validate_message(new_message)
validate_turn(messages.last, new_message) validate_turn(messages.last, new_message)
@ -54,6 +58,13 @@ module DiscourseAi
tools.present? tools.present?
end end
# helper method to get base64 encoded uploads
# at the correct dimentions
def encoded_uploads(message)
return [] if message[:upload_ids].blank?
UploadEncoder.encode(upload_ids: message[:upload_ids], max_pixels: max_pixels)
end
private private
def validate_message(message) def validate_message(message)
@ -63,11 +74,19 @@ module DiscourseAi
raise ArgumentError, "message type must be one of #{valid_types}" raise ArgumentError, "message type must be one of #{valid_types}"
end end
valid_keys = %i[type content id name] valid_keys = %i[type content id name upload_ids]
if (invalid_keys = message.keys - valid_keys).any? if (invalid_keys = message.keys - valid_keys).any?
raise ArgumentError, "message contains invalid keys: #{invalid_keys}" raise ArgumentError, "message contains invalid keys: #{invalid_keys}"
end end
if message[:type] == :upload_ids && !message[:upload_ids].is_a?(Array)
raise ArgumentError, "upload_ids must be an array of ids"
end
if message[:upload_ids].present? && message[:type] != :user
raise ArgumentError, "upload_ids are only supported for users"
end
raise ArgumentError, "message content must be a string" if !message[:content].is_a?(String) raise ArgumentError, "message content must be a string" if !message[:content].is_a?(String)
end end

View File

@ -0,0 +1,45 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
class UploadEncoder
def self.encode(upload_ids:, max_pixels:)
uploads = []
upload_ids.each do |upload_id|
upload = Upload.find(upload_id)
next if upload.blank?
next if upload.width.to_i == 0 || upload.height.to_i == 0
original_pixels = upload.width * upload.height
image = upload
if original_pixels > max_pixels
ratio = max_pixels.to_f / original_pixels
new_width = (ratio * upload.width).to_i
new_height = (ratio * upload.height).to_i
image = upload.get_optimized_image(new_width, new_height)
end
next if !image
mime_type = MiniMime.lookup_by_filename(upload.original_filename).content_type
path = Discourse.store.path_for(image)
if path.blank?
# download is protected with a DistributedMutex
external_copy = Discourse.store.download_safe(image)
path = external_copy&.path
end
encoded = Base64.strict_encode64(File.read(path))
uploads << { base64: encoded, mime_type: mime_type }
end
uploads
end
end
end
end

BIN
spec/fixtures/images/100x100.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 415 B

View File

@ -2,6 +2,10 @@
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
let(:llm) { DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-opus") } let(:llm) { DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-opus") }
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
let(:upload100x100) do
UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
end
let(:prompt) do let(:prompt) do
DiscourseAi::Completions::Prompt.new( DiscourseAi::Completions::Prompt.new(
@ -340,6 +344,73 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
expect(result.strip).to eq(expected) expect(result.strip).to eq(expected)
end end
it "can send images via a completion prompt" do
prompt =
DiscourseAi::Completions::Prompt.new(
"You are image bot",
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
)
encoded = prompt.encoded_uploads(prompt.messages.last)
request_body = {
model: "claude-3-opus-20240229",
max_tokens: 3000,
messages: [
{
role: "user",
content: [
{
type: "image",
source: {
type: "base64",
media_type: "image/jpeg",
data: encoded[0][:base64],
},
},
{ type: "text", text: "user1: hello" },
],
},
],
system: "You are image bot",
}
response_body = <<~STRING
{
"content": [
{
"text": "What a cool image",
"type": "text"
}
],
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
"model": "claude-3-opus-20240229",
"role": "assistant",
"stop_reason": "end_turn",
"stop_sequence": null,
"type": "message",
"usage": {
"input_tokens": 10,
"output_tokens": 25
}
}
STRING
requested_body = nil
stub_request(:post, "https://api.anthropic.com/v1/messages").with(
body:
proc do |req_body|
requested_body = JSON.parse(req_body, symbolize_names: true)
true
end,
).to_return(status: 200, body: response_body)
result = llm.generate(prompt, user: Discourse.system_user)
expect(result).to eq("What a cool image")
expect(requested_body).to eq(request_body)
end
it "can operate in regular mode" do it "can operate in regular mode" do
body = <<~STRING body = <<~STRING
{ {

View File

@ -6,6 +6,7 @@ RSpec.describe DiscourseAi::Completions::Prompt do
let(:system_insts) { "These are the system instructions." } let(:system_insts) { "These are the system instructions." }
let(:user_msg) { "Write something nice" } let(:user_msg) { "Write something nice" }
let(:username) { "username1" } let(:username) { "username1" }
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
describe ".new" do describe ".new" do
it "raises for invalid attributes" do it "raises for invalid attributes" do
@ -23,6 +24,38 @@ RSpec.describe DiscourseAi::Completions::Prompt do
end end
end end
describe "image support" do
it "allows adding uploads to messages" do
upload = UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
prompt.max_pixels = 300
prompt.push(type: :user, content: "hello", upload_ids: [upload.id])
expect(prompt.messages.last[:upload_ids]).to eq([upload.id])
expect(prompt.max_pixels).to eq(300)
encoded = prompt.encoded_uploads(prompt.messages.last)
expect(encoded.length).to eq(1)
expect(encoded[0][:mime_type]).to eq("image/jpeg")
old_base64 = encoded[0][:base64]
prompt.max_pixels = 1_000_000
encoded = prompt.encoded_uploads(prompt.messages.last)
expect(encoded.length).to eq(1)
expect(encoded[0][:mime_type]).to eq("image/jpeg")
new_base64 = encoded[0][:base64]
expect(new_base64.length).to be > old_base64.length
expect(new_base64.length).to be > 0
expect(old_base64.length).to be > 0
end
end
describe "#push" do describe "#push" do
describe "turn validations" do describe "turn validations" do
it "validates that tool messages have a previous tool_call message" do it "validates that tool messages have a previous tool_call message" do

View File

@ -61,6 +61,55 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end end
end end
describe "image support" do
before do
Jobs.run_immediately!
SiteSetting.ai_bot_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}"
end
fab!(:persona) do
AiPersona.create!(
name: "Test Persona",
description: "A test persona",
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
enabled: true,
system_prompt: "You are a helpful bot",
vision_enabled: true,
vision_max_pixels: 1_000,
default_llm: "anthropic:claude-3-opus",
mentionable: true,
)
end
fab!(:upload)
it "sends images to llm" do
post = nil
persona.create_user!
image = "![image](upload://#{upload.base62_sha1}.jpg)"
body = "Hey @#{persona.user.username}, can you help me with this image? #{image}"
prompts = nil
DiscourseAi::Completions::Llm.with_prepared_responses(
["I understood image"],
) do |_, _, inner_prompts|
post = create_post(title: "some new topic I created", raw: body)
prompts = inner_prompts
end
expect(prompts[0].messages[1][:upload_ids]).to eq([upload.id])
expect(prompts[0].max_pixels).to eq(1000)
post.topic.reload
last_post = post.topic.posts.order(:post_number).last
expect(last_post.raw).to eq("I understood image")
end
end
describe "persona with user support" do describe "persona with user support" do
before do before do
Jobs.run_immediately! Jobs.run_immediately!

View File

@ -9,13 +9,9 @@ RSpec.describe DiscourseAi::AiBot::Tools::DiscourseMetaSearch do
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
let(:progress_blk) { Proc.new {} } let(:progress_blk) { Proc.new {} }
let(:mock_search_json) do let(:mock_search_json) { plugin_file_from_fixtures("search.json", "search_meta").read }
File.read(File.expand_path("../../../../../fixtures/search_meta/search.json", __FILE__))
end
let(:mock_site_json) do let(:mock_site_json) { plugin_file_from_fixtures("site.json", "search_meta").read }
File.read(File.expand_path("../../../../../fixtures/search_meta/site.json", __FILE__))
end
before do before do
stub_request(:get, "https://meta.discourse.org/site.json").to_return( stub_request(:get, "https://meta.discourse.org/site.json").to_return(

View File

@ -204,6 +204,23 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
expect(persona.temperature).to eq(nil) expect(persona.temperature).to eq(nil)
end end
it "supports updating vision params" do
persona = Fabricate(:ai_persona, name: "test_bot2")
put "/admin/plugins/discourse-ai/ai-personas/#{persona.id}.json",
params: {
ai_persona: {
vision_enabled: true,
vision_max_pixels: 512 * 512,
},
}
expect(response).to have_http_status(:ok)
persona.reload
expect(persona.vision_enabled).to eq(true)
expect(persona.vision_max_pixels).to eq(512 * 512)
end
it "does not allow temperature and top p changes on stock personas" do it "does not allow temperature and top p changes on stock personas" do
put "/admin/plugins/discourse-ai/ai-personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json", put "/admin/plugins/discourse-ai/ai-personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json",
params: { params: {

View File

@ -46,6 +46,8 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
user: null, user: null,
user_id: null, user_id: null,
max_context_posts: 5, max_context_posts: 5,
vision_enabled: true,
vision_max_pixels: 100,
}; };
const aiPersona = AiPersona.create({ ...properties }); const aiPersona = AiPersona.create({ ...properties });
@ -77,6 +79,8 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
default_llm: "Default LLM", default_llm: "Default LLM",
mentionable: false, mentionable: false,
max_context_posts: 5, max_context_posts: 5,
vision_enabled: true,
vision_max_pixels: 100,
}; };
const aiPersona = AiPersona.create({ ...properties }); const aiPersona = AiPersona.create({ ...properties });