[MATH-1079] Improve performance of SimplexSolver.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1551735 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Thomas Neidhart 2013-12-17 22:11:45 +00:00
parent 68fa81123a
commit af858a6ca2
3 changed files with 125 additions and 50 deletions

View File

@ -51,6 +51,10 @@ If the output is not quite correct, check for invisible trailing spaces!
</properties>
<body>
<release version="3.3" date="TBD" description="TBD">
<action dev="tn" type="fix" issue="MATH-1079">
Improved performance of "SimplexSolver" in package o.a.c.math3.optim.linear by
directly performing row operations and keeping track of the current basic variables.
</action>
<action dev="tn" type="update" issue="MATH-1080">
The "LinearConstraintSet" will now return the enclosed collection of "LinearConstraint"
objects in the same order as they have been added.

View File

@ -258,7 +258,7 @@ public class SimplexSolver extends LinearOptimizer {
minRatioPositions.add(i);
} else if (cmp < 0) {
minRatio = ratio;
minRatioPositions = new ArrayList<Integer>();
minRatioPositions.clear();
minRatioPositions.add(i);
}
}
@ -290,17 +290,13 @@ public class SimplexSolver extends LinearOptimizer {
Integer minRow = null;
int minIndex = tableau.getWidth();
final int varStart = tableau.getNumObjectiveFunctions();
final int varEnd = tableau.getWidth() - 1;
for (Integer row : minRatioPositions) {
for (int i = varStart; i < varEnd && !row.equals(minRow); i++) {
final Integer basicRow = tableau.getBasicRow(i);
if (basicRow != null && basicRow.equals(row) && i < minIndex) {
minIndex = i;
final int basicVar = tableau.getBasicVariable(row);
if (basicVar < minIndex) {
minIndex = basicVar;
minRow = row;
}
}
}
return minRow;
}
return minRatioPositions.get(0);
@ -325,17 +321,7 @@ public class SimplexSolver extends LinearOptimizer {
throw new UnboundedSolutionException();
}
// set the pivot element to 1
double pivotVal = tableau.getEntry(pivotRow, pivotCol);
tableau.divideRow(pivotRow, pivotVal);
// set the rest of the pivot column to 0
for (int i = 0; i < tableau.getHeight(); i++) {
if (i != pivotRow) {
final double multiplier = tableau.getEntry(i, pivotCol);
tableau.subtractRow(i, pivotRow, multiplier);
}
}
tableau.performRowOperations(pivotCol, pivotRow);
}
/**

View File

@ -21,6 +21,7 @@ import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
@ -29,7 +30,6 @@ import java.util.TreeSet;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.PointValuePair;
@ -82,7 +82,7 @@ class SimplexTableau implements Serializable {
private final List<String> columnLabels = new ArrayList<String>();
/** Simple tableau. */
private transient RealMatrix tableau;
private transient Array2DRowRealMatrix tableau;
/** Number of decision variables. */
private final int numDecisionVariables;
@ -102,6 +102,12 @@ class SimplexTableau implements Serializable {
/** Cut-off value for entries in the tableau. */
private final double cutOff;
/** Maps basic variables to row they are basic in. */
private int[] basicVariables;
/** Maps rows to their corresponding basic variables. */
private int[] basicRows;
/**
* Builds a tableau for a linear problem.
*
@ -162,13 +168,15 @@ class SimplexTableau implements Serializable {
this.epsilon = epsilon;
this.maxUlps = maxUlps;
this.cutOff = cutOff;
this.numDecisionVariables = f.getCoefficients().getDimension() +
(restrictToNonNegative ? 0 : 1);
this.numDecisionVariables = f.getCoefficients().getDimension() + (restrictToNonNegative ? 0 : 1);
this.numSlackVariables = getConstraintTypeCounts(Relationship.LEQ) +
getConstraintTypeCounts(Relationship.GEQ);
this.numArtificialVariables = getConstraintTypeCounts(Relationship.EQ) +
getConstraintTypeCounts(Relationship.GEQ);
this.tableau = createTableau(goalType == GoalType.MAXIMIZE);
// initialize the basic variables for phase 1:
// we know that only slack or artificial variables can be basic
initializeBasicVariables(getSlackVariableOffset());
initializeColumnLabels();
}
@ -200,7 +208,7 @@ class SimplexTableau implements Serializable {
* @param maximize if true, goal is to maximize the objective function
* @return created tableau
*/
protected RealMatrix createTableau(final boolean maximize) {
protected Array2DRowRealMatrix createTableau(final boolean maximize) {
// create a matrix of the correct size
int width = numDecisionVariables + numSlackVariables +
@ -212,13 +220,12 @@ class SimplexTableau implements Serializable {
if (getNumObjectiveFunctions() == 2) {
matrix.setEntry(0, 0, -1);
}
int zIndex = (getNumObjectiveFunctions() == 1) ? 0 : 1;
matrix.setEntry(zIndex, zIndex, maximize ? 1 : -1);
RealVector objectiveCoefficients =
maximize ? f.getCoefficients().mapMultiply(-1) : f.getCoefficients();
RealVector objectiveCoefficients = maximize ? f.getCoefficients().mapMultiply(-1) : f.getCoefficients();
copyArray(objectiveCoefficients.toArray(), matrix.getDataRef()[zIndex]);
matrix.setEntry(zIndex, width - 1,
maximize ? f.getConstantTerm() : -1 * f.getConstantTerm());
matrix.setEntry(zIndex, width - 1, maximize ? f.getConstantTerm() : -1 * f.getConstantTerm());
if (!restrictToNonNegative) {
matrix.setEntry(zIndex, getSlackVariableOffset() - 1,
@ -333,6 +340,44 @@ class SimplexTableau implements Serializable {
* @return the row that the variable is basic in. null if the column is not basic
*/
protected Integer getBasicRow(final int col) {
final int row = basicVariables[col];
return row == -1 ? null : row;
}
/**
* Returns the variable that is basic in this row.
* @param row the index of the row to check
* @return the variable that is basic for this row.
*/
protected int getBasicVariable(final int row) {
return basicRows[row];
}
/**
* Initializes the basic variable / row mapping.
* @param startColumn the column to start
*/
private void initializeBasicVariables(final int startColumn) {
basicVariables = new int[getWidth() - 1];
basicRows = new int[getHeight()];
Arrays.fill(basicVariables, -1);
for (int i = startColumn; i < getWidth() - 1; i++) {
Integer row = findBasicRow(i);
if (row != null) {
basicVariables[i] = row;
basicRows[row] = i;
}
}
}
/**
* Returns the row in which the given column is basic.
* @param col index of the column
* @return the row that the variable is basic in, or {@code null} if the variable is not basic.
*/
private Integer findBasicRow(final int col) {
Integer row = null;
for (int i = 0; i < getHeight(); i++) {
final double entry = getEntry(i, col);
@ -354,12 +399,12 @@ class SimplexTableau implements Serializable {
return;
}
Set<Integer> columnsToDrop = new TreeSet<Integer>();
final Set<Integer> columnsToDrop = new TreeSet<Integer>();
columnsToDrop.add(0);
// positive cost non-artificial variables
for (int i = getNumObjectiveFunctions(); i < getArtificialVariableOffset(); i++) {
final double entry = tableau.getEntry(0, i);
final double entry = getEntry(0, i);
if (Precision.compareTo(entry, 0d, epsilon) > 0) {
columnsToDrop.add(i);
}
@ -373,12 +418,12 @@ class SimplexTableau implements Serializable {
}
}
double[][] matrix = new double[getHeight() - 1][getWidth() - columnsToDrop.size()];
final double[][] matrix = new double[getHeight() - 1][getWidth() - columnsToDrop.size()];
for (int i = 1; i < getHeight(); i++) {
int col = 0;
for (int j = 0; j < getWidth(); j++) {
if (!columnsToDrop.contains(j)) {
matrix[i - 1][col++] = tableau.getEntry(i, j);
matrix[i - 1][col++] = getEntry(i, j);
}
}
}
@ -391,6 +436,8 @@ class SimplexTableau implements Serializable {
this.tableau = new Array2DRowRealMatrix(matrix);
this.numArtificialVariables = 0;
// need to update the basic variable mappings as row/columns have been dropped
initializeBasicVariables(getNumObjectiveFunctions());
}
/**
@ -406,8 +453,10 @@ class SimplexTableau implements Serializable {
* @return whether the model has been solved
*/
boolean isOptimal() {
for (int i = getNumObjectiveFunctions(); i < getWidth() - 1; i++) {
final double entry = tableau.getEntry(0, i);
final double[] objectiveFunctionRow = getRow(0);
final int end = getRhsOffset();
for (int i = getNumObjectiveFunctions(); i < end; i++) {
final double entry = objectiveFunctionRow[i];
if (Precision.compareTo(entry, 0d, epsilon) < 0) {
return false;
}
@ -424,8 +473,8 @@ class SimplexTableau implements Serializable {
Integer negativeVarBasicRow = negativeVarColumn > 0 ? getBasicRow(negativeVarColumn) : null;
double mostNegative = negativeVarBasicRow == null ? 0 : getEntry(negativeVarBasicRow, getRhsOffset());
Set<Integer> basicRows = new HashSet<Integer>();
double[] coefficients = new double[getOriginalNumDecisionVariables()];
final Set<Integer> basicRows = new HashSet<Integer>();
final double[] coefficients = new double[getOriginalNumDecisionVariables()];
for (int i = 0; i < coefficients.length; i++) {
int colIndex = columnLabels.indexOf("x" + i);
if (colIndex < 0) {
@ -452,6 +501,32 @@ class SimplexTableau implements Serializable {
return new PointValuePair(coefficients, f.value(coefficients));
}
/**
* Perform the row operations of the simplex algorithm with the selected
* pivot column and row.
* @param pivotCol the pivot column
* @param pivotRow the pivot row
*/
protected void performRowOperations(int pivotCol, int pivotRow) {
// set the pivot element to 1
final double pivotVal = getEntry(pivotRow, pivotCol);
divideRow(pivotRow, pivotVal);
// set the rest of the pivot column to 0
for (int i = 0; i < getHeight(); i++) {
if (i != pivotRow) {
final double multiplier = getEntry(i, pivotCol);
subtractRow(i, pivotRow, multiplier);
}
}
// update the basic variable mappings
final int previousBasicVariable = getBasicVariable(pivotRow);
basicVariables[previousBasicVariable] = -1;
basicVariables[pivotCol] = pivotRow;
basicRows[pivotRow] = pivotCol;
}
/**
* Divides one row by a given divisor.
* <p>
@ -461,9 +536,10 @@ class SimplexTableau implements Serializable {
* @param dividendRow index of the row
* @param divisor value of the divisor
*/
protected void divideRow(final int dividendRow, final double divisor) {
protected void divideRow(final int dividendRowIndex, final double divisor) {
final double[] dividendRow = getRow(dividendRowIndex);
for (int j = 0; j < getWidth(); j++) {
tableau.setEntry(dividendRow, j, tableau.getEntry(dividendRow, j) / divisor);
dividendRow[j] /= divisor;
}
}
@ -477,15 +553,16 @@ class SimplexTableau implements Serializable {
* @param subtrahendRow row index
* @param multiple multiplication factor
*/
protected void subtractRow(final int minuendRow, final int subtrahendRow,
final double multiple) {
protected void subtractRow(final int minuendRowIndex, final int subtrahendRowIndex, final double multiplier) {
final double[] minuendRow = getRow(minuendRowIndex);
final double[] subtrahendRow = getRow(subtrahendRowIndex);
for (int i = 0; i < getWidth(); i++) {
double result = tableau.getEntry(minuendRow, i) - tableau.getEntry(subtrahendRow, i) * multiple;
double result = minuendRow[i] - subtrahendRow[i] * multiplier;
// cut-off values smaller than the cut-off threshold, otherwise may lead to numerical instabilities
if (FastMath.abs(result) < cutOff) {
if (result != 0.0 && FastMath.abs(result) < cutOff) {
result = 0.0;
}
tableau.setEntry(minuendRow, i, result);
minuendRow[i] = result;
}
}
@ -521,8 +598,7 @@ class SimplexTableau implements Serializable {
* @param column column index
* @param value for the entry
*/
protected final void setEntry(final int row, final int column,
final double value) {
protected final void setEntry(final int row, final int column, final double value) {
tableau.setEntry(row, column, value);
}
@ -588,6 +664,15 @@ class SimplexTableau implements Serializable {
return numArtificialVariables;
}
/**
* Get the row from the tableau.
* @param row the row index
* @return the reference to the underlying row data
*/
protected final double[] getRow(int row) {
return tableau.getDataRef()[row];
}
/**
* Get the tableau data.
* @return tableau data