Bug 62836: Implementation of Excel TREND function

git-svn-id: https://svn.apache.org/repos/asf/poi/trunk@1845586 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Yegor Kozlov 2018-11-02 13:34:28 +00:00
parent 341f456ef7
commit 9aabade3f0
5 changed files with 410 additions and 1 deletions

View File

@ -115,7 +115,7 @@ public final class FunctionEval {
// 47: DVAR
retval[48] = TextFunction.TEXT;
// 49: LINEST
// 50: TREND
retval[50] = new Trend();
// 51: LOGEST
// 52: GROWTH

View File

@ -0,0 +1,377 @@
/* ====================================================================
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==================================================================== */
/*
* Notes:
* Duplicate x values don't work most of the time because of the way the
* math library handles multiple regression.
* The math library currently fails when the number of x variables is >=
* the sample size (see https://github.com/Hipparchus-Math/hipparchus/issues/13).
*/
package org.apache.poi.ss.formula.functions;
import org.apache.poi.ss.formula.CacheAreaEval;
import org.apache.poi.ss.formula.eval.AreaEval;
import org.apache.poi.ss.formula.eval.BoolEval;
import org.apache.poi.ss.formula.eval.ErrorEval;
import org.apache.poi.ss.formula.eval.EvaluationException;
import org.apache.poi.ss.formula.eval.MissingArgEval;
import org.apache.poi.ss.formula.eval.NotImplementedException;
import org.apache.poi.ss.formula.eval.NumberEval;
import org.apache.poi.ss.formula.eval.NumericValueEval;
import org.apache.poi.ss.formula.eval.RefEval;
import org.apache.poi.ss.formula.eval.ValueEval;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import java.util.Arrays;
/**
* Implementation for the Excel function TREND<p>
*
* Syntax:<br>
* TREND(known_y's, known_x's, new_x's, constant)
* <table border="0" cellpadding="1" cellspacing="0" summary="Parameter descriptions">
* <tr><th>known_y's, known_x's, new_x's</th><td>typically area references, possibly cell references or scalar values</td></tr>
* <tr><th>constant</th><td><b>TRUE</b> or <b>FALSE</b>:
* determines whether the regression line should include an intercept term</td></tr>
* </table><br>
* If <b>known_x's</b> is not given, it is assumed to be the default array {1, 2, 3, ...}
* of the same size as <b>known_y's</b>.<br>
* If <b>new_x's</b> is not given, it is assumed to be the same as <b>known_x's</b><br>
* If <b>constant</b> is omitted, it is assumed to be <b>TRUE</b>
* </p>
*/
public final class Trend implements Function {
MatrixFunction.MutableValueCollector collector = new MatrixFunction.MutableValueCollector(false, false);
private static final class TrendResults {
public double[] vals;
public int resultWidth;
public int resultHeight;
public TrendResults(double[] vals, int resultWidth, int resultHeight) {
this.vals = vals;
this.resultWidth = resultWidth;
this.resultHeight = resultHeight;
}
}
public ValueEval evaluate(ValueEval[] args, int srcRowIndex, int srcColumnIndex) {
if (args.length < 1 || args.length > 4) {
return ErrorEval.VALUE_INVALID;
}
try {
TrendResults tr = getNewY(args);
ValueEval[] vals = new ValueEval[tr.vals.length];
for (int i = 0; i < tr.vals.length; i++) {
vals[i] = new NumberEval(tr.vals[i]);
}
if (tr.vals.length == 1) {
return vals[0];
}
return new CacheAreaEval(srcRowIndex, srcColumnIndex, srcRowIndex + tr.resultHeight - 1, srcColumnIndex + tr.resultWidth - 1, vals);
} catch (EvaluationException e) {
return e.getErrorEval();
}
}
private static double[][] evalToArray(ValueEval arg) throws EvaluationException {
double[][] ar;
ValueEval eval;
if (arg instanceof MissingArgEval) {
return new double[0][0];
}
if (arg instanceof RefEval) {
RefEval re = (RefEval) arg;
if (re.getNumberOfSheets() > 1) {
throw new EvaluationException(ErrorEval.VALUE_INVALID);
}
eval = re.getInnerValueEval(re.getFirstSheetIndex());
} else {
eval = arg;
}
if (eval == null) {
throw new RuntimeException("Parameter may not be null.");
}
if (eval instanceof AreaEval) {
AreaEval ae = (AreaEval) eval;
int w = ae.getWidth();
int h = ae.getHeight();
ar = new double[h][w];
for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
ValueEval ve = ae.getRelativeValue(i, j);
if (!(ve instanceof NumericValueEval)) {
throw new EvaluationException(ErrorEval.VALUE_INVALID);
}
ar[i][j] = ((NumericValueEval)ve).getNumberValue();
}
}
} else if (eval instanceof NumericValueEval) {
ar = new double[1][1];
ar[0][0] = ((NumericValueEval)eval).getNumberValue();
} else {
throw new EvaluationException(ErrorEval.VALUE_INVALID);
}
return ar;
}
private static double[][] getDefaultArrayOneD(int w) {
double[][] array = new double[w][1];
for (int i = 0; i < w; i++) {
array[i][0] = i + 1;
}
return array;
}
private static double[] flattenArray(double[][] twoD) {
if (twoD.length < 1) {
return new double[0];
}
double[] oneD = new double[twoD.length * twoD[0].length];
for (int i = 0; i < twoD.length; i++) {
for (int j = 0; j < twoD[0].length; j++) {
oneD[i * twoD[0].length + j] = twoD[i][j];
}
}
return oneD;
}
private static double[][] flattenArrayToRow(double[][] twoD) {
if (twoD.length < 1) {
return new double[0][0];
}
double[][] oneD = new double[twoD.length * twoD[0].length][1];
for (int i = 0; i < twoD.length; i++) {
for (int j = 0; j < twoD[0].length; j++) {
oneD[i * twoD[0].length + j][0] = twoD[i][j];
}
}
return oneD;
}
private static double[][] switchRowsColumns(double[][] array) {
double[][] newArray = new double[array[0].length][array.length];
for (int i = 0; i < array.length; i++) {
for (int j = 0; j < array[0].length; j++) {
newArray[j][i] = array[i][j];
}
}
return newArray;
}
/**
* Check if all columns in a matrix contain the same values.
* Return true if the number of distinct values in each column is 1.
*
* @param matrix column-oriented matrix. A Row matrix should be transposed to column .
* @return true if all columns contain the same value
*/
private static boolean isAllColumnsSame(double[][] matrix){
if(matrix.length == 0) return false;
boolean[] cols = new boolean[matrix[0].length];
for (int j = 0; j < matrix[0].length; j++) {
double prev = Double.NaN;
for (int i = 0; i < matrix.length; i++) {
double v = matrix[i][j];
if(i > 0 && v != prev) {
cols[j] = true;
break;
}
prev = v;
}
}
boolean allEquals = true;
for (boolean x : cols) {
if(x) {
allEquals = false;
break;
}
};
return allEquals;
}
private static TrendResults getNewY(ValueEval[] args) throws EvaluationException {
double[][] xOrig;
double[][] x;
double[][] yOrig;
double[] y;
double[][] newXOrig;
double[][] newX;
double[][] resultSize;
boolean passThroughOrigin = false;
switch (args.length) {
case 1:
yOrig = evalToArray(args[0]);
xOrig = new double[0][0];
newXOrig = new double[0][0];
break;
case 2:
yOrig = evalToArray(args[0]);
xOrig = evalToArray(args[1]);
newXOrig = new double[0][0];
break;
case 3:
yOrig = evalToArray(args[0]);
xOrig = evalToArray(args[1]);
newXOrig = evalToArray(args[2]);
break;
case 4:
yOrig = evalToArray(args[0]);
xOrig = evalToArray(args[1]);
newXOrig = evalToArray(args[2]);
if (!(args[3] instanceof BoolEval)) {
throw new EvaluationException(ErrorEval.VALUE_INVALID);
}
// The argument in Excel is false when it *should* pass through the origin.
passThroughOrigin = !((BoolEval)args[3]).getBooleanValue();
break;
default:
throw new EvaluationException(ErrorEval.VALUE_INVALID);
}
if (yOrig.length < 1) {
throw new EvaluationException(ErrorEval.VALUE_INVALID);
}
y = flattenArray(yOrig);
newX = newXOrig;
if (newXOrig.length > 0) {
resultSize = newXOrig;
} else {
resultSize = new double[1][1];
}
if (y.length == 1) {
/* See comment at top of file
if (xOrig.length > 0 && !(xOrig.length == 1 || xOrig[0].length == 1)) {
throw new EvaluationException(ErrorEval.REF_INVALID);
} else if (xOrig.length < 1) {
x = new double[1][1];
x[0][0] = 1;
} else {
x = new double[1][];
x[0] = flattenArray(xOrig);
if (newXOrig.length < 1) {
resultSize = xOrig;
}
}*/
throw new NotImplementedException("Sample size too small");
} else if (yOrig.length == 1 || yOrig[0].length == 1) {
if (xOrig.length < 1) {
x = getDefaultArrayOneD(y.length);
if (newXOrig.length < 1) {
resultSize = yOrig;
}
} else {
x = xOrig;
if (xOrig[0].length > 1 && yOrig.length == 1) {
x = switchRowsColumns(x);
}
if (newXOrig.length < 1) {
resultSize = xOrig;
}
}
if (newXOrig.length > 0 && (x.length == 1 || x[0].length == 1)) {
newX = flattenArrayToRow(newXOrig);
}
} else {
if (xOrig.length < 1) {
x = getDefaultArrayOneD(y.length);
if (newXOrig.length < 1) {
resultSize = yOrig;
}
} else {
x = flattenArrayToRow(xOrig);
if (newXOrig.length < 1) {
resultSize = xOrig;
}
}
if (newXOrig.length > 0) {
newX = flattenArrayToRow(newXOrig);
}
if (y.length != x.length || yOrig.length != xOrig.length) {
throw new EvaluationException(ErrorEval.REF_INVALID);
}
}
if (newXOrig.length < 1) {
newX = x;
} else if (newXOrig.length == 1 && newXOrig[0].length > 1 && xOrig.length > 1 && xOrig[0].length == 1) {
newX = switchRowsColumns(newXOrig);
}
if (newX[0].length != x[0].length) {
throw new EvaluationException(ErrorEval.REF_INVALID);
}
if (x[0].length >= x.length) {
/* See comment at top of file */
throw new NotImplementedException("Sample size too small");
}
int resultHeight = resultSize.length;
int resultWidth = resultSize[0].length;
if(isAllColumnsSame(x)){
double[] result = new double[newX.length];
double avg = Arrays.stream(y).average().orElse(0);
for(int i = 0; i < result.length; i++) result[i] = avg;
return new TrendResults(result, resultWidth, resultHeight);
}
OLSMultipleLinearRegression reg = new OLSMultipleLinearRegression();
if (passThroughOrigin) {
reg.setNoIntercept(true);
}
try {
reg.newSampleData(y, x);
} catch (IllegalArgumentException e) {
throw new EvaluationException(ErrorEval.REF_INVALID);
}
double[] par;
try {
par = reg.estimateRegressionParameters();
} catch (SingularMatrixException e) {
throw new NotImplementedException("Singular matrix in input");
}
double[] result = new double[newX.length];
for (int i = 0; i < newX.length; i++) {
result[i] = 0;
if (passThroughOrigin) {
for (int j = 0; j < par.length; j++) {
result[i] += par[j] * newX[i][j];
}
} else {
result[i] = par[0];
for (int j = 1; j < par.length; j++) {
result[i] += par[j] * newX[i][j - 1];
}
}
}
return new TrendResults(result, resultWidth, resultHeight);
}
}

View File

@ -41,6 +41,7 @@ import org.junit.runners.Suite;
TestQuotientFunctionsFromSpreadsheet.class,
TestReptFunctionsFromSpreadsheet.class,
TestRomanFunctionsFromSpreadsheet.class,
TestTrendFunctionsFromSpreadsheet.class,
TestWeekNumFunctionsFromSpreadsheet.class,
TestWeekNumFunctionsFromSpreadsheet2013.class
})

View File

@ -0,0 +1,31 @@
/* ====================================================================
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==================================================================== */
package org.apache.poi.ss.formula.functions;
import java.util.Collection;
import org.junit.runners.Parameterized.Parameters;
/**
* Tests TREND() as loaded from a test data spreadsheet.
*/
public class TestTrendFunctionsFromSpreadsheet extends BaseTestFunctionsFromSpreadsheet {
@Parameters(name="{0}")
public static Collection<Object[]> data() throws Exception {
return data(TestTrendFunctionsFromSpreadsheet.class, "Trend.xls");
}
}

Binary file not shown.