[MATH-1079] Improve performance of SimplexSolver.
git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1551735 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
68fa81123a
commit
af858a6ca2
|
@ -51,6 +51,10 @@ If the output is not quite correct, check for invisible trailing spaces!
|
|||
</properties>
|
||||
<body>
|
||||
<release version="3.3" date="TBD" description="TBD">
|
||||
<action dev="tn" type="fix" issue="MATH-1079">
|
||||
Improved performance of "SimplexSolver" in package o.a.c.math3.optim.linear by
|
||||
directly performing row operations and keeping track of the current basic variables.
|
||||
</action>
|
||||
<action dev="tn" type="update" issue="MATH-1080">
|
||||
The "LinearConstraintSet" will now return the enclosed collection of "LinearConstraint"
|
||||
objects in the same order as they have been added.
|
||||
|
|
|
@ -258,7 +258,7 @@ public class SimplexSolver extends LinearOptimizer {
|
|||
minRatioPositions.add(i);
|
||||
} else if (cmp < 0) {
|
||||
minRatio = ratio;
|
||||
minRatioPositions = new ArrayList<Integer>();
|
||||
minRatioPositions.clear();
|
||||
minRatioPositions.add(i);
|
||||
}
|
||||
}
|
||||
|
@ -290,17 +290,13 @@ public class SimplexSolver extends LinearOptimizer {
|
|||
|
||||
Integer minRow = null;
|
||||
int minIndex = tableau.getWidth();
|
||||
final int varStart = tableau.getNumObjectiveFunctions();
|
||||
final int varEnd = tableau.getWidth() - 1;
|
||||
for (Integer row : minRatioPositions) {
|
||||
for (int i = varStart; i < varEnd && !row.equals(minRow); i++) {
|
||||
final Integer basicRow = tableau.getBasicRow(i);
|
||||
if (basicRow != null && basicRow.equals(row) && i < minIndex) {
|
||||
minIndex = i;
|
||||
final int basicVar = tableau.getBasicVariable(row);
|
||||
if (basicVar < minIndex) {
|
||||
minIndex = basicVar;
|
||||
minRow = row;
|
||||
}
|
||||
}
|
||||
}
|
||||
return minRow;
|
||||
}
|
||||
return minRatioPositions.get(0);
|
||||
|
@ -325,17 +321,7 @@ public class SimplexSolver extends LinearOptimizer {
|
|||
throw new UnboundedSolutionException();
|
||||
}
|
||||
|
||||
// set the pivot element to 1
|
||||
double pivotVal = tableau.getEntry(pivotRow, pivotCol);
|
||||
tableau.divideRow(pivotRow, pivotVal);
|
||||
|
||||
// set the rest of the pivot column to 0
|
||||
for (int i = 0; i < tableau.getHeight(); i++) {
|
||||
if (i != pivotRow) {
|
||||
final double multiplier = tableau.getEntry(i, pivotCol);
|
||||
tableau.subtractRow(i, pivotRow, multiplier);
|
||||
}
|
||||
}
|
||||
tableau.performRowOperations(pivotCol, pivotRow);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -21,6 +21,7 @@ import java.io.ObjectInputStream;
|
|||
import java.io.ObjectOutputStream;
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
|
@ -29,7 +30,6 @@ import java.util.TreeSet;
|
|||
|
||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math3.linear.MatrixUtils;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.linear.RealVector;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
|
||||
import org.apache.commons.math3.optim.PointValuePair;
|
||||
|
@ -82,7 +82,7 @@ class SimplexTableau implements Serializable {
|
|||
private final List<String> columnLabels = new ArrayList<String>();
|
||||
|
||||
/** Simple tableau. */
|
||||
private transient RealMatrix tableau;
|
||||
private transient Array2DRowRealMatrix tableau;
|
||||
|
||||
/** Number of decision variables. */
|
||||
private final int numDecisionVariables;
|
||||
|
@ -102,6 +102,12 @@ class SimplexTableau implements Serializable {
|
|||
/** Cut-off value for entries in the tableau. */
|
||||
private final double cutOff;
|
||||
|
||||
/** Maps basic variables to row they are basic in. */
|
||||
private int[] basicVariables;
|
||||
|
||||
/** Maps rows to their corresponding basic variables. */
|
||||
private int[] basicRows;
|
||||
|
||||
/**
|
||||
* Builds a tableau for a linear problem.
|
||||
*
|
||||
|
@ -162,13 +168,15 @@ class SimplexTableau implements Serializable {
|
|||
this.epsilon = epsilon;
|
||||
this.maxUlps = maxUlps;
|
||||
this.cutOff = cutOff;
|
||||
this.numDecisionVariables = f.getCoefficients().getDimension() +
|
||||
(restrictToNonNegative ? 0 : 1);
|
||||
this.numDecisionVariables = f.getCoefficients().getDimension() + (restrictToNonNegative ? 0 : 1);
|
||||
this.numSlackVariables = getConstraintTypeCounts(Relationship.LEQ) +
|
||||
getConstraintTypeCounts(Relationship.GEQ);
|
||||
this.numArtificialVariables = getConstraintTypeCounts(Relationship.EQ) +
|
||||
getConstraintTypeCounts(Relationship.GEQ);
|
||||
this.tableau = createTableau(goalType == GoalType.MAXIMIZE);
|
||||
// initialize the basic variables for phase 1:
|
||||
// we know that only slack or artificial variables can be basic
|
||||
initializeBasicVariables(getSlackVariableOffset());
|
||||
initializeColumnLabels();
|
||||
}
|
||||
|
||||
|
@ -200,7 +208,7 @@ class SimplexTableau implements Serializable {
|
|||
* @param maximize if true, goal is to maximize the objective function
|
||||
* @return created tableau
|
||||
*/
|
||||
protected RealMatrix createTableau(final boolean maximize) {
|
||||
protected Array2DRowRealMatrix createTableau(final boolean maximize) {
|
||||
|
||||
// create a matrix of the correct size
|
||||
int width = numDecisionVariables + numSlackVariables +
|
||||
|
@ -212,13 +220,12 @@ class SimplexTableau implements Serializable {
|
|||
if (getNumObjectiveFunctions() == 2) {
|
||||
matrix.setEntry(0, 0, -1);
|
||||
}
|
||||
|
||||
int zIndex = (getNumObjectiveFunctions() == 1) ? 0 : 1;
|
||||
matrix.setEntry(zIndex, zIndex, maximize ? 1 : -1);
|
||||
RealVector objectiveCoefficients =
|
||||
maximize ? f.getCoefficients().mapMultiply(-1) : f.getCoefficients();
|
||||
RealVector objectiveCoefficients = maximize ? f.getCoefficients().mapMultiply(-1) : f.getCoefficients();
|
||||
copyArray(objectiveCoefficients.toArray(), matrix.getDataRef()[zIndex]);
|
||||
matrix.setEntry(zIndex, width - 1,
|
||||
maximize ? f.getConstantTerm() : -1 * f.getConstantTerm());
|
||||
matrix.setEntry(zIndex, width - 1, maximize ? f.getConstantTerm() : -1 * f.getConstantTerm());
|
||||
|
||||
if (!restrictToNonNegative) {
|
||||
matrix.setEntry(zIndex, getSlackVariableOffset() - 1,
|
||||
|
@ -333,6 +340,44 @@ class SimplexTableau implements Serializable {
|
|||
* @return the row that the variable is basic in. null if the column is not basic
|
||||
*/
|
||||
protected Integer getBasicRow(final int col) {
|
||||
final int row = basicVariables[col];
|
||||
return row == -1 ? null : row;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the variable that is basic in this row.
|
||||
* @param row the index of the row to check
|
||||
* @return the variable that is basic for this row.
|
||||
*/
|
||||
protected int getBasicVariable(final int row) {
|
||||
return basicRows[row];
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the basic variable / row mapping.
|
||||
* @param startColumn the column to start
|
||||
*/
|
||||
private void initializeBasicVariables(final int startColumn) {
|
||||
basicVariables = new int[getWidth() - 1];
|
||||
basicRows = new int[getHeight()];
|
||||
|
||||
Arrays.fill(basicVariables, -1);
|
||||
|
||||
for (int i = startColumn; i < getWidth() - 1; i++) {
|
||||
Integer row = findBasicRow(i);
|
||||
if (row != null) {
|
||||
basicVariables[i] = row;
|
||||
basicRows[row] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the row in which the given column is basic.
|
||||
* @param col index of the column
|
||||
* @return the row that the variable is basic in, or {@code null} if the variable is not basic.
|
||||
*/
|
||||
private Integer findBasicRow(final int col) {
|
||||
Integer row = null;
|
||||
for (int i = 0; i < getHeight(); i++) {
|
||||
final double entry = getEntry(i, col);
|
||||
|
@ -354,12 +399,12 @@ class SimplexTableau implements Serializable {
|
|||
return;
|
||||
}
|
||||
|
||||
Set<Integer> columnsToDrop = new TreeSet<Integer>();
|
||||
final Set<Integer> columnsToDrop = new TreeSet<Integer>();
|
||||
columnsToDrop.add(0);
|
||||
|
||||
// positive cost non-artificial variables
|
||||
for (int i = getNumObjectiveFunctions(); i < getArtificialVariableOffset(); i++) {
|
||||
final double entry = tableau.getEntry(0, i);
|
||||
final double entry = getEntry(0, i);
|
||||
if (Precision.compareTo(entry, 0d, epsilon) > 0) {
|
||||
columnsToDrop.add(i);
|
||||
}
|
||||
|
@ -373,12 +418,12 @@ class SimplexTableau implements Serializable {
|
|||
}
|
||||
}
|
||||
|
||||
double[][] matrix = new double[getHeight() - 1][getWidth() - columnsToDrop.size()];
|
||||
final double[][] matrix = new double[getHeight() - 1][getWidth() - columnsToDrop.size()];
|
||||
for (int i = 1; i < getHeight(); i++) {
|
||||
int col = 0;
|
||||
for (int j = 0; j < getWidth(); j++) {
|
||||
if (!columnsToDrop.contains(j)) {
|
||||
matrix[i - 1][col++] = tableau.getEntry(i, j);
|
||||
matrix[i - 1][col++] = getEntry(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -391,6 +436,8 @@ class SimplexTableau implements Serializable {
|
|||
|
||||
this.tableau = new Array2DRowRealMatrix(matrix);
|
||||
this.numArtificialVariables = 0;
|
||||
// need to update the basic variable mappings as row/columns have been dropped
|
||||
initializeBasicVariables(getNumObjectiveFunctions());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -406,8 +453,10 @@ class SimplexTableau implements Serializable {
|
|||
* @return whether the model has been solved
|
||||
*/
|
||||
boolean isOptimal() {
|
||||
for (int i = getNumObjectiveFunctions(); i < getWidth() - 1; i++) {
|
||||
final double entry = tableau.getEntry(0, i);
|
||||
final double[] objectiveFunctionRow = getRow(0);
|
||||
final int end = getRhsOffset();
|
||||
for (int i = getNumObjectiveFunctions(); i < end; i++) {
|
||||
final double entry = objectiveFunctionRow[i];
|
||||
if (Precision.compareTo(entry, 0d, epsilon) < 0) {
|
||||
return false;
|
||||
}
|
||||
|
@ -424,8 +473,8 @@ class SimplexTableau implements Serializable {
|
|||
Integer negativeVarBasicRow = negativeVarColumn > 0 ? getBasicRow(negativeVarColumn) : null;
|
||||
double mostNegative = negativeVarBasicRow == null ? 0 : getEntry(negativeVarBasicRow, getRhsOffset());
|
||||
|
||||
Set<Integer> basicRows = new HashSet<Integer>();
|
||||
double[] coefficients = new double[getOriginalNumDecisionVariables()];
|
||||
final Set<Integer> basicRows = new HashSet<Integer>();
|
||||
final double[] coefficients = new double[getOriginalNumDecisionVariables()];
|
||||
for (int i = 0; i < coefficients.length; i++) {
|
||||
int colIndex = columnLabels.indexOf("x" + i);
|
||||
if (colIndex < 0) {
|
||||
|
@ -452,6 +501,32 @@ class SimplexTableau implements Serializable {
|
|||
return new PointValuePair(coefficients, f.value(coefficients));
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform the row operations of the simplex algorithm with the selected
|
||||
* pivot column and row.
|
||||
* @param pivotCol the pivot column
|
||||
* @param pivotRow the pivot row
|
||||
*/
|
||||
protected void performRowOperations(int pivotCol, int pivotRow) {
|
||||
// set the pivot element to 1
|
||||
final double pivotVal = getEntry(pivotRow, pivotCol);
|
||||
divideRow(pivotRow, pivotVal);
|
||||
|
||||
// set the rest of the pivot column to 0
|
||||
for (int i = 0; i < getHeight(); i++) {
|
||||
if (i != pivotRow) {
|
||||
final double multiplier = getEntry(i, pivotCol);
|
||||
subtractRow(i, pivotRow, multiplier);
|
||||
}
|
||||
}
|
||||
|
||||
// update the basic variable mappings
|
||||
final int previousBasicVariable = getBasicVariable(pivotRow);
|
||||
basicVariables[previousBasicVariable] = -1;
|
||||
basicVariables[pivotCol] = pivotRow;
|
||||
basicRows[pivotRow] = pivotCol;
|
||||
}
|
||||
|
||||
/**
|
||||
* Divides one row by a given divisor.
|
||||
* <p>
|
||||
|
@ -461,9 +536,10 @@ class SimplexTableau implements Serializable {
|
|||
* @param dividendRow index of the row
|
||||
* @param divisor value of the divisor
|
||||
*/
|
||||
protected void divideRow(final int dividendRow, final double divisor) {
|
||||
protected void divideRow(final int dividendRowIndex, final double divisor) {
|
||||
final double[] dividendRow = getRow(dividendRowIndex);
|
||||
for (int j = 0; j < getWidth(); j++) {
|
||||
tableau.setEntry(dividendRow, j, tableau.getEntry(dividendRow, j) / divisor);
|
||||
dividendRow[j] /= divisor;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -477,15 +553,16 @@ class SimplexTableau implements Serializable {
|
|||
* @param subtrahendRow row index
|
||||
* @param multiple multiplication factor
|
||||
*/
|
||||
protected void subtractRow(final int minuendRow, final int subtrahendRow,
|
||||
final double multiple) {
|
||||
protected void subtractRow(final int minuendRowIndex, final int subtrahendRowIndex, final double multiplier) {
|
||||
final double[] minuendRow = getRow(minuendRowIndex);
|
||||
final double[] subtrahendRow = getRow(subtrahendRowIndex);
|
||||
for (int i = 0; i < getWidth(); i++) {
|
||||
double result = tableau.getEntry(minuendRow, i) - tableau.getEntry(subtrahendRow, i) * multiple;
|
||||
double result = minuendRow[i] - subtrahendRow[i] * multiplier;
|
||||
// cut-off values smaller than the cut-off threshold, otherwise may lead to numerical instabilities
|
||||
if (FastMath.abs(result) < cutOff) {
|
||||
if (result != 0.0 && FastMath.abs(result) < cutOff) {
|
||||
result = 0.0;
|
||||
}
|
||||
tableau.setEntry(minuendRow, i, result);
|
||||
minuendRow[i] = result;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -521,8 +598,7 @@ class SimplexTableau implements Serializable {
|
|||
* @param column column index
|
||||
* @param value for the entry
|
||||
*/
|
||||
protected final void setEntry(final int row, final int column,
|
||||
final double value) {
|
||||
protected final void setEntry(final int row, final int column, final double value) {
|
||||
tableau.setEntry(row, column, value);
|
||||
}
|
||||
|
||||
|
@ -588,6 +664,15 @@ class SimplexTableau implements Serializable {
|
|||
return numArtificialVariables;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the row from the tableau.
|
||||
* @param row the row index
|
||||
* @return the reference to the underlying row data
|
||||
*/
|
||||
protected final double[] getRow(int row) {
|
||||
return tableau.getDataRef()[row];
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the tableau data.
|
||||
* @return tableau data
|
||||
|
|
Loading…
Reference in New Issue