diff --git a/src/java/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegression.java b/src/java/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegression.java index 49057c13d..fb2825ffc 100644 --- a/src/java/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegression.java +++ b/src/java/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegression.java @@ -139,6 +139,20 @@ public abstract class AbstractMultipleLinearRegression implements public double[][] estimateRegressionParametersVariance() { return calculateBetaVariance().getData(); } + + /** + * {@inheritDoc} + */ + public double[] estimateRegressionParametersStandardErrors() { + double[][] betaVariance = estimateRegressionParametersVariance(); + double sigma = calculateYVariance(); + int length = betaVariance[0].length; + double[] result = new double[length]; + for (int i = 0; i < length; i++) { + result[i] = Math.sqrt(sigma * betaVariance[i][i]); + } + return result; + } /** * {@inheritDoc} diff --git a/src/java/org/apache/commons/math/stat/regression/MultipleLinearRegression.java b/src/java/org/apache/commons/math/stat/regression/MultipleLinearRegression.java index eb0795b50..34b4c6c46 100644 --- a/src/java/org/apache/commons/math/stat/regression/MultipleLinearRegression.java +++ b/src/java/org/apache/commons/math/stat/regression/MultipleLinearRegression.java @@ -59,5 +59,12 @@ public interface MultipleLinearRegression { * @return The double representing the variance of y */ double estimateRegressandVariance(); + + /** + * Returns the standard errors of the regression parameters. + * + * @return standard errors of estimated regression parameters + */ + double[] estimateRegressionParametersStandardErrors(); } diff --git a/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java b/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java index 15431b5b5..4cf271d4e 100644 --- a/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java +++ b/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java @@ -143,6 +143,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio * @return The beta variance */ protected RealMatrix calculateBetaVariance() { + //TODO: find a way to use QR decomp to avoid inverting XX' here RealMatrix XTX = X.transpose().multiply(X); return new LUDecompositionImpl(XTX).getSolver().getInverse(); } diff --git a/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java b/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java index 64664c266..97d87deee 100644 --- a/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java +++ b/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java @@ -149,8 +149,17 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs 1E-8); // Check standard errors from NIST - double[][] errors = model.estimateRegressionParametersVariance(); - //TODO: translate this into std error vector and check + double[] errors = model.estimateRegressionParametersStandardErrors(); + TestUtils.assertEquals(new double[] {890420.383607373, + 84.9149257747669, + 0.334910077722432E-01, + 0.488399681651699, + 0.214274163161675, + 0.226073200069370, + 455.478499142212}, errors, 1E-2); // Ugh.. + // Bad accuracy is in intercept std error estimate. Could be due to + // Current impl inverting XX' to get standard errors. + } /** @@ -245,7 +254,15 @@ public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs 5.4326230830188482,-7.2375578629692230,2.1671550814448222, 15.0147574652763112,4.8625103516321015,-7.1597256413907706, -0.4515205619767598,-10.2916870903837587,-15.7812984571900063}, - 1E-12); + 1E-12); + + // Check standard errors from R + double[] errors = model.estimateRegressionParametersStandardErrors(); + TestUtils.assertEquals(new double[] {6.94881329475087, + 0.07360008972340, + 0.27410957467466, + 0.19454551679325, + 0.03726654773803}, errors, 1E-10); } /**