diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index b7535d12..21e6c91d 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -6,22 +6,6 @@ en: discourse_ai: "Discourse AI" js: discourse_automation: - ai_models: - gpt_4_turbo: GPT 4 Turbo - gpt_4: GPT 4 - gpt_3_5_turbo: GPT 3.5 Turbo - claude_2: Claude 2 - gemini_pro: Gemini Pro - gemini_1_5_pro: Gemini 1.5 Pro - gemini_1_5_flash: Gemini 1.5 Flash - claude_3_opus: Claude 3 Opus - claude_3_sonnet: Claude 3 Sonnet - claude_3_haiku: Claude 3 Haiku - mixtral_8x7b_instruct_v0_1: Mixtral 8x7B Instruct V0.1 - mistral_7b_instruct_v0_2: Mistral 7B Instruct V0.2 - command_r: Cohere Command R - command_r_plus: Cohere Command R+ - gpt_4o: GPT 4 Omni scriptables: llm_report: fields: diff --git a/db/post_migrate/20240619211337_update_automation_script_models.rb b/db/post_migrate/20240619211337_update_automation_script_models.rb new file mode 100644 index 00000000..53fe59be --- /dev/null +++ b/db/post_migrate/20240619211337_update_automation_script_models.rb @@ -0,0 +1,56 @@ +# frozen_string_literal: true + +class UpdateAutomationScriptModels < ActiveRecord::Migration[7.0] + def up + script_names = %w[llm_triage llm_report] + + fields_to_update = DB.query(<<~SQL, script_names: script_names) + SELECT fields.id, fields.metadata + FROM discourse_automation_fields fields + INNER JOIN discourse_automation_automations automations ON automations.id = fields.automation_id + WHERE fields.name = 'model' + AND automations.script IN (:script_names) + SQL + + return if fields_to_update.empty? + + updated_fields = + fields_to_update + .map do |field| + new_metadata = { "value" => translate_model(field.metadata["value"]) }.to_json + + "(#{field.id}, '#{new_metadata}')" if new_metadata.present? + end + .compact + + return if updated_fields.empty? + + DB.exec(<<~SQL) + UPDATE discourse_automation_fields AS fields + SET metadata = new_fields.metadata::jsonb + FROM (VALUES #{updated_fields.join(", ")}) AS new_fields(id, metadata) + WHERE new_fields.id::bigint = fields.id + SQL + end + + def translate_model(current_model) + options = DB.query(<<~SQL, name: current_model.to_s).to_a + SELECT id, provider + FROM llm_models + WHERE name = :name + SQL + + return if options.empty? + return "custom:#{options.first.id}" if options.length == 1 + + priority_provider = options.find { |o| o.provider == "aws_bedrock" || o.provider == "vllm" } + + return "custom:#{priority_provider.id}" if priority_provider + + "custom:#{options.first.id}" + end + + def down + raise ActiveRecord::IrreversibleMigration + end +end diff --git a/discourse_automation/llm_report.rb b/discourse_automation/llm_report.rb index 4ca4a983..c190af0f 100644 --- a/discourse_automation/llm_report.rb +++ b/discourse_automation/llm_report.rb @@ -25,7 +25,7 @@ if defined?(DiscourseAutomation) component: :choices, required: true, extra: { - content: DiscourseAi::Automation::AVAILABLE_MODELS, + content: DiscourseAi::Automation.available_models, } field :priority_group, component: :group diff --git a/discourse_automation/llm_triage.rb b/discourse_automation/llm_triage.rb index 84f78a9e..9309bded 100644 --- a/discourse_automation/llm_triage.rb +++ b/discourse_automation/llm_triage.rb @@ -25,7 +25,7 @@ if defined?(DiscourseAutomation) component: :choices, required: true, extra: { - content: DiscourseAi::Automation::AVAILABLE_MODELS, + content: DiscourseAi::Automation.available_models, } field :category, component: :category field :tags, component: :tags diff --git a/lib/automation.rb b/lib/automation.rb index 5f775c67..3cdc3515 100644 --- a/lib/automation.rb +++ b/lib/automation.rb @@ -2,57 +2,15 @@ module DiscourseAi module Automation - AVAILABLE_MODELS = [ - { id: "gpt-4o", name: "discourse_automation.ai_models.gpt_4o" }, - { id: "gpt-4-turbo", name: "discourse_automation.ai_models.gpt_4_turbo" }, - { id: "gpt-4", name: "discourse_automation.ai_models.gpt_4" }, - { id: "gpt-3.5-turbo", name: "discourse_automation.ai_models.gpt_3_5_turbo" }, - { id: "gemini-pro", name: "discourse_automation.ai_models.gemini_pro" }, - { id: "gemini-1.5-pro", name: "discourse_automation.ai_models.gemini_1_5_pro" }, - { id: "gemini-1.5-flash", name: "discourse_automation.ai_models.gemini_1_5_flash" }, - { id: "claude-2", name: "discourse_automation.ai_models.claude_2" }, - { id: "claude-3-sonnet", name: "discourse_automation.ai_models.claude_3_sonnet" }, - { id: "claude-3-opus", name: "discourse_automation.ai_models.claude_3_opus" }, - { id: "claude-3-haiku", name: "discourse_automation.ai_models.claude_3_haiku" }, - { - id: "mistralai/Mixtral-8x7B-Instruct-v0.1", - name: "discourse_automation.ai_models.mixtral_8x7b_instruct_v0_1", - }, - { - id: "mistralai/Mistral-7B-Instruct-v0.2", - name: "discourse_automation.ai_models.mistral_7b_instruct_v0_2", - }, - { id: "command-r", name: "discourse_automation.ai_models.command_r" }, - { id: "command-r-plus", name: "discourse_automation.ai_models.command_r_plus" }, - ] + def self.available_models + values = DB.query_hash(<<~SQL) + SELECT display_name AS translated_name, id AS id + FROM llm_models + SQL - def self.translate_model(model) - llm_model = LlmModel.find_by(name: model) - return "custom:#{llm_model.id}" if llm_model + values.each { |value_h| value_h["id"] = "custom:#{value_h["id"]}" } - return "google:#{model}" if model.start_with? "gemini" - return "open_ai:#{model}" if model.start_with? "gpt" - return "cohere:#{model}" if model.start_with? "command" - - if model.start_with? "claude" - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(model) - return "aws_bedrock:#{model}" - else - return "anthropic:#{model}" - end - end - - if model.start_with?("mistral") - if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(model) - return "vllm:#{model}" - elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?(model) - "hugging_face:#{model}" - else - "ollama:mistral" - end - end - - raise "Unknown model #{model}" + values end end end diff --git a/lib/automation/llm_triage.rb b/lib/automation/llm_triage.rb index c116043b..e7d6a024 100644 --- a/lib/automation/llm_triage.rb +++ b/lib/automation/llm_triage.rb @@ -32,8 +32,7 @@ module DiscourseAi result = nil - translated_model = DiscourseAi::Automation.translate_model(model) - llm = DiscourseAi::Completions::Llm.proxy(translated_model) + llm = DiscourseAi::Completions::Llm.proxy(model) result = llm.generate( diff --git a/lib/automation/report_runner.rb b/lib/automation/report_runner.rb index 5205e701..842de0bf 100644 --- a/lib/automation/report_runner.rb +++ b/lib/automation/report_runner.rb @@ -65,9 +65,7 @@ module DiscourseAi I18n.t("discourse_automation.scriptables.llm_report.title") end @model = model - - translated_model = DiscourseAi::Automation.translate_model(model) - @llm = DiscourseAi::Completions::Llm.proxy(translated_model) + @llm = DiscourseAi::Completions::Llm.proxy(model) @category_ids = category_ids @tags = tags @allow_secure_categories = allow_secure_categories