Removed files not to be included in CM 3.0.
git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1295533 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
52feb7c331
commit
f12bb6ddd5
2
pom.xml
2
pom.xml
|
@ -24,7 +24,7 @@
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
<groupId>org.apache.commons</groupId>
|
<groupId>org.apache.commons</groupId>
|
||||||
<artifactId>commons-math3</artifactId>
|
<artifactId>commons-math3</artifactId>
|
||||||
<version>3.0-SNAPSHOT</version>
|
<version>3.0</version>
|
||||||
<name>Commons Math</name>
|
<name>Commons Math</name>
|
||||||
|
|
||||||
<inceptionYear>2003</inceptionYear>
|
<inceptionYear>2003</inceptionYear>
|
||||||
|
|
|
@ -1,420 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2011 The Apache Software Foundation.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
package org.apache.commons.math3.linear;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import org.apache.commons.math3.util.MathArrays;
|
|
||||||
import org.apache.commons.math3.exception.ConvergenceException;
|
|
||||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
|
||||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
|
||||||
import org.apache.commons.math3.util.FastMath;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @author gregsterijevski
|
|
||||||
*/
|
|
||||||
public class PivotingQRDecomposition {
|
|
||||||
|
|
||||||
private double[][] qr;
|
|
||||||
/** The diagonal elements of R. */
|
|
||||||
private double[] rDiag;
|
|
||||||
/** Cached value of Q. */
|
|
||||||
private RealMatrix cachedQ;
|
|
||||||
/** Cached value of QT. */
|
|
||||||
private RealMatrix cachedQT;
|
|
||||||
/** Cached value of R. */
|
|
||||||
private RealMatrix cachedR;
|
|
||||||
/** Cached value of H. */
|
|
||||||
private RealMatrix cachedH;
|
|
||||||
/** permutation info */
|
|
||||||
private int[] permutation;
|
|
||||||
/** the rank **/
|
|
||||||
private int rank;
|
|
||||||
/** vector of column multipliers */
|
|
||||||
private double[] beta;
|
|
||||||
|
|
||||||
public boolean isSingular() {
|
|
||||||
return rank != qr[0].length;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getRank() {
|
|
||||||
return rank;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int[] getOrder() {
|
|
||||||
return MathArrays.copyOf(permutation);
|
|
||||||
}
|
|
||||||
|
|
||||||
public PivotingQRDecomposition(RealMatrix matrix) throws ConvergenceException {
|
|
||||||
this(matrix, 1.0e-16, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public PivotingQRDecomposition(RealMatrix matrix, boolean allowPivot) throws ConvergenceException {
|
|
||||||
this(matrix, 1.0e-16, allowPivot);
|
|
||||||
}
|
|
||||||
|
|
||||||
public PivotingQRDecomposition(RealMatrix matrix, double qrRankingThreshold,
|
|
||||||
boolean allowPivot) throws ConvergenceException {
|
|
||||||
final int rows = matrix.getRowDimension();
|
|
||||||
final int cols = matrix.getColumnDimension();
|
|
||||||
qr = matrix.getData();
|
|
||||||
rDiag = new double[cols];
|
|
||||||
//final double[] norms = new double[cols];
|
|
||||||
this.beta = new double[cols];
|
|
||||||
this.permutation = new int[cols];
|
|
||||||
cachedQ = null;
|
|
||||||
cachedQT = null;
|
|
||||||
cachedR = null;
|
|
||||||
cachedH = null;
|
|
||||||
|
|
||||||
/*- initialize the permutation vector and calculate the norms */
|
|
||||||
for (int k = 0; k < cols; ++k) {
|
|
||||||
permutation[k] = k;
|
|
||||||
}
|
|
||||||
// transform the matrix column after column
|
|
||||||
for (int k = 0; k < cols; ++k) {
|
|
||||||
// select the column with the greatest norm on active components
|
|
||||||
int nextColumn = -1;
|
|
||||||
double ak2 = Double.NEGATIVE_INFINITY;
|
|
||||||
if (allowPivot) {
|
|
||||||
for (int i = k; i < cols; ++i) {
|
|
||||||
double norm2 = 0;
|
|
||||||
for (int j = k; j < rows; ++j) {
|
|
||||||
final double aki = qr[j][permutation[i]];
|
|
||||||
norm2 += aki * aki;
|
|
||||||
}
|
|
||||||
if (Double.isInfinite(norm2) || Double.isNaN(norm2)) {
|
|
||||||
throw new ConvergenceException(LocalizedFormats.UNABLE_TO_PERFORM_QR_DECOMPOSITION_ON_JACOBIAN,
|
|
||||||
rows, cols);
|
|
||||||
}
|
|
||||||
if (norm2 > ak2) {
|
|
||||||
nextColumn = i;
|
|
||||||
ak2 = norm2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
nextColumn = k;
|
|
||||||
ak2 = 0.0;
|
|
||||||
for (int j = k; j < rows; ++j) {
|
|
||||||
final double aki = qr[j][k];
|
|
||||||
ak2 += aki * aki;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (ak2 <= qrRankingThreshold) {
|
|
||||||
rank = k;
|
|
||||||
for (int i = rank; i < rows; i++) {
|
|
||||||
for (int j = i + 1; j < cols; j++) {
|
|
||||||
qr[i][permutation[j]] = 0.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
final int pk = permutation[nextColumn];
|
|
||||||
permutation[nextColumn] = permutation[k];
|
|
||||||
permutation[k] = pk;
|
|
||||||
|
|
||||||
// choose alpha such that Hk.u = alpha ek
|
|
||||||
final double akk = qr[k][pk];
|
|
||||||
final double alpha = (akk > 0) ? -FastMath.sqrt(ak2) : FastMath.sqrt(ak2);
|
|
||||||
final double betak = 1.0 / (ak2 - akk * alpha);
|
|
||||||
beta[pk] = betak;
|
|
||||||
|
|
||||||
// transform the current column
|
|
||||||
rDiag[pk] = alpha;
|
|
||||||
qr[k][pk] -= alpha;
|
|
||||||
|
|
||||||
// transform the remaining columns
|
|
||||||
for (int dk = cols - 1 - k; dk > 0; --dk) {
|
|
||||||
double gamma = 0;
|
|
||||||
for (int j = k; j < rows; ++j) {
|
|
||||||
gamma += qr[j][pk] * qr[j][permutation[k + dk]];
|
|
||||||
}
|
|
||||||
gamma *= betak;
|
|
||||||
for (int j = k; j < rows; ++j) {
|
|
||||||
qr[j][permutation[k + dk]] -= gamma * qr[j][pk];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rank = cols;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the matrix Q of the decomposition.
|
|
||||||
* <p>Q is an orthogonal matrix</p>
|
|
||||||
* @return the Q matrix
|
|
||||||
*/
|
|
||||||
public RealMatrix getQ() {
|
|
||||||
if (cachedQ == null) {
|
|
||||||
cachedQ = getQT().transpose();
|
|
||||||
}
|
|
||||||
return cachedQ;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the transpose of the matrix Q of the decomposition.
|
|
||||||
* <p>Q is an orthogonal matrix</p>
|
|
||||||
* @return the Q matrix
|
|
||||||
*/
|
|
||||||
public RealMatrix getQT() {
|
|
||||||
if (cachedQT == null) {
|
|
||||||
|
|
||||||
// QT is supposed to be m x m
|
|
||||||
final int m = qr.length;
|
|
||||||
cachedQT = MatrixUtils.createRealMatrix(m, m);
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
|
|
||||||
* applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
|
|
||||||
* succession to the result
|
|
||||||
*/
|
|
||||||
for (int minor = m - 1; minor >= rank; minor--) {
|
|
||||||
cachedQT.setEntry(minor, minor, 1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int minor = rank - 1; minor >= 0; minor--) {
|
|
||||||
//final double[] qrtMinor = qrt[minor];
|
|
||||||
final int p_minor = permutation[minor];
|
|
||||||
cachedQT.setEntry(minor, minor, 1.0);
|
|
||||||
//if (qrtMinor[minor] != 0.0) {
|
|
||||||
for (int col = minor; col < m; col++) {
|
|
||||||
double alpha = 0.0;
|
|
||||||
for (int row = minor; row < m; row++) {
|
|
||||||
alpha -= cachedQT.getEntry(col, row) * qr[row][p_minor];
|
|
||||||
}
|
|
||||||
alpha /= rDiag[p_minor] * qr[minor][p_minor];
|
|
||||||
for (int row = minor; row < m; row++) {
|
|
||||||
cachedQT.addToEntry(col, row, -alpha * qr[row][p_minor]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// return the cached matrix
|
|
||||||
return cachedQT;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the matrix R of the decomposition.
|
|
||||||
* <p>R is an upper-triangular matrix</p>
|
|
||||||
* @return the R matrix
|
|
||||||
*/
|
|
||||||
public RealMatrix getR() {
|
|
||||||
if (cachedR == null) {
|
|
||||||
// R is supposed to be m x n
|
|
||||||
final int n = qr[0].length;
|
|
||||||
final int m = qr.length;
|
|
||||||
cachedR = MatrixUtils.createRealMatrix(m, n);
|
|
||||||
// copy the diagonal from rDiag and the upper triangle of qr
|
|
||||||
for (int row = rank - 1; row >= 0; row--) {
|
|
||||||
cachedR.setEntry(row, row, rDiag[permutation[row]]);
|
|
||||||
for (int col = row + 1; col < n; col++) {
|
|
||||||
cachedR.setEntry(row, col, qr[row][permutation[col]]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// return the cached matrix
|
|
||||||
return cachedR;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RealMatrix getH() {
|
|
||||||
if (cachedH == null) {
|
|
||||||
final int n = qr[0].length;
|
|
||||||
final int m = qr.length;
|
|
||||||
cachedH = MatrixUtils.createRealMatrix(m, n);
|
|
||||||
for (int i = 0; i < m; ++i) {
|
|
||||||
for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
|
|
||||||
final int p_j = permutation[j];
|
|
||||||
cachedH.setEntry(i, j, qr[i][p_j] / -rDiag[p_j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// return the cached matrix
|
|
||||||
return cachedH;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RealMatrix getPermutationMatrix() {
|
|
||||||
RealMatrix rm = MatrixUtils.createRealMatrix(qr[0].length, qr[0].length);
|
|
||||||
for (int i = 0; i < this.qr[0].length; i++) {
|
|
||||||
rm.setEntry(permutation[i], i, 1.0);
|
|
||||||
}
|
|
||||||
return rm;
|
|
||||||
}
|
|
||||||
|
|
||||||
public DecompositionSolver getSolver() {
|
|
||||||
return new Solver(qr, rDiag, permutation, rank);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Specialized solver. */
|
|
||||||
private static class Solver implements DecompositionSolver {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A packed TRANSPOSED representation of the QR decomposition.
|
|
||||||
* <p>The elements BELOW the diagonal are the elements of the UPPER triangular
|
|
||||||
* matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
|
|
||||||
* from which an explicit form of Q can be recomputed if desired.</p>
|
|
||||||
*/
|
|
||||||
private final double[][] qr;
|
|
||||||
/** The diagonal elements of R. */
|
|
||||||
private final double[] rDiag;
|
|
||||||
/** The rank of the matrix */
|
|
||||||
private final int rank;
|
|
||||||
/** The permutation matrix */
|
|
||||||
private final int[] perm;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build a solver from decomposed matrix.
|
|
||||||
* @param qrt packed TRANSPOSED representation of the QR decomposition
|
|
||||||
* @param rDiag diagonal elements of R
|
|
||||||
*/
|
|
||||||
private Solver(final double[][] qr, final double[] rDiag, int[] perm, int rank) {
|
|
||||||
this.qr = qr;
|
|
||||||
this.rDiag = rDiag;
|
|
||||||
this.perm = perm;
|
|
||||||
this.rank = rank;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
|
||||||
public boolean isNonSingular() {
|
|
||||||
if (qr.length >= qr[0].length) {
|
|
||||||
return rank == qr[0].length;
|
|
||||||
} else { //qr.length < qr[0].length
|
|
||||||
return rank == qr.length;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
|
||||||
public RealVector solve(RealVector b) {
|
|
||||||
final int n = qr[0].length;
|
|
||||||
final int m = qr.length;
|
|
||||||
if (b.getDimension() != m) {
|
|
||||||
throw new DimensionMismatchException(b.getDimension(), m);
|
|
||||||
}
|
|
||||||
if (!isNonSingular()) {
|
|
||||||
throw new SingularMatrixException();
|
|
||||||
}
|
|
||||||
|
|
||||||
final double[] x = new double[n];
|
|
||||||
final double[] y = b.toArray();
|
|
||||||
|
|
||||||
// apply Householder transforms to solve Q.y = b
|
|
||||||
for (int minor = 0; minor < rank; minor++) {
|
|
||||||
final int m_idx = perm[minor];
|
|
||||||
double dotProduct = 0;
|
|
||||||
for (int row = minor; row < m; row++) {
|
|
||||||
dotProduct += y[row] * qr[row][m_idx];
|
|
||||||
}
|
|
||||||
dotProduct /= rDiag[m_idx] * qr[minor][m_idx];
|
|
||||||
for (int row = minor; row < m; row++) {
|
|
||||||
y[row] += dotProduct * qr[row][m_idx];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// solve triangular system R.x = y
|
|
||||||
for (int row = rank - 1; row >= 0; --row) {
|
|
||||||
final int m_row = perm[row];
|
|
||||||
y[row] /= rDiag[m_row];
|
|
||||||
final double yRow = y[row];
|
|
||||||
//final double[] qrtRow = qrt[row];
|
|
||||||
x[perm[row]] = yRow;
|
|
||||||
for (int i = 0; i < row; i++) {
|
|
||||||
y[i] -= yRow * qr[i][m_row];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return new ArrayRealVector(x, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
|
||||||
public RealMatrix solve(RealMatrix b) {
|
|
||||||
final int cols = qr[0].length;
|
|
||||||
final int rows = qr.length;
|
|
||||||
if (b.getRowDimension() != rows) {
|
|
||||||
throw new DimensionMismatchException(b.getRowDimension(), rows);
|
|
||||||
}
|
|
||||||
if (!isNonSingular()) {
|
|
||||||
throw new SingularMatrixException();
|
|
||||||
}
|
|
||||||
|
|
||||||
final int columns = b.getColumnDimension();
|
|
||||||
final int blockSize = BlockRealMatrix.BLOCK_SIZE;
|
|
||||||
final int cBlocks = (columns + blockSize - 1) / blockSize;
|
|
||||||
final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(cols, columns);
|
|
||||||
final double[][] y = new double[b.getRowDimension()][blockSize];
|
|
||||||
final double[] alpha = new double[blockSize];
|
|
||||||
//final BlockRealMatrix result = new BlockRealMatrix(cols, columns, xBlocks, false);
|
|
||||||
for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
|
|
||||||
final int kStart = kBlock * blockSize;
|
|
||||||
final int kEnd = FastMath.min(kStart + blockSize, columns);
|
|
||||||
final int kWidth = kEnd - kStart;
|
|
||||||
// get the right hand side vector
|
|
||||||
b.copySubMatrix(0, rows - 1, kStart, kEnd - 1, y);
|
|
||||||
|
|
||||||
// apply Householder transforms to solve Q.y = b
|
|
||||||
for (int minor = 0; minor < rank; minor++) {
|
|
||||||
final int m_idx = perm[minor];
|
|
||||||
final double factor = 1.0 / (rDiag[m_idx] * qr[minor][m_idx]);
|
|
||||||
|
|
||||||
Arrays.fill(alpha, 0, kWidth, 0.0);
|
|
||||||
for (int row = minor; row < rows; ++row) {
|
|
||||||
final double d = qr[row][m_idx];
|
|
||||||
final double[] yRow = y[row];
|
|
||||||
for (int k = 0; k < kWidth; ++k) {
|
|
||||||
alpha[k] += d * yRow[k];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int k = 0; k < kWidth; ++k) {
|
|
||||||
alpha[k] *= factor;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int row = minor; row < rows; ++row) {
|
|
||||||
final double d = qr[row][m_idx];
|
|
||||||
final double[] yRow = y[row];
|
|
||||||
for (int k = 0; k < kWidth; ++k) {
|
|
||||||
yRow[k] += alpha[k] * d;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// solve triangular system R.x = y
|
|
||||||
for (int j = rank - 1; j >= 0; --j) {
|
|
||||||
final int jBlock = perm[j] / blockSize; //which block
|
|
||||||
final int jStart = jBlock * blockSize; // idx of top corner of block in my coord
|
|
||||||
final double factor = 1.0 / rDiag[perm[j]];
|
|
||||||
final double[] yJ = y[j];
|
|
||||||
final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
|
|
||||||
int index = (perm[j] - jStart) * kWidth; //to local (block) coordinates
|
|
||||||
for (int k = 0; k < kWidth; ++k) {
|
|
||||||
yJ[k] *= factor;
|
|
||||||
xBlock[index++] = yJ[k];
|
|
||||||
}
|
|
||||||
for (int i = 0; i < j; ++i) {
|
|
||||||
final double rIJ = qr[i][perm[j]];
|
|
||||||
final double[] yI = y[i];
|
|
||||||
for (int k = 0; k < kWidth; ++k) {
|
|
||||||
yI[k] -= yJ[k] * rIJ;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//return result;
|
|
||||||
return new BlockRealMatrix(cols, columns, xBlocks, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
|
||||||
public RealMatrix getInverse() {
|
|
||||||
return solve(MatrixUtils.createRealIdentityMatrix(rDiag.length));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,257 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.commons.math3.linear;
|
|
||||||
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
|
|
||||||
import org.apache.commons.math3.exception.ConvergenceException;
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
|
|
||||||
public class PivotingQRDecompositionTest {
|
|
||||||
double[][] testData3x3NonSingular = {
|
|
||||||
{ 12, -51, 4 },
|
|
||||||
{ 6, 167, -68 },
|
|
||||||
{ -4, 24, -41 }, };
|
|
||||||
|
|
||||||
double[][] testData3x3Singular = {
|
|
||||||
{ 1, 4, 7, },
|
|
||||||
{ 2, 5, 8, },
|
|
||||||
{ 3, 6, 9, }, };
|
|
||||||
|
|
||||||
double[][] testData3x4 = {
|
|
||||||
{ 12, -51, 4, 1 },
|
|
||||||
{ 6, 167, -68, 2 },
|
|
||||||
{ -4, 24, -41, 3 }, };
|
|
||||||
|
|
||||||
double[][] testData4x3 = {
|
|
||||||
{ 12, -51, 4, },
|
|
||||||
{ 6, 167, -68, },
|
|
||||||
{ -4, 24, -41, },
|
|
||||||
{ -5, 34, 7, }, };
|
|
||||||
|
|
||||||
private static final double entryTolerance = 10e-16;
|
|
||||||
|
|
||||||
private static final double normTolerance = 10e-14;
|
|
||||||
|
|
||||||
/** test dimensions */
|
|
||||||
@Test
|
|
||||||
public void testDimensions() throws ConvergenceException {
|
|
||||||
checkDimension(MatrixUtils.createRealMatrix(testData3x3NonSingular));
|
|
||||||
|
|
||||||
checkDimension(MatrixUtils.createRealMatrix(testData4x3));
|
|
||||||
|
|
||||||
checkDimension(MatrixUtils.createRealMatrix(testData3x4));
|
|
||||||
|
|
||||||
Random r = new Random(643895747384642l);
|
|
||||||
int p = (5 * BlockRealMatrix.BLOCK_SIZE) / 4;
|
|
||||||
int q = (7 * BlockRealMatrix.BLOCK_SIZE) / 4;
|
|
||||||
checkDimension(createTestMatrix(r, p, q));
|
|
||||||
checkDimension(createTestMatrix(r, q, p));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private void checkDimension(RealMatrix m) throws ConvergenceException {
|
|
||||||
int rows = m.getRowDimension();
|
|
||||||
int columns = m.getColumnDimension();
|
|
||||||
PivotingQRDecomposition qr = new PivotingQRDecomposition(m);
|
|
||||||
Assert.assertEquals(rows, qr.getQ().getRowDimension());
|
|
||||||
Assert.assertEquals(rows, qr.getQ().getColumnDimension());
|
|
||||||
Assert.assertEquals(rows, qr.getR().getRowDimension());
|
|
||||||
Assert.assertEquals(columns, qr.getR().getColumnDimension());
|
|
||||||
}
|
|
||||||
|
|
||||||
/** test A = QR */
|
|
||||||
@Test
|
|
||||||
public void testAEqualQR() throws ConvergenceException {
|
|
||||||
checkAEqualQR(MatrixUtils.createRealMatrix(testData3x3NonSingular));
|
|
||||||
|
|
||||||
checkAEqualQR(MatrixUtils.createRealMatrix(testData3x3Singular));
|
|
||||||
|
|
||||||
checkAEqualQR(MatrixUtils.createRealMatrix(testData3x4));
|
|
||||||
|
|
||||||
checkAEqualQR(MatrixUtils.createRealMatrix(testData4x3));
|
|
||||||
|
|
||||||
Random r = new Random(643895747384642l);
|
|
||||||
int p = (5 * BlockRealMatrix.BLOCK_SIZE) / 4;
|
|
||||||
int q = (7 * BlockRealMatrix.BLOCK_SIZE) / 4;
|
|
||||||
checkAEqualQR(createTestMatrix(r, p, q));
|
|
||||||
|
|
||||||
checkAEqualQR(createTestMatrix(r, q, p));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private void checkAEqualQR(RealMatrix m) throws ConvergenceException {
|
|
||||||
PivotingQRDecomposition qr = new PivotingQRDecomposition(m);
|
|
||||||
RealMatrix prod = qr.getQ().multiply(qr.getR()).multiply(qr.getPermutationMatrix().transpose());
|
|
||||||
double norm = prod.subtract(m).getNorm();
|
|
||||||
Assert.assertEquals(0, norm, normTolerance);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** test the orthogonality of Q */
|
|
||||||
@Test
|
|
||||||
public void testQOrthogonal() throws ConvergenceException{
|
|
||||||
checkQOrthogonal(MatrixUtils.createRealMatrix(testData3x3NonSingular));
|
|
||||||
|
|
||||||
checkQOrthogonal(MatrixUtils.createRealMatrix(testData3x3Singular));
|
|
||||||
|
|
||||||
checkQOrthogonal(MatrixUtils.createRealMatrix(testData3x4));
|
|
||||||
|
|
||||||
checkQOrthogonal(MatrixUtils.createRealMatrix(testData4x3));
|
|
||||||
|
|
||||||
Random r = new Random(643895747384642l);
|
|
||||||
int p = (5 * BlockRealMatrix.BLOCK_SIZE) / 4;
|
|
||||||
int q = (7 * BlockRealMatrix.BLOCK_SIZE) / 4;
|
|
||||||
checkQOrthogonal(createTestMatrix(r, p, q));
|
|
||||||
|
|
||||||
checkQOrthogonal(createTestMatrix(r, q, p));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private void checkQOrthogonal(RealMatrix m) throws ConvergenceException{
|
|
||||||
PivotingQRDecomposition qr = new PivotingQRDecomposition(m);
|
|
||||||
RealMatrix eye = MatrixUtils.createRealIdentityMatrix(m.getRowDimension());
|
|
||||||
double norm = qr.getQT().multiply(qr.getQ()).subtract(eye).getNorm();
|
|
||||||
Assert.assertEquals(0, norm, normTolerance);
|
|
||||||
}
|
|
||||||
//
|
|
||||||
/** test that R is upper triangular */
|
|
||||||
@Test
|
|
||||||
public void testRUpperTriangular() throws ConvergenceException{
|
|
||||||
RealMatrix matrix = MatrixUtils.createRealMatrix(testData3x3NonSingular);
|
|
||||||
checkUpperTriangular(new PivotingQRDecomposition(matrix).getR());
|
|
||||||
|
|
||||||
matrix = MatrixUtils.createRealMatrix(testData3x3Singular);
|
|
||||||
checkUpperTriangular(new PivotingQRDecomposition(matrix).getR());
|
|
||||||
|
|
||||||
matrix = MatrixUtils.createRealMatrix(testData3x4);
|
|
||||||
checkUpperTriangular(new PivotingQRDecomposition(matrix).getR());
|
|
||||||
|
|
||||||
matrix = MatrixUtils.createRealMatrix(testData4x3);
|
|
||||||
checkUpperTriangular(new PivotingQRDecomposition(matrix).getR());
|
|
||||||
|
|
||||||
Random r = new Random(643895747384642l);
|
|
||||||
int p = (5 * BlockRealMatrix.BLOCK_SIZE) / 4;
|
|
||||||
int q = (7 * BlockRealMatrix.BLOCK_SIZE) / 4;
|
|
||||||
matrix = createTestMatrix(r, p, q);
|
|
||||||
checkUpperTriangular(new PivotingQRDecomposition(matrix).getR());
|
|
||||||
|
|
||||||
matrix = createTestMatrix(r, p, q);
|
|
||||||
checkUpperTriangular(new PivotingQRDecomposition(matrix).getR());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private void checkUpperTriangular(RealMatrix m) {
|
|
||||||
m.walkInOptimizedOrder(new DefaultRealMatrixPreservingVisitor() {
|
|
||||||
@Override
|
|
||||||
public void visit(int row, int column, double value) {
|
|
||||||
if (column < row) {
|
|
||||||
Assert.assertEquals(0.0, value, entryTolerance);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/** test that H is trapezoidal */
|
|
||||||
@Test
|
|
||||||
public void testHTrapezoidal() throws ConvergenceException{
|
|
||||||
RealMatrix matrix = MatrixUtils.createRealMatrix(testData3x3NonSingular);
|
|
||||||
checkTrapezoidal(new PivotingQRDecomposition(matrix).getH());
|
|
||||||
|
|
||||||
matrix = MatrixUtils.createRealMatrix(testData3x3Singular);
|
|
||||||
checkTrapezoidal(new PivotingQRDecomposition(matrix).getH());
|
|
||||||
|
|
||||||
matrix = MatrixUtils.createRealMatrix(testData3x4);
|
|
||||||
checkTrapezoidal(new PivotingQRDecomposition(matrix).getH());
|
|
||||||
|
|
||||||
matrix = MatrixUtils.createRealMatrix(testData4x3);
|
|
||||||
checkTrapezoidal(new PivotingQRDecomposition(matrix).getH());
|
|
||||||
|
|
||||||
Random r = new Random(643895747384642l);
|
|
||||||
int p = (5 * BlockRealMatrix.BLOCK_SIZE) / 4;
|
|
||||||
int q = (7 * BlockRealMatrix.BLOCK_SIZE) / 4;
|
|
||||||
matrix = createTestMatrix(r, p, q);
|
|
||||||
checkTrapezoidal(new PivotingQRDecomposition(matrix).getH());
|
|
||||||
|
|
||||||
matrix = createTestMatrix(r, p, q);
|
|
||||||
checkTrapezoidal(new PivotingQRDecomposition(matrix).getH());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private void checkTrapezoidal(RealMatrix m) {
|
|
||||||
m.walkInOptimizedOrder(new DefaultRealMatrixPreservingVisitor() {
|
|
||||||
@Override
|
|
||||||
public void visit(int row, int column, double value) {
|
|
||||||
if (column > row) {
|
|
||||||
Assert.assertEquals(0.0, value, entryTolerance);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
/** test matrices values */
|
|
||||||
@Test
|
|
||||||
public void testMatricesValues() throws ConvergenceException{
|
|
||||||
PivotingQRDecomposition qr =
|
|
||||||
new PivotingQRDecomposition(MatrixUtils.createRealMatrix(testData3x3NonSingular),false);
|
|
||||||
RealMatrix qRef = MatrixUtils.createRealMatrix(new double[][] {
|
|
||||||
{ -12.0 / 14.0, 69.0 / 175.0, -58.0 / 175.0 },
|
|
||||||
{ -6.0 / 14.0, -158.0 / 175.0, 6.0 / 175.0 },
|
|
||||||
{ 4.0 / 14.0, -30.0 / 175.0, -165.0 / 175.0 }
|
|
||||||
});
|
|
||||||
RealMatrix rRef = MatrixUtils.createRealMatrix(new double[][] {
|
|
||||||
{ -14.0, -21.0, 14.0 },
|
|
||||||
{ 0.0, -175.0, 70.0 },
|
|
||||||
{ 0.0, 0.0, 35.0 }
|
|
||||||
});
|
|
||||||
RealMatrix hRef = MatrixUtils.createRealMatrix(new double[][] {
|
|
||||||
{ 26.0 / 14.0, 0.0, 0.0 },
|
|
||||||
{ 6.0 / 14.0, 648.0 / 325.0, 0.0 },
|
|
||||||
{ -4.0 / 14.0, 36.0 / 325.0, 2.0 }
|
|
||||||
});
|
|
||||||
|
|
||||||
// check values against known references
|
|
||||||
RealMatrix q = qr.getQ();
|
|
||||||
Assert.assertEquals(0, q.subtract(qRef).getNorm(), 1.0e-13);
|
|
||||||
RealMatrix qT = qr.getQT();
|
|
||||||
Assert.assertEquals(0, qT.subtract(qRef.transpose()).getNorm(), 1.0e-13);
|
|
||||||
RealMatrix r = qr.getR();
|
|
||||||
Assert.assertEquals(0, r.subtract(rRef).getNorm(), 1.0e-13);
|
|
||||||
RealMatrix h = qr.getH();
|
|
||||||
Assert.assertEquals(0, h.subtract(hRef).getNorm(), 1.0e-13);
|
|
||||||
|
|
||||||
// check the same cached instance is returned the second time
|
|
||||||
Assert.assertTrue(q == qr.getQ());
|
|
||||||
Assert.assertTrue(r == qr.getR());
|
|
||||||
Assert.assertTrue(h == qr.getH());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private RealMatrix createTestMatrix(final Random r, final int rows, final int columns) {
|
|
||||||
RealMatrix m = MatrixUtils.createRealMatrix(rows, columns);
|
|
||||||
m.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor(){
|
|
||||||
@Override
|
|
||||||
public double visit(int row, int column, double value) {
|
|
||||||
return 2.0 * r.nextDouble() - 1.0;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return m;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,201 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.commons.math3.linear;
|
|
||||||
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.exception.ConvergenceException;
|
|
||||||
import org.apache.commons.math3.exception.MathIllegalArgumentException;
|
|
||||||
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.Assert;
|
|
||||||
|
|
||||||
public class PivotingQRSolverTest {
|
|
||||||
double[][] testData3x3NonSingular = {
|
|
||||||
{ 12, -51, 4 },
|
|
||||||
{ 6, 167, -68 },
|
|
||||||
{ -4, 24, -41 }
|
|
||||||
};
|
|
||||||
|
|
||||||
double[][] testData3x3Singular = {
|
|
||||||
{ 1, 2, 2 },
|
|
||||||
{ 2, 4, 6 },
|
|
||||||
{ 4, 8, 12 }
|
|
||||||
};
|
|
||||||
|
|
||||||
double[][] testData3x4 = {
|
|
||||||
{ 12, -51, 4, 1 },
|
|
||||||
{ 6, 167, -68, 2 },
|
|
||||||
{ -4, 24, -41, 3 }
|
|
||||||
};
|
|
||||||
|
|
||||||
double[][] testData4x3 = {
|
|
||||||
{ 12, -51, 4 },
|
|
||||||
{ 6, 167, -68 },
|
|
||||||
{ -4, 24, -41 },
|
|
||||||
{ -5, 34, 7 }
|
|
||||||
};
|
|
||||||
|
|
||||||
/** test rank */
|
|
||||||
@Test
|
|
||||||
public void testRank() throws ConvergenceException {
|
|
||||||
DecompositionSolver solver =
|
|
||||||
new PivotingQRDecomposition(MatrixUtils.createRealMatrix(testData3x3NonSingular)).getSolver();
|
|
||||||
Assert.assertTrue(solver.isNonSingular());
|
|
||||||
|
|
||||||
solver = new PivotingQRDecomposition(MatrixUtils.createRealMatrix(testData3x3Singular)).getSolver();
|
|
||||||
Assert.assertFalse(solver.isNonSingular());
|
|
||||||
|
|
||||||
solver = new PivotingQRDecomposition(MatrixUtils.createRealMatrix(testData3x4)).getSolver();
|
|
||||||
Assert.assertTrue(solver.isNonSingular());
|
|
||||||
|
|
||||||
solver = new PivotingQRDecomposition(MatrixUtils.createRealMatrix(testData4x3)).getSolver();
|
|
||||||
Assert.assertTrue(solver.isNonSingular());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/** test solve dimension errors */
|
|
||||||
@Test
|
|
||||||
public void testSolveDimensionErrors() throws ConvergenceException {
|
|
||||||
DecompositionSolver solver =
|
|
||||||
new PivotingQRDecomposition(MatrixUtils.createRealMatrix(testData3x3NonSingular)).getSolver();
|
|
||||||
RealMatrix b = MatrixUtils.createRealMatrix(new double[2][2]);
|
|
||||||
try {
|
|
||||||
solver.solve(b);
|
|
||||||
Assert.fail("an exception should have been thrown");
|
|
||||||
} catch (MathIllegalArgumentException iae) {
|
|
||||||
// expected behavior
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
solver.solve(b.getColumnVector(0));
|
|
||||||
Assert.fail("an exception should have been thrown");
|
|
||||||
} catch (MathIllegalArgumentException iae) {
|
|
||||||
// expected behavior
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** test solve rank errors */
|
|
||||||
@Test
|
|
||||||
public void testSolveRankErrors() throws ConvergenceException {
|
|
||||||
DecompositionSolver solver =
|
|
||||||
new PivotingQRDecomposition(MatrixUtils.createRealMatrix(testData3x3Singular)).getSolver();
|
|
||||||
RealMatrix b = MatrixUtils.createRealMatrix(new double[3][2]);
|
|
||||||
try {
|
|
||||||
solver.solve(b);
|
|
||||||
Assert.fail("an exception should have been thrown");
|
|
||||||
} catch (SingularMatrixException iae) {
|
|
||||||
// expected behavior
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
solver.solve(b.getColumnVector(0));
|
|
||||||
Assert.fail("an exception should have been thrown");
|
|
||||||
} catch (SingularMatrixException iae) {
|
|
||||||
// expected behavior
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** test solve */
|
|
||||||
@Test
|
|
||||||
public void testSolve() throws ConvergenceException {
|
|
||||||
PivotingQRDecomposition decomposition =
|
|
||||||
new PivotingQRDecomposition(MatrixUtils.createRealMatrix(testData3x3NonSingular));
|
|
||||||
DecompositionSolver solver = decomposition.getSolver();
|
|
||||||
RealMatrix b = MatrixUtils.createRealMatrix(new double[][] {
|
|
||||||
{ -102, 12250 }, { 544, 24500 }, { 167, -36750 }
|
|
||||||
});
|
|
||||||
|
|
||||||
RealMatrix xRef = MatrixUtils.createRealMatrix(new double[][] {
|
|
||||||
{ 1, 2515 }, { 2, 422 }, { -3, 898 }
|
|
||||||
});
|
|
||||||
|
|
||||||
// using RealMatrix
|
|
||||||
Assert.assertEquals(0, solver.solve(b).subtract(xRef).getNorm(), 2.0e-14 * xRef.getNorm());
|
|
||||||
|
|
||||||
// using ArrayRealVector
|
|
||||||
for (int i = 0; i < b.getColumnDimension(); ++i) {
|
|
||||||
final RealVector x = solver.solve(b.getColumnVector(i));
|
|
||||||
final double error = x.subtract(xRef.getColumnVector(i)).getNorm();
|
|
||||||
Assert.assertEquals(0, error, 3.0e-14 * xRef.getColumnVector(i).getNorm());
|
|
||||||
}
|
|
||||||
|
|
||||||
// using RealVector with an alternate implementation
|
|
||||||
for (int i = 0; i < b.getColumnDimension(); ++i) {
|
|
||||||
ArrayRealVectorTest.RealVectorTestImpl v =
|
|
||||||
new ArrayRealVectorTest.RealVectorTestImpl(b.getColumn(i));
|
|
||||||
final RealVector x = solver.solve(v);
|
|
||||||
final double error = x.subtract(xRef.getColumnVector(i)).getNorm();
|
|
||||||
Assert.assertEquals(0, error, 3.0e-14 * xRef.getColumnVector(i).getNorm());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testOverdetermined() throws ConvergenceException {
|
|
||||||
final Random r = new Random(5559252868205245l);
|
|
||||||
int p = (7 * BlockRealMatrix.BLOCK_SIZE) / 4;
|
|
||||||
int q = (5 * BlockRealMatrix.BLOCK_SIZE) / 4;
|
|
||||||
RealMatrix a = createTestMatrix(r, p, q);
|
|
||||||
RealMatrix xRef = createTestMatrix(r, q, BlockRealMatrix.BLOCK_SIZE + 3);
|
|
||||||
|
|
||||||
// build a perturbed system: A.X + noise = B
|
|
||||||
RealMatrix b = a.multiply(xRef);
|
|
||||||
final double noise = 0.001;
|
|
||||||
b.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
|
|
||||||
@Override
|
|
||||||
public double visit(int row, int column, double value) {
|
|
||||||
return value * (1.0 + noise * (2 * r.nextDouble() - 1));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// despite perturbation, the least square solution should be pretty good
|
|
||||||
RealMatrix x = new PivotingQRDecomposition(a).getSolver().solve(b);
|
|
||||||
Assert.assertEquals(0, x.subtract(xRef).getNorm(), 0.01 * noise * p * q);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testUnderdetermined() throws ConvergenceException {
|
|
||||||
final Random r = new Random(42185006424567123l);
|
|
||||||
int p = (5 * BlockRealMatrix.BLOCK_SIZE) / 4;
|
|
||||||
int q = (7 * BlockRealMatrix.BLOCK_SIZE) / 4;
|
|
||||||
RealMatrix a = createTestMatrix(r, p, q);
|
|
||||||
RealMatrix xRef = createTestMatrix(r, q, BlockRealMatrix.BLOCK_SIZE + 3);
|
|
||||||
RealMatrix b = a.multiply(xRef);
|
|
||||||
PivotingQRDecomposition pqr = new PivotingQRDecomposition(a);
|
|
||||||
RealMatrix x = pqr.getSolver().solve(b);
|
|
||||||
Assert.assertTrue(x.subtract(xRef).getNorm() / (p * q) > 0.01);
|
|
||||||
int count=0;
|
|
||||||
for( int i = 0 ; i < q; i++){
|
|
||||||
if( x.getRowVector(i).getNorm() == 0.0 ){
|
|
||||||
++count;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Assert.assertEquals("Zeroed rows", q-p, count);
|
|
||||||
}
|
|
||||||
|
|
||||||
private RealMatrix createTestMatrix(final Random r, final int rows, final int columns) {
|
|
||||||
RealMatrix m = MatrixUtils.createRealMatrix(rows, columns);
|
|
||||||
m.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
|
|
||||||
@Override
|
|
||||||
public double visit(int row, int column, double value) {
|
|
||||||
return 2.0 * r.nextDouble() - 1.0;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return m;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,915 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2011 The Apache Software Foundation.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
package org.apache.commons.math3.optimization;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import org.apache.commons.math3.TestUtils;
|
|
||||||
import org.apache.commons.math3.analysis.DifferentiableMultivariateFunction;
|
|
||||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
|
||||||
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
|
|
||||||
import org.apache.commons.math3.exception.MathIllegalArgumentException;
|
|
||||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
|
||||||
import org.apache.commons.math3.optimization.direct.BOBYQAOptimizer;
|
|
||||||
import org.apache.commons.math3.optimization.direct.PowellOptimizer;
|
|
||||||
import org.apache.commons.math3.optimization.general.AbstractScalarDifferentiableOptimizer;
|
|
||||||
import org.apache.commons.math3.optimization.general.ConjugateGradientFormula;
|
|
||||||
import org.apache.commons.math3.optimization.general.NonLinearConjugateGradientOptimizer;
|
|
||||||
import org.apache.commons.math3.util.FastMath;
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* an ever growing set of tests from NIST
|
|
||||||
* http://www.itl.nist.gov/div898/strd/nls/nls_main.shtml
|
|
||||||
* @author gregs
|
|
||||||
*/
|
|
||||||
public class BatteryNISTTest {
|
|
||||||
|
|
||||||
public static double[] lanczosNIST = {
|
|
||||||
2.5134, 0.00000,
|
|
||||||
2.0443, 5.00000e-2,
|
|
||||||
1.6684, 1.00000e-1,
|
|
||||||
1.3664, 1.50000e-1,
|
|
||||||
1.1232, 2.00000e-1,
|
|
||||||
0.9269, 2.50000e-1,
|
|
||||||
0.7679, 3.00000e-1,
|
|
||||||
0.6389, 3.50000e-1,
|
|
||||||
0.5338, 4.00000e-1,
|
|
||||||
0.4479, 4.50000e-1,
|
|
||||||
0.3776, 5.00000e-1,
|
|
||||||
0.3197, 5.50000e-1,
|
|
||||||
0.2720, 6.00000e-1,
|
|
||||||
0.2325, 6.50000e-1,
|
|
||||||
0.1997, 7.00000e-1,
|
|
||||||
0.1723, 7.50000e-1,
|
|
||||||
0.1493, 8.00000e-1,
|
|
||||||
0.1301, 8.50000e-1,
|
|
||||||
0.1138, 9.00000e-1,
|
|
||||||
0.1000, 9.50000e-1,
|
|
||||||
0.0883, 1.00000,
|
|
||||||
0.0783, 1.05000,
|
|
||||||
0.0698, 1.10000,
|
|
||||||
0.0624, 1.15000};
|
|
||||||
/* the lanzcos objective function -------------------------------*/
|
|
||||||
private final nistMVRF lanczosObjectFunc = new nistMVRF(lanczosNIST, 1, 24, 6) {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double partialDeriv(double[] point, int idx) {
|
|
||||||
double cy, cx, r, ret = 0.0, d;
|
|
||||||
int ptr = 0, ptr1;
|
|
||||||
for (int i = 0; i < this.nobs; i++) {
|
|
||||||
cy = data[ptr++];
|
|
||||||
cx = data[ptr++];
|
|
||||||
ptr1 = 0;
|
|
||||||
d = 0.0;
|
|
||||||
for (int j = 0; j < 3; j++) {
|
|
||||||
d += point[ptr1++] * FastMath.exp(-cx * point[ptr1++]);
|
|
||||||
}
|
|
||||||
r = cy - d;
|
|
||||||
if (idx == 0) {
|
|
||||||
ret -= (2.0 * r) * FastMath.exp(-cx * point[1]);
|
|
||||||
} else if (idx == 1) {
|
|
||||||
ret += (2.0 * r) * FastMath.exp(-cx * point[1]) * cx * point[0];
|
|
||||||
} else if (idx == 2) {
|
|
||||||
ret -= (2.0 * r) * FastMath.exp(-cx * point[3]);
|
|
||||||
} else if (idx == 3) {
|
|
||||||
ret += (2.0 * r) * FastMath.exp(-cx * point[3]) * cx * point[2];
|
|
||||||
} else if (idx == 4) {
|
|
||||||
ret -= (2.0 * r) * FastMath.exp(-cx * point[5]);
|
|
||||||
} else {
|
|
||||||
ret += (2.0 * r) * FastMath.exp(-cx * point[5]) * cx * point[4];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return (ret);
|
|
||||||
}
|
|
||||||
|
|
||||||
public double value(double[] point) {
|
|
||||||
double ret = 0.0, err, d, cx, cy;
|
|
||||||
int ptr = 0, ptr1 = 0;
|
|
||||||
for (int i = 0; i < this.nobs; i++) {
|
|
||||||
cy = data[ptr++];
|
|
||||||
cx = data[ptr++];
|
|
||||||
d = 0.0;
|
|
||||||
ptr1 = 0;
|
|
||||||
for (int j = 0; j < 3; j++) {
|
|
||||||
d += point[ptr1++] * FastMath.exp(-cx * point[ptr1++]);
|
|
||||||
}
|
|
||||||
err = cy - d;
|
|
||||||
ret += err * err;
|
|
||||||
}
|
|
||||||
return (ret);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double[] getGradient(double[] point) {
|
|
||||||
Arrays.fill(gradient, 0.0);
|
|
||||||
double cy, cx, r, d = 0;
|
|
||||||
int ptr = 0, ptr1;
|
|
||||||
for (int i = 0; i < this.nobs; i++) {
|
|
||||||
cy = data[ptr++];
|
|
||||||
cx = data[ptr++];
|
|
||||||
ptr1 = 0;
|
|
||||||
d = 0.0;
|
|
||||||
for (int j = 0; j < 3; j++) {
|
|
||||||
d += point[ptr1++] * FastMath.exp(-cx * point[ptr1++]);
|
|
||||||
}
|
|
||||||
r = cy - d;
|
|
||||||
gradient[0] -= (2.0 * r) * FastMath.exp(-cx * point[1]);
|
|
||||||
gradient[1] += (2.0 * r) * FastMath.exp(-cx * point[1]) * cx * point[0];
|
|
||||||
|
|
||||||
gradient[2] -= (2.0 * r) * FastMath.exp(-cx * point[3]);
|
|
||||||
gradient[3] += (2.0 * r) * FastMath.exp(-cx * point[3]) * cx * point[2];
|
|
||||||
|
|
||||||
gradient[4] -= (2.0 * r) * FastMath.exp(-cx * point[5]);
|
|
||||||
gradient[5] += (2.0 * r) * FastMath.exp(-cx * point[5]) * cx * point[4];
|
|
||||||
}
|
|
||||||
return this.gradient;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/* chwirut1 data ------------------------*/
|
|
||||||
public static double[] chwirut1NIST = {
|
|
||||||
92.9000, 0.5000,
|
|
||||||
78.7000, 0.6250,
|
|
||||||
64.2000, 0.7500,
|
|
||||||
64.9000, 0.8750,
|
|
||||||
57.1000, 1.0000,
|
|
||||||
43.3000, 1.2500,
|
|
||||||
31.1000, 1.7500,
|
|
||||||
23.6000, 2.2500,
|
|
||||||
31.0500, 1.7500,
|
|
||||||
23.7750, 2.2500,
|
|
||||||
17.7375, 2.7500,
|
|
||||||
13.8000, 3.2500,
|
|
||||||
11.5875, 3.7500,
|
|
||||||
9.4125, 4.2500,
|
|
||||||
7.7250, 4.7500,
|
|
||||||
7.3500, 5.2500,
|
|
||||||
8.0250, 5.7500,
|
|
||||||
90.6000, 0.5000,
|
|
||||||
76.9000, 0.6250,
|
|
||||||
71.6000, 0.7500,
|
|
||||||
63.6000, 0.8750,
|
|
||||||
54.0000, 1.0000,
|
|
||||||
39.2000, 1.2500,
|
|
||||||
29.3000, 1.7500,
|
|
||||||
21.4000, 2.2500,
|
|
||||||
29.1750, 1.7500,
|
|
||||||
22.1250, 2.2500,
|
|
||||||
17.5125, 2.7500,
|
|
||||||
14.2500, 3.2500,
|
|
||||||
9.4500, 3.7500,
|
|
||||||
9.1500, 4.2500,
|
|
||||||
7.9125, 4.7500,
|
|
||||||
8.4750, 5.2500,
|
|
||||||
6.1125, 5.7500,
|
|
||||||
80.0000, 0.5000,
|
|
||||||
79.0000, 0.6250,
|
|
||||||
63.8000, 0.7500,
|
|
||||||
57.2000, 0.8750,
|
|
||||||
53.2000, 1.0000,
|
|
||||||
42.5000, 1.2500,
|
|
||||||
26.8000, 1.7500,
|
|
||||||
20.4000, 2.2500,
|
|
||||||
26.8500, 1.7500,
|
|
||||||
21.0000, 2.2500,
|
|
||||||
16.4625, 2.7500,
|
|
||||||
12.5250, 3.2500,
|
|
||||||
10.5375, 3.7500,
|
|
||||||
8.5875, 4.2500,
|
|
||||||
7.1250, 4.7500,
|
|
||||||
6.1125, 5.2500,
|
|
||||||
5.9625, 5.7500,
|
|
||||||
74.1000, 0.5000,
|
|
||||||
67.3000, 0.6250,
|
|
||||||
60.8000, 0.7500,
|
|
||||||
55.5000, 0.8750,
|
|
||||||
50.3000, 1.0000,
|
|
||||||
41.0000, 1.2500,
|
|
||||||
29.4000, 1.7500,
|
|
||||||
20.4000, 2.2500,
|
|
||||||
29.3625, 1.7500,
|
|
||||||
21.1500, 2.2500,
|
|
||||||
16.7625, 2.7500,
|
|
||||||
13.2000, 3.2500,
|
|
||||||
10.8750, 3.7500,
|
|
||||||
8.1750, 4.2500,
|
|
||||||
7.3500, 4.7500,
|
|
||||||
5.9625, 5.2500,
|
|
||||||
5.6250, 5.7500,
|
|
||||||
81.5000, .5000,
|
|
||||||
62.4000, .7500,
|
|
||||||
32.5000, 1.5000,
|
|
||||||
12.4100, 3.0000,
|
|
||||||
13.1200, 3.0000,
|
|
||||||
15.5600, 3.0000,
|
|
||||||
5.6300, 6.0000,
|
|
||||||
78.0000, .5000,
|
|
||||||
59.9000, .7500,
|
|
||||||
33.2000, 1.5000,
|
|
||||||
13.8400, 3.0000,
|
|
||||||
12.7500, 3.0000,
|
|
||||||
14.6200, 3.0000,
|
|
||||||
3.9400, 6.0000,
|
|
||||||
76.8000, .5000,
|
|
||||||
61.0000, .7500,
|
|
||||||
32.9000, 1.5000,
|
|
||||||
13.8700, 3.0000,
|
|
||||||
11.8100, 3.0000,
|
|
||||||
13.3100, 3.0000,
|
|
||||||
5.4400, 6.0000,
|
|
||||||
78.0000, .5000,
|
|
||||||
63.5000, .7500,
|
|
||||||
33.8000, 1.5000,
|
|
||||||
12.5600, 3.0000,
|
|
||||||
5.6300, 6.0000,
|
|
||||||
12.7500, 3.0000,
|
|
||||||
13.1200, 3.0000,
|
|
||||||
5.4400, 6.0000,
|
|
||||||
76.8000, .5000,
|
|
||||||
60.0000, .7500,
|
|
||||||
47.8000, 1.0000,
|
|
||||||
32.0000, 1.5000,
|
|
||||||
22.2000, 2.0000,
|
|
||||||
22.5700, 2.0000,
|
|
||||||
18.8200, 2.5000,
|
|
||||||
13.9500, 3.0000,
|
|
||||||
11.2500, 4.0000,
|
|
||||||
9.0000, 5.0000,
|
|
||||||
6.6700, 6.0000,
|
|
||||||
75.8000, .5000,
|
|
||||||
62.0000, .7500,
|
|
||||||
48.8000, 1.0000,
|
|
||||||
35.2000, 1.5000,
|
|
||||||
20.0000, 2.0000,
|
|
||||||
20.3200, 2.0000,
|
|
||||||
19.3100, 2.5000,
|
|
||||||
12.7500, 3.0000,
|
|
||||||
10.4200, 4.0000,
|
|
||||||
7.3100, 5.0000,
|
|
||||||
7.4200, 6.0000,
|
|
||||||
70.5000, .5000,
|
|
||||||
59.5000, .7500,
|
|
||||||
48.5000, 1.0000,
|
|
||||||
35.8000, 1.5000,
|
|
||||||
21.0000, 2.0000,
|
|
||||||
21.6700, 2.0000,
|
|
||||||
21.0000, 2.5000,
|
|
||||||
15.6400, 3.0000,
|
|
||||||
8.1700, 4.0000,
|
|
||||||
8.5500, 5.0000,
|
|
||||||
10.1200, 6.0000,
|
|
||||||
78.0000, .5000,
|
|
||||||
66.0000, .6250,
|
|
||||||
62.0000, .7500,
|
|
||||||
58.0000, .8750,
|
|
||||||
47.7000, 1.0000,
|
|
||||||
37.8000, 1.2500,
|
|
||||||
20.2000, 2.2500,
|
|
||||||
21.0700, 2.2500,
|
|
||||||
13.8700, 2.7500,
|
|
||||||
9.6700, 3.2500,
|
|
||||||
7.7600, 3.7500,
|
|
||||||
5.4400, 4.2500,
|
|
||||||
4.8700, 4.7500,
|
|
||||||
4.0100, 5.2500,
|
|
||||||
3.7500, 5.7500,
|
|
||||||
24.1900, 3.0000,
|
|
||||||
25.7600, 3.0000,
|
|
||||||
18.0700, 3.0000,
|
|
||||||
11.8100, 3.0000,
|
|
||||||
12.0700, 3.0000,
|
|
||||||
16.1200, 3.0000,
|
|
||||||
70.8000, .5000,
|
|
||||||
54.7000, .7500,
|
|
||||||
48.0000, 1.0000,
|
|
||||||
39.8000, 1.5000,
|
|
||||||
29.8000, 2.0000,
|
|
||||||
23.7000, 2.5000,
|
|
||||||
29.6200, 2.0000,
|
|
||||||
23.8100, 2.5000,
|
|
||||||
17.7000, 3.0000,
|
|
||||||
11.5500, 4.0000,
|
|
||||||
12.0700, 5.0000,
|
|
||||||
8.7400, 6.0000,
|
|
||||||
80.7000, .5000,
|
|
||||||
61.3000, .7500,
|
|
||||||
47.5000, 1.0000,
|
|
||||||
29.0000, 1.5000,
|
|
||||||
24.0000, 2.0000,
|
|
||||||
17.7000, 2.5000,
|
|
||||||
24.5600, 2.0000,
|
|
||||||
18.6700, 2.5000,
|
|
||||||
16.2400, 3.0000,
|
|
||||||
8.7400, 4.0000,
|
|
||||||
7.8700, 5.0000,
|
|
||||||
8.5100, 6.0000,
|
|
||||||
66.7000, .5000,
|
|
||||||
59.2000, .7500,
|
|
||||||
40.8000, 1.0000,
|
|
||||||
30.7000, 1.5000,
|
|
||||||
25.7000, 2.0000,
|
|
||||||
16.3000, 2.5000,
|
|
||||||
25.9900, 2.0000,
|
|
||||||
16.9500, 2.5000,
|
|
||||||
13.3500, 3.0000,
|
|
||||||
8.6200, 4.0000,
|
|
||||||
7.2000, 5.0000,
|
|
||||||
6.6400, 6.0000,
|
|
||||||
13.6900, 3.0000,
|
|
||||||
81.0000, .5000,
|
|
||||||
64.5000, .7500,
|
|
||||||
35.5000, 1.5000,
|
|
||||||
13.3100, 3.0000,
|
|
||||||
4.8700, 6.0000,
|
|
||||||
12.9400, 3.0000,
|
|
||||||
5.0600, 6.0000,
|
|
||||||
15.1900, 3.0000,
|
|
||||||
14.6200, 3.0000,
|
|
||||||
15.6400, 3.0000,
|
|
||||||
25.5000, 1.7500,
|
|
||||||
25.9500, 1.7500,
|
|
||||||
81.7000, .5000,
|
|
||||||
61.6000, .7500,
|
|
||||||
29.8000, 1.7500,
|
|
||||||
29.8100, 1.7500,
|
|
||||||
17.1700, 2.7500,
|
|
||||||
10.3900, 3.7500,
|
|
||||||
28.4000, 1.7500,
|
|
||||||
28.6900, 1.7500,
|
|
||||||
81.3000, .5000,
|
|
||||||
60.9000, .7500,
|
|
||||||
16.6500, 2.7500,
|
|
||||||
10.0500, 3.7500,
|
|
||||||
28.9000, 1.7500,
|
|
||||||
28.9500, 1.7500
|
|
||||||
};
|
|
||||||
|
|
||||||
/* the chwirut1 objective function */
|
|
||||||
private final nistMVRF chwirut1ObjectFunc = new chwirut(chwirut1NIST, 1, 214, 3);
|
|
||||||
|
|
||||||
//http://www.itl.nist.gov/div898/strd/nls/data/LINKS/DATA/Chwirut2.dat
|
|
||||||
public static double[] chwirut2NIST = {
|
|
||||||
92.9000, 0.500,
|
|
||||||
57.1000, 1.000,
|
|
||||||
31.0500, 1.750,
|
|
||||||
11.5875, 3.750,
|
|
||||||
8.0250, 5.750,
|
|
||||||
63.6000, 0.875,
|
|
||||||
21.4000, 2.250,
|
|
||||||
14.2500, 3.250,
|
|
||||||
8.4750, 5.250,
|
|
||||||
63.8000, 0.750,
|
|
||||||
26.8000, 1.750,
|
|
||||||
16.4625, 2.750,
|
|
||||||
7.1250, 4.750,
|
|
||||||
67.3000, 0.625,
|
|
||||||
41.0000, 1.250,
|
|
||||||
21.1500, 2.250,
|
|
||||||
8.1750, 4.250,
|
|
||||||
81.5000, .500,
|
|
||||||
13.1200, 3.000,
|
|
||||||
59.9000, .750,
|
|
||||||
14.6200, 3.000,
|
|
||||||
32.9000, 1.500,
|
|
||||||
5.4400, 6.000,
|
|
||||||
12.5600, 3.000,
|
|
||||||
5.4400, 6.000,
|
|
||||||
32.0000, 1.500,
|
|
||||||
13.9500, 3.000,
|
|
||||||
75.8000, .500,
|
|
||||||
20.0000, 2.000,
|
|
||||||
10.4200, 4.000,
|
|
||||||
59.5000, .750,
|
|
||||||
21.6700, 2.000,
|
|
||||||
8.5500, 5.000,
|
|
||||||
62.0000, .750,
|
|
||||||
20.2000, 2.250,
|
|
||||||
7.7600, 3.750,
|
|
||||||
3.7500, 5.750,
|
|
||||||
11.8100, 3.000,
|
|
||||||
54.7000, .750,
|
|
||||||
23.7000, 2.500,
|
|
||||||
11.5500, 4.000,
|
|
||||||
61.3000, .750,
|
|
||||||
17.7000, 2.500,
|
|
||||||
8.7400, 4.000,
|
|
||||||
59.2000, .750,
|
|
||||||
16.3000, 2.500,
|
|
||||||
8.6200, 4.000,
|
|
||||||
81.0000, .500,
|
|
||||||
4.8700, 6.000,
|
|
||||||
14.6200, 3.000,
|
|
||||||
81.7000, .500,
|
|
||||||
17.1700, 2.750,
|
|
||||||
81.3000, .500,
|
|
||||||
28.9000, 1.750
|
|
||||||
};
|
|
||||||
|
|
||||||
/* the chwirut 2 objective --------------------------------------------------*/
|
|
||||||
private final nistMVRF chwirut2ObjectFunc = new chwirut(chwirut2NIST, 1, 54, 3);
|
|
||||||
|
|
||||||
//http://www.itl.nist.gov/div898/strd/nls/data/LINKS/DATA/Misra1a.dat
|
|
||||||
//y x
|
|
||||||
private static double[] misra1aNIST = {
|
|
||||||
10.07, 77.6,
|
|
||||||
14.73, 114.9,
|
|
||||||
17.94, 141.1,
|
|
||||||
23.93, 190.8,
|
|
||||||
29.61, 239.9,
|
|
||||||
35.18, 289.0,
|
|
||||||
40.02, 332.8,
|
|
||||||
44.82, 378.4,
|
|
||||||
50.76, 434.8,
|
|
||||||
55.05, 477.3,
|
|
||||||
61.01, 536.8,
|
|
||||||
66.40, 593.1,
|
|
||||||
75.47, 689.1,
|
|
||||||
81.78, 760.0
|
|
||||||
};
|
|
||||||
|
|
||||||
/* the misra1a objective function */
|
|
||||||
private final nistMVRF misra1aObjectFunc = new nistMVRF(misra1aNIST, 1, 14, 2) {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double partialDeriv(double[] point, int idx) {
|
|
||||||
double cy, cx, r, ret = 0.0;
|
|
||||||
int ptr = 0;
|
|
||||||
for (int i = 0; i < this.nobs; i++) {
|
|
||||||
cy = data[ptr++];
|
|
||||||
cx = data[ptr++];
|
|
||||||
r = cy - point[0] * (1.0 - FastMath.exp(-cx * point[1]));
|
|
||||||
if (idx == 0) {
|
|
||||||
ret -= (2.0 * r) * (1.0 - FastMath.exp(-cx * point[1]));
|
|
||||||
} else {
|
|
||||||
ret -= (2.0 * r) * cx * point[0] * FastMath.exp(-cx * point[1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return (ret);
|
|
||||||
}
|
|
||||||
|
|
||||||
public double value(double[] point) {
|
|
||||||
double ret = 0.0, err;
|
|
||||||
int ptr = 0;
|
|
||||||
for (int i = 0; i < this.nobs; i++) {
|
|
||||||
err = data[ptr++] - point[0] * (1.0 - FastMath.exp(-data[ptr++] * point[1]));
|
|
||||||
ret += err * err;
|
|
||||||
}
|
|
||||||
return (ret);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double[] getGradient(double[] point) {
|
|
||||||
Arrays.fill(gradient, 0.0);
|
|
||||||
double cy, cx, r;
|
|
||||||
int ptr = 0;
|
|
||||||
for (int i = 0; i < this.nobs; i++) {
|
|
||||||
cy = data[ptr++];
|
|
||||||
cx = data[ptr++];
|
|
||||||
r = cy - point[0] * (1.0 - FastMath.exp(-cx * point[1]));
|
|
||||||
gradient[0] -= (2.0 * r) * (1.0 - FastMath.exp(-cx * point[1]));
|
|
||||||
gradient[1] -= (2.0 * r) * cx * point[0] * FastMath.exp(-cx * point[1]);
|
|
||||||
}
|
|
||||||
return this.gradient;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
private static double[] correctParamMisra1a = {2.3894212918e2, 5.5015643181E-4};
|
|
||||||
private static double[] correctParamChwirut2 = {1.6657666537e-1, 5.1653291286e-3, 1.2150007096e-2};
|
|
||||||
private static double[] correctParamChwirut1 = {1.9027818370e-1, 6.1314004477e-3, 1.0530908399e-2};
|
|
||||||
private static double[] correctParamLanczos = {8.6816414977e-2, 9.5498101505e-01, 8.4400777463E-01, 2.9515951832, 1.5825685901, 4.9863565084};
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void lanczosTest() {
|
|
||||||
//first check to see that the NIST Object function is being replicated correctly
|
|
||||||
double obj = this.lanczosObjectFunc.value(correctParamLanczos);
|
|
||||||
Assert.assertEquals(1.6117193594E-08, obj, 1.0e-8);
|
|
||||||
|
|
||||||
double[] grad = this.lanczosObjectFunc.getGradient(correctParamLanczos);
|
|
||||||
double[] grad2 = new double[6];
|
|
||||||
grad2[0] = this.lanczosObjectFunc.partialDeriv(correctParamLanczos, 0);
|
|
||||||
grad2[1] = this.lanczosObjectFunc.partialDeriv(correctParamLanczos, 1);
|
|
||||||
grad2[2] = this.lanczosObjectFunc.partialDeriv(correctParamLanczos, 2);
|
|
||||||
grad2[3] = this.lanczosObjectFunc.partialDeriv(correctParamLanczos, 3);
|
|
||||||
grad2[4] = this.lanczosObjectFunc.partialDeriv(correctParamLanczos, 4);
|
|
||||||
grad2[5] = this.lanczosObjectFunc.partialDeriv(correctParamLanczos, 5);
|
|
||||||
TestUtils.assertEquals("Grads...", grad, grad2, 1.0e-12);
|
|
||||||
|
|
||||||
double[] n_grad = this.getGradient(lanczosObjectFunc, correctParamLanczos, 1.0e-5);
|
|
||||||
//System.out.println("g = " + grad[0] + " ng = " + n_grad[0]);
|
|
||||||
//System.out.println("g = " + grad[1] + " ng = " + n_grad[1]);
|
|
||||||
if (FastMath.abs(grad[0] - n_grad[0]) > FastMath.max(1.0e-6, 1.0e-6 * (grad[0] + n_grad[0]) / 2.0)) {
|
|
||||||
Assert.fail("Check gradient at 1");
|
|
||||||
}
|
|
||||||
if (FastMath.abs(grad[1] - n_grad[1]) > FastMath.max(1.0e-6, 1.0e-6 * (grad[1] + n_grad[1]) / 2.0)) {
|
|
||||||
Assert.fail("Check gradient at 2");
|
|
||||||
}
|
|
||||||
if (FastMath.abs(grad[2] - n_grad[2]) > FastMath.max(1.0e-6, 1.0e-6 * (grad[2] + n_grad[2]) / 2.0)) {
|
|
||||||
Assert.fail("Check gradient at 2");
|
|
||||||
}
|
|
||||||
if (FastMath.abs(grad[3] - n_grad[3]) > FastMath.max(1.0e-6, 1.0e-6 * (grad[3] + n_grad[3]) / 2.0)) {
|
|
||||||
Assert.fail("Check gradient at 2");
|
|
||||||
}
|
|
||||||
if (FastMath.abs(grad[4] - n_grad[4]) > FastMath.max(1.0e-6, 1.0e-6 * (grad[4] + n_grad[4]) / 2.0)) {
|
|
||||||
Assert.fail("Check gradient at 2");
|
|
||||||
}
|
|
||||||
if (FastMath.abs(grad[5] - n_grad[5]) > FastMath.max(1.0e-6, 1.0e-6 * (grad[5] + n_grad[5]) / 2.0)) {
|
|
||||||
Assert.fail("Check gradient at 2");
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void lanczos_BOBYQA() {
|
|
||||||
double[] bobyqa = run(new BOBYQAOptimizer(10),
|
|
||||||
lanczosObjectFunc, new double[]{1.2,0.3,5.6,5.5,6.5,7.6});
|
|
||||||
TestUtils.assertEquals(correctParamLanczos, bobyqa, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void lanczosTest_cgPolakRibiere() {
|
|
||||||
double[] cgPolakRibiere = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE),
|
|
||||||
lanczosObjectFunc, new double[]{1.2,0.3,5.6,5.5,6.5,7.6});
|
|
||||||
TestUtils.assertEquals(correctParamLanczos, cgPolakRibiere, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void lanczosTest_cgPolakRibiere2() {
|
|
||||||
double[] cgPolakRibiere2 = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE),
|
|
||||||
lanczosObjectFunc, new double[]{0.5,0.7,3.6,4.2,4,6.3});
|
|
||||||
TestUtils.assertEquals(correctParamLanczos, cgPolakRibiere2, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void lanczosTest_cgFletcherReeves() {
|
|
||||||
double[] cgFletcherReeves = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.FLETCHER_REEVES),
|
|
||||||
lanczosObjectFunc, new double[]{1.2,0.3,5.6,5.5,6.5,7.6});
|
|
||||||
TestUtils.assertEquals(correctParamLanczos, cgFletcherReeves, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void lanczosTest_cgFletcherReeves2() {
|
|
||||||
double[] cgFletcherReeves2 = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.FLETCHER_REEVES),
|
|
||||||
lanczosObjectFunc, new double[]{0.5,0.7,3.6,4.2,4,6.3});
|
|
||||||
TestUtils.assertEquals(correctParamLanczos, cgFletcherReeves2, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void lanczosTest_powell() {
|
|
||||||
double[] resPowell = run(new PowellOptimizer(1.0e-8, 1.0e-8), lanczosObjectFunc,
|
|
||||||
new double[]{1.2,0.3,5.6,5.5,6.5,7.6});
|
|
||||||
TestUtils.assertEquals(correctParamLanczos, resPowell, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void lanczosTest_powell2() {
|
|
||||||
double[] resPowell2 = run(new PowellOptimizer(1.0e-8, 1.0e-8), lanczosObjectFunc,
|
|
||||||
new double[]{0.5,0.7,3.6,4.2,4,6.3});
|
|
||||||
TestUtils.assertEquals(correctParamLanczos, resPowell2, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void chwirut1Test() {
|
|
||||||
//first check to see that the NIST Object function is being replicated correctly
|
|
||||||
double obj = this.chwirut1ObjectFunc.value(correctParamChwirut1);
|
|
||||||
Assert.assertEquals(2.3844771393e3, obj, 1.0e-8);
|
|
||||||
|
|
||||||
double[] grad = this.chwirut1ObjectFunc.getGradient(correctParamChwirut1);
|
|
||||||
double[] grad2 = new double[3];
|
|
||||||
grad2[0] = this.chwirut1ObjectFunc.partialDeriv(correctParamChwirut1, 0);
|
|
||||||
grad2[1] = this.chwirut1ObjectFunc.partialDeriv(correctParamChwirut1, 1);
|
|
||||||
grad2[2] = this.chwirut1ObjectFunc.partialDeriv(correctParamChwirut1, 2);
|
|
||||||
TestUtils.assertEquals("Grads...", grad, grad2, 1.0e-12);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void chwirut1_BOBYQA() {
|
|
||||||
double[] bobyqa = run(new BOBYQAOptimizer(5),
|
|
||||||
chwirut1ObjectFunc, new double[]{0.1, 0.01, 0.02});
|
|
||||||
TestUtils.assertEquals(correctParamChwirut1, bobyqa, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void chwirut1Test_cgPolakRibiere() {
|
|
||||||
double[] cgPolakRibiere = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE),
|
|
||||||
chwirut1ObjectFunc, new double[]{0.1, 0.01, 0.02});
|
|
||||||
TestUtils.assertEquals(correctParamChwirut1, cgPolakRibiere, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void chwirut1Test_cgPolakRibiere2() {
|
|
||||||
double[] cgPolakRibiere2 = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE),
|
|
||||||
chwirut1ObjectFunc, new double[]{0.15, 0.008, 0.01});
|
|
||||||
TestUtils.assertEquals(correctParamChwirut1, cgPolakRibiere2, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void chwirut1Test_cgFletcherReeves() {
|
|
||||||
double[] cgFletcherReeves = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.FLETCHER_REEVES),
|
|
||||||
chwirut1ObjectFunc, new double[]{0.1, 0.01, 0.02});
|
|
||||||
TestUtils.assertEquals(correctParamChwirut1, cgFletcherReeves, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void chwirut1Test_cgFletcherReeves2() {
|
|
||||||
double[] cgFletcherReeves2 = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.FLETCHER_REEVES),
|
|
||||||
chwirut1ObjectFunc, new double[]{0.15, 0.008, 0.01});
|
|
||||||
TestUtils.assertEquals(correctParamChwirut1, cgFletcherReeves2, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void chwirut1Test_powell() {
|
|
||||||
double[] resPowell = run(new PowellOptimizer(1.0e-8, 1.0e-8), chwirut1ObjectFunc, new double[]{0.1, 0.01, 0.02});
|
|
||||||
TestUtils.assertEquals(correctParamChwirut1, resPowell, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void chwirut1Test_powell2() {
|
|
||||||
double[] resPowell2 = run(new PowellOptimizer(1.0e-8, 1.0e-8), chwirut1ObjectFunc, new double[]{0.15, 0.08, 0.01});
|
|
||||||
TestUtils.assertEquals(correctParamChwirut1, resPowell2, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void chwirut2Test() {
|
|
||||||
//first check to see that the NIST Object function is being replicated correctly
|
|
||||||
double obj = this.chwirut2ObjectFunc.value(correctParamChwirut2);
|
|
||||||
Assert.assertEquals(5.1304802941e02, obj, 1.0e-8);
|
|
||||||
|
|
||||||
double[] grad = this.chwirut2ObjectFunc.getGradient(correctParamChwirut2);
|
|
||||||
double[] grad2 = new double[3];
|
|
||||||
grad2[0] = this.chwirut2ObjectFunc.partialDeriv(correctParamChwirut2, 0);
|
|
||||||
grad2[1] = this.chwirut2ObjectFunc.partialDeriv(correctParamChwirut2, 1);
|
|
||||||
grad2[2] = this.chwirut2ObjectFunc.partialDeriv(correctParamChwirut2, 2);
|
|
||||||
TestUtils.assertEquals("Grads...", grad, grad2, 1.0e-12);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void chwirut2_BOBYQA() {
|
|
||||||
double[] bobyqa = run(new BOBYQAOptimizer(5),
|
|
||||||
chwirut2ObjectFunc, new double[]{0.1, 0.01, 0.02});
|
|
||||||
TestUtils.assertEquals(correctParamChwirut2, bobyqa, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void chwirut2Test_cgPolakRibiere() {
|
|
||||||
double[] cgPolakRibiere = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE),
|
|
||||||
chwirut2ObjectFunc, new double[]{0.1, 0.01, 0.02});
|
|
||||||
TestUtils.assertEquals(correctParamChwirut2, cgPolakRibiere, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void chwirut2Test_cgPolakRibiere2() {
|
|
||||||
double[] cgPolakRibiere2 = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE),
|
|
||||||
chwirut2ObjectFunc, new double[]{0.15, 0.008, 0.01});
|
|
||||||
TestUtils.assertEquals(correctParamChwirut2, cgPolakRibiere2, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void chwirut2Test_cgFletcherReeves() {
|
|
||||||
double[] cgFletcherReeves = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.FLETCHER_REEVES),
|
|
||||||
chwirut2ObjectFunc, new double[]{0.1, 0.01, 0.02});
|
|
||||||
TestUtils.assertEquals(correctParamChwirut2, cgFletcherReeves, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void chwirut2Test_cgFletcherReeves2() {
|
|
||||||
double[] cgFletcherReeves2 = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.FLETCHER_REEVES),
|
|
||||||
chwirut2ObjectFunc, new double[]{0.15, 0.008, 0.01});
|
|
||||||
TestUtils.assertEquals(correctParamChwirut2, cgFletcherReeves2, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void chwirut2Test_powell() {
|
|
||||||
double[] resPowell = run(new PowellOptimizer(1.0e-8, 1.0e-8), chwirut2ObjectFunc, new double[]{0.1, 0.01, 0.02});
|
|
||||||
TestUtils.assertEquals(correctParamChwirut2, resPowell, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void chwirut2Test_powell2() {
|
|
||||||
double[] resPowell2 = run(new PowellOptimizer(1.0e-8, 1.0e-8), chwirut2ObjectFunc, new double[]{0.15, 0.08, 0.01});
|
|
||||||
TestUtils.assertEquals(correctParamChwirut2, resPowell2, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void misra1aTest() {
|
|
||||||
//first check to see that the NIST Object function is being replicated correctly
|
|
||||||
double obj = this.misra1aObjectFunc.value(correctParamMisra1a);
|
|
||||||
Assert.assertEquals(1.2455138894e-01, obj, 1.0e-8);
|
|
||||||
|
|
||||||
double[] grad = this.misra1aObjectFunc.getGradient(correctParamMisra1a);
|
|
||||||
double[] grad2 = new double[2];
|
|
||||||
grad2[0] = this.misra1aObjectFunc.partialDeriv(correctParamMisra1a, 0);
|
|
||||||
grad2[1] = this.misra1aObjectFunc.partialDeriv(correctParamMisra1a, 1);
|
|
||||||
|
|
||||||
TestUtils.assertEquals("Grads...", grad, grad2, 1.0e-12);
|
|
||||||
|
|
||||||
// double[] n_grad = this.getGradient(misra1aObjectFunc, correctParamMisra1a, 1.0e-5);
|
|
||||||
// System.out.println("g = " + grad[0] + " ng = " + n_grad[0]);
|
|
||||||
// System.out.println("g = " + grad[1] + " ng = " + n_grad[1]);
|
|
||||||
// if( FastMath.abs(grad[0] - n_grad[0] ) > FastMath.max(1.0e-6, 1.0e-6 * (grad[0]+n_grad[0])/2.0) ){
|
|
||||||
// Assert.fail("Check gradient at 1");
|
|
||||||
// }
|
|
||||||
// if( FastMath.abs(grad[1] - n_grad[1] ) > FastMath.max(1.0e-6, 1.0e-6 * (grad[1]+n_grad[1])/2.0) ){
|
|
||||||
// Assert.fail("Check gradient at 2");
|
|
||||||
// }
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void misra1a_BOBYQA() {
|
|
||||||
double[] bobyqa = run(new BOBYQAOptimizer(4),
|
|
||||||
misra1aObjectFunc, new double[]{500.0, 0.0001});
|
|
||||||
TestUtils.assertEquals(correctParamMisra1a, bobyqa, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void misra1aTest_cgPolakRibiere() {
|
|
||||||
double[] cgPolakRibiere = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE),
|
|
||||||
misra1aObjectFunc, new double[]{500.0, 0.0001});
|
|
||||||
TestUtils.assertEquals(correctParamMisra1a, cgPolakRibiere, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void misra1aTest_cgPolakRibiere2() {
|
|
||||||
double[] cgPolakRibiere2 = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE),
|
|
||||||
misra1aObjectFunc, new double[]{250.0, 0.0005});
|
|
||||||
TestUtils.assertEquals(correctParamMisra1a, cgPolakRibiere2, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void misra1aTest_cgFletcherReeves() {
|
|
||||||
double[] cgFletcherReeves = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.FLETCHER_REEVES),
|
|
||||||
misra1aObjectFunc, new double[]{500.0, 0.0001});
|
|
||||||
TestUtils.assertEquals(correctParamMisra1a, cgFletcherReeves, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void misra1aTest_cgFletcherReeves2() {
|
|
||||||
double[] cgFletcherReeves2 = run(new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.FLETCHER_REEVES),
|
|
||||||
misra1aObjectFunc, new double[]{250.0, 0.0005});
|
|
||||||
TestUtils.assertEquals(correctParamMisra1a, cgFletcherReeves2, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void misra1aTest_powell() {
|
|
||||||
double[] resPowell = run(new PowellOptimizer(1.0e-8, 1.0e-8), misra1aObjectFunc, new double[]{500.0, 0.0001});
|
|
||||||
TestUtils.assertEquals(correctParamMisra1a, resPowell, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
//@Test
|
|
||||||
public void misra1aTest_powell2() {
|
|
||||||
double[] resPowell2 = run(new PowellOptimizer(1.0e-8, 1.0e-8), misra1aObjectFunc, new double[]{250.0, 0.0005});
|
|
||||||
TestUtils.assertEquals(correctParamMisra1a, resPowell2, 1.0e-8);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* numerical gradients */
|
|
||||||
private double[] getGradient(nistMVRF func, double[] xo, double eps) {
|
|
||||||
double[] ret = new double[func.getNumberOfParameters()];
|
|
||||||
for (int i = 0; i < ret.length; i++) {
|
|
||||||
final double tmp = xo[i];
|
|
||||||
xo[i] += eps;
|
|
||||||
ret[i] = func.value(xo);
|
|
||||||
xo[i] = tmp - eps;
|
|
||||||
ret[i] -= func.value(xo);
|
|
||||||
ret[i] /= (2.0 * eps);
|
|
||||||
xo[i] = tmp;
|
|
||||||
}
|
|
||||||
return (ret);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* generic test runner */
|
|
||||||
private double[] run(MultivariateOptimizer optim, DifferentiableMultivariateFunction func, double[] start) {
|
|
||||||
return (optim.optimize(1000000, func, GoalType.MINIMIZE, start).getPointRef());
|
|
||||||
}
|
|
||||||
/* generic test runner for AbstractScalarDifferentiableOptimizer */
|
|
||||||
private double[] run(AbstractScalarDifferentiableOptimizer optim, DifferentiableMultivariateFunction func, double[] start) {
|
|
||||||
return (optim.optimize(1000000, func, GoalType.MINIMIZE, start).getPointRef());
|
|
||||||
}
|
|
||||||
|
|
||||||
/* base objective function class for these tests */
|
|
||||||
private abstract static class nistMVRF implements DifferentiableMultivariateFunction {
|
|
||||||
protected final MultivariateFunction[] mrf;
|
|
||||||
protected final MultivariateVectorFunction mvf = new MultivariateVectorFunction() {
|
|
||||||
|
|
||||||
public double[] value(double[] point) throws IllegalArgumentException {
|
|
||||||
return getGradient(point);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
protected double[] gradient;
|
|
||||||
protected double[] data;
|
|
||||||
protected int nvars;
|
|
||||||
protected int nobs;
|
|
||||||
protected int nparams;
|
|
||||||
|
|
||||||
public int getNumberOfParameters() {
|
|
||||||
return nparams;
|
|
||||||
}
|
|
||||||
|
|
||||||
public nistMVRF(double[] data, int nvars, int nobs, int nparams) {
|
|
||||||
if ((nvars + 1) * nobs != data.length) {
|
|
||||||
throw new MathIllegalArgumentException(LocalizedFormats.INVALID_REGRESSION_ARRAY,
|
|
||||||
data.length, nobs, nvars);
|
|
||||||
}
|
|
||||||
this.nobs = nobs;
|
|
||||||
this.nvars = nvars;
|
|
||||||
this.gradient = new double[nparams];
|
|
||||||
this.nparams = nparams;
|
|
||||||
this.data = data;
|
|
||||||
mrf = new MultivariateFunction[nvars];
|
|
||||||
for (int i = 0; i < nvars; i++) {
|
|
||||||
final int idx = i;
|
|
||||||
mrf[i] = new MultivariateFunction() {
|
|
||||||
|
|
||||||
private int myIdx = idx;
|
|
||||||
|
|
||||||
public double value(double[] point) {
|
|
||||||
return partialDeriv(point, myIdx);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public MultivariateVectorFunction gradient() {
|
|
||||||
return mvf;
|
|
||||||
}
|
|
||||||
|
|
||||||
public MultivariateFunction partialDerivative(int k) {
|
|
||||||
return mrf[k];
|
|
||||||
}
|
|
||||||
|
|
||||||
protected abstract double partialDeriv(double[] point, int idx);
|
|
||||||
|
|
||||||
protected abstract double[] getGradient(double[] point);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* since there are multiple chwirut tests create an object */
|
|
||||||
private static class chwirut extends nistMVRF {
|
|
||||||
|
|
||||||
public chwirut(double[] data, int nvars, int nobs, int nparams) {
|
|
||||||
super(data, nvars, nobs, nparams);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double partialDeriv(double[] point, int idx) {
|
|
||||||
double cy, cx, r, ret = 0.0, d;
|
|
||||||
int ptr = 0;
|
|
||||||
for (int i = 0; i < this.nobs; i++) {
|
|
||||||
cy = data[ptr++];
|
|
||||||
cx = data[ptr++];
|
|
||||||
d = (point[1] + point[2] * cx);
|
|
||||||
r = cy - FastMath.exp(-cx * point[0]) / d;
|
|
||||||
if (idx == 0) {
|
|
||||||
ret -= (2.0 * r * r) * cx;
|
|
||||||
} else if (idx == 1) {
|
|
||||||
ret += (2.0 * r * r) / d;
|
|
||||||
} else {
|
|
||||||
ret += (2.0 * r * r) * cx / d;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return (ret);
|
|
||||||
}
|
|
||||||
|
|
||||||
public double value(double[] point) {
|
|
||||||
double ret = 0.0, err, cx, cy;
|
|
||||||
int ptr = 0;
|
|
||||||
for (int i = 0; i < this.nobs; i++) {
|
|
||||||
cy = data[ptr++];
|
|
||||||
cx = data[ptr++];
|
|
||||||
err = cy - (FastMath.exp(-cx * point[0]) / (point[1] + point[2] * cx));
|
|
||||||
ret += err * err;
|
|
||||||
}
|
|
||||||
return (ret);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double[] getGradient(double[] point) {
|
|
||||||
Arrays.fill(gradient, 0.0);
|
|
||||||
double cy, cx, r, d;
|
|
||||||
int ptr = 0;
|
|
||||||
for (int i = 0; i < this.nobs; i++) {
|
|
||||||
cy = data[ptr++];
|
|
||||||
cx = data[ptr++];
|
|
||||||
d = (point[1] + point[2] * cx);
|
|
||||||
r = cy - FastMath.exp(-cx * point[0]) / d;
|
|
||||||
gradient[0] -= (2.0 * r * r) * cx;
|
|
||||||
gradient[1] += (2.0 * r * r) / d;
|
|
||||||
gradient[2] += (2.0 * r * r) * cx / d;
|
|
||||||
}
|
|
||||||
return this.gradient;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue