From 7f56cc8556f04afd90c3b31b9207e6e0a85eff45 Mon Sep 17 00:00:00 2001 From: Luc Maisonobe Date: Thu, 10 Dec 2015 19:42:36 +0100 Subject: [PATCH] Fixed additional equations mapping. --- .../math3/ode/FieldEquationsMapper.java | 33 +- .../commons/math3/ode/FieldExpandableODE.java | 12 +- .../commons/math3/ode/FieldODEState.java | 16 +- .../math3/ode/FieldODEStateAndDerivative.java | 3 +- .../ode/nonstiff/FieldExpandableODETest.java | 344 ++++++++++++++++++ 5 files changed, 361 insertions(+), 47 deletions(-) create mode 100644 src/test/java/org/apache/commons/math3/ode/nonstiff/FieldExpandableODETest.java diff --git a/src/main/java/org/apache/commons/math3/ode/FieldEquationsMapper.java b/src/main/java/org/apache/commons/math3/ode/FieldEquationsMapper.java index 0dd445a90..fd5f0870c 100644 --- a/src/main/java/org/apache/commons/math3/ode/FieldEquationsMapper.java +++ b/src/main/java/org/apache/commons/math3/ode/FieldEquationsMapper.java @@ -57,7 +57,7 @@ public class FieldEquationsMapper> implements Seri if (mapper == null) { start[0] = 0; } else { - System.arraycopy(mapper.start, 0, start, 0, index); + System.arraycopy(mapper.start, 0, start, 0, index + 1); } start[index + 1] = start[index] + dimension; } @@ -88,7 +88,7 @@ public class FieldEquationsMapper> implements Seri int index = 0; insertEquationData(index, state.getState(), y); while (++index < getNumberOfEquations()) { - insertEquationData(index, state.getSecondaryState(index - 1), y); + insertEquationData(index, state.getSecondaryState(index), y); } return y; } @@ -102,38 +102,11 @@ public class FieldEquationsMapper> implements Seri int index = 0; insertEquationData(index, state.getDerivative(), yDot); while (++index < getNumberOfEquations()) { - insertEquationData(index, state.getSecondaryDerivative(index - 1), yDot); + insertEquationData(index, state.getSecondaryDerivative(index), yDot); } return yDot; } - /** Map a flat array to a state. - * @param t time - * @param y array to map, including primary and secondary components - * @return mapped state - * @exception DimensionMismatchException if array does not match total dimension - */ - public FieldODEState mapState(final T t, final T[] y) - throws DimensionMismatchException { - - if (y.length != getTotalDimension()) { - throw new DimensionMismatchException(y.length, getTotalDimension()); - } - - final int n = getNumberOfEquations(); - int index = 0; - final T[] state = extractEquationData(index, y); - if (n < 2) { - return new FieldODEState(t, state); - } else { - final T[][] secondaryState = MathArrays.buildArray(t.getField(), n - 1, -1); - while (++index < n) { - secondaryState[index - 1] = extractEquationData(index, y); - } - return new FieldODEState(t, state, secondaryState); - } - } - /** Map flat arrays to a state and derivative. * @param t time * @param y state array to map, including primary and secondary components diff --git a/src/main/java/org/apache/commons/math3/ode/FieldExpandableODE.java b/src/main/java/org/apache/commons/math3/ode/FieldExpandableODE.java index b4f4c5bf1..3972367aa 100644 --- a/src/main/java/org/apache/commons/math3/ode/FieldExpandableODE.java +++ b/src/main/java/org/apache/commons/math3/ode/FieldExpandableODE.java @@ -69,13 +69,6 @@ public class FieldExpandableODE> { this.mapper = new FieldEquationsMapper(null, primary.getDimension()); } - /** Get the primary set of differential equations. - * @return primary set of differential equations - */ - public FieldFirstOrderDifferentialEquations getPrimary() { - return primary; - } - /** Get the mapper for the set of equations. * @return mapper for the set of equations */ @@ -87,14 +80,15 @@ public class FieldExpandableODE> { * @param secondary secondary equations set * @return index of the secondary equation in the expanded state, to be used * as the parameter to {@link FieldODEState#getSecondaryState(int)} and - * {@link FieldODEStateAndDerivative#getSecondaryDerivative(int)} + * {@link FieldODEStateAndDerivative#getSecondaryDerivative(int)} (beware index + * 0 corresponds to main state, additional states start at 1) */ public int addSecondaryEquations(final FieldSecondaryEquations secondary) { components.add(secondary); mapper = new FieldEquationsMapper(mapper, secondary.getDimension()); - return components.size() - 1; + return components.size(); } diff --git a/src/main/java/org/apache/commons/math3/ode/FieldODEState.java b/src/main/java/org/apache/commons/math3/ode/FieldODEState.java index ff801f24a..ab3836ae5 100644 --- a/src/main/java/org/apache/commons/math3/ode/FieldODEState.java +++ b/src/main/java/org/apache/commons/math3/ode/FieldODEState.java @@ -59,9 +59,9 @@ public class FieldODEState> { * @param secondaryState state at time (may be null) */ public FieldODEState(T time, T[] state, T[][] secondaryState) { - this.time = time; - this.state = state.clone(); - this.secondaryState = copy(time.getField(), secondaryState); + this.time = time; + this.state = state.clone(); + this.secondaryState = copy(time.getField(), secondaryState); } /** Copy a two-dimensions array. @@ -77,11 +77,11 @@ public class FieldODEState> { } // allocate the array - final T[][] copied = MathArrays.buildArray(field, original.length, original[0].length); + final T[][] copied = MathArrays.buildArray(field, original.length, -1); // copy content for (int i = 0; i < original.length; ++i) { - System.arraycopy(original[i], 0, copied[i], 0, original[i].length); + copied[i] = original[i].clone(); } return copied; @@ -119,19 +119,21 @@ public class FieldODEState> { /** Get secondary state dimension. * @param index index of the secondary set as returned * by {@link FieldExpandableODE#addSecondaryEquations(FieldSecondaryEquations)} + * (beware index 0 corresponds to main state, additional states start at 1) * @return secondary state dimension */ public int getSecondaryStateDimension(final int index) { - return secondaryState[index].length; + return index == 0 ? state.length : secondaryState[index - 1].length; } /** Get secondary state at time. * @param index index of the secondary set as returned * by {@link FieldExpandableODE#addSecondaryEquations(FieldSecondaryEquations)} + * (beware index 0 corresponds to main state, additional states start at 1) * @return secondary state at time */ public T[] getSecondaryState(final int index) { - return secondaryState[index].clone(); + return index == 0 ? state.clone() : secondaryState[index - 1].clone(); } } diff --git a/src/main/java/org/apache/commons/math3/ode/FieldODEStateAndDerivative.java b/src/main/java/org/apache/commons/math3/ode/FieldODEStateAndDerivative.java index 79f192ba1..e47518540 100644 --- a/src/main/java/org/apache/commons/math3/ode/FieldODEStateAndDerivative.java +++ b/src/main/java/org/apache/commons/math3/ode/FieldODEStateAndDerivative.java @@ -72,10 +72,11 @@ public class FieldODEStateAndDerivative> extends F /** Get derivative of the secondary state at time. * @param index index of the secondary set as returned * by {@link FieldExpandableODE#addSecondaryEquations(FieldSecondaryEquations)} + * (beware index 0 corresponds to main state, additional states start at 1) * @return derivative of the secondary state at time */ public T[] getSecondaryDerivative(final int index) { - return secondaryDerivative[index].clone(); + return index == 0 ? derivative.clone() : secondaryDerivative[index - 1].clone(); } } diff --git a/src/test/java/org/apache/commons/math3/ode/nonstiff/FieldExpandableODETest.java b/src/test/java/org/apache/commons/math3/ode/nonstiff/FieldExpandableODETest.java new file mode 100644 index 000000000..b8fc65d05 --- /dev/null +++ b/src/test/java/org/apache/commons/math3/ode/nonstiff/FieldExpandableODETest.java @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ode.nonstiff; + + +import org.apache.commons.math3.Field; +import org.apache.commons.math3.RealFieldElement; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.ode.FieldExpandableODE; +import org.apache.commons.math3.ode.FieldFirstOrderDifferentialEquations; +import org.apache.commons.math3.ode.FieldODEStateAndDerivative; +import org.apache.commons.math3.ode.FieldSecondaryEquations; +import org.apache.commons.math3.util.Decimal64Field; +import org.apache.commons.math3.util.MathArrays; +import org.junit.Assert; +import org.junit.Test; + +public class FieldExpandableODETest { + + @Test + public void testOnlyMainEquation() { + doTestOnlyMainEquation(Decimal64Field.getInstance()); + } + + private > void doTestOnlyMainEquation(final Field field) { + FieldFirstOrderDifferentialEquations main = new Linear<>(field, 3, 0); + FieldExpandableODE equation = new FieldExpandableODE<>(main); + Assert.assertEquals(main.getDimension(), equation.getMapper().getTotalDimension()); + Assert.assertEquals(1, equation.getMapper().getNumberOfEquations()); + T t0 = field.getZero().add(10); + T t = field.getZero().add(100); + T[] complete = MathArrays.buildArray(field, equation.getMapper().getTotalDimension()); + for (int i = 0; i < complete.length; ++i) { + complete[i] = field.getZero().add(i); + } + equation.init(t0, complete, t); + T[] completeDot = equation.computeDerivatives(t0, complete); + FieldODEStateAndDerivative state = equation.getMapper().mapStateAndDerivative(t0, complete, completeDot); + Assert.assertEquals(0, state.getNumberOfSecondaryStates()); + T[] mainState = state.getState(); + T[] mainStateDot = state.getDerivative(); + Assert.assertEquals(main.getDimension(), mainState.length); + for (int i = 0; i < main.getDimension(); ++i) { + Assert.assertEquals(i, mainState[i].getReal(), 1.0e-15); + Assert.assertEquals(i, mainStateDot[i].getReal(), 1.0e-15); + Assert.assertEquals(i, completeDot[i].getReal(), 1.0e-15); + } + } + + @Test + public void testMainAndSecondary() { + doTestMainAndSecondary(Decimal64Field.getInstance()); + } + + private > void doTestMainAndSecondary(final Field field) { + + FieldFirstOrderDifferentialEquations main = new Linear<>(field, 3, 0); + FieldExpandableODE equation = new FieldExpandableODE(main); + FieldSecondaryEquations secondary1 = new Linear(field, 3, main.getDimension()); + int i1 = equation.addSecondaryEquations(secondary1); + FieldSecondaryEquations secondary2 = new Linear(field, 5, main.getDimension() + secondary1.getDimension()); + int i2 = equation.addSecondaryEquations(secondary2); + Assert.assertEquals(main.getDimension() + secondary1.getDimension() + secondary2.getDimension(), + equation.getMapper().getTotalDimension()); + Assert.assertEquals(3, equation.getMapper().getNumberOfEquations()); + Assert.assertEquals(1, i1); + Assert.assertEquals(2, i2); + + T t0 = field.getZero().add(10); + T t = field.getZero().add(100); + T[] complete = MathArrays.buildArray(field, equation.getMapper().getTotalDimension()); + for (int i = 0; i < complete.length; ++i) { + complete[i] = field.getZero().add(i); + } + equation.init(t0, complete, t); + T[] completeDot = equation.computeDerivatives(t0, complete); + + T[] mainState = equation.getMapper().extractEquationData(0, complete); + T[] mainStateDot = equation.getMapper().extractEquationData(0, completeDot); + Assert.assertEquals(main.getDimension(), mainState.length); + for (int i = 0; i < main.getDimension(); ++i) { + Assert.assertEquals(i, mainState[i].getReal(), 1.0e-15); + Assert.assertEquals(i, mainStateDot[i].getReal(), 1.0e-15); + Assert.assertEquals(i, completeDot[i].getReal(), 1.0e-15); + } + + T[] secondaryState1 = equation.getMapper().extractEquationData(i1, complete); + T[] secondaryState1Dot = equation.getMapper().extractEquationData(i1, completeDot); + Assert.assertEquals(secondary1.getDimension(), secondaryState1.length); + for (int i = 0; i < secondary1.getDimension(); ++i) { + Assert.assertEquals(i + main.getDimension(), secondaryState1[i].getReal(), 1.0e-15); + Assert.assertEquals(-i, secondaryState1Dot[i].getReal(), 1.0e-15); + Assert.assertEquals(-i, completeDot[i + main.getDimension()].getReal(), 1.0e-15); + } + + T[] secondaryState2 = equation.getMapper().extractEquationData(i2, complete); + T[] secondaryState2Dot = equation.getMapper().extractEquationData(i2, completeDot); + Assert.assertEquals(secondary2.getDimension(), secondaryState2.length); + for (int i = 0; i < secondary2.getDimension(); ++i) { + Assert.assertEquals(i + main.getDimension() + secondary1.getDimension(), secondaryState2[i].getReal(), 1.0e-15); + Assert.assertEquals(-i, secondaryState2Dot[i].getReal(), 1.0e-15); + Assert.assertEquals(-i, completeDot[i + main.getDimension() + secondary1.getDimension()].getReal(), 1.0e-15); + } + + } + + @Test + public void testMap() { + doTestMap(Decimal64Field.getInstance()); + } + + private > void doTestMap(final Field field) { + + FieldFirstOrderDifferentialEquations main = new Linear<>(field, 3, 0); + FieldExpandableODE equation = new FieldExpandableODE(main); + FieldSecondaryEquations secondary1 = new Linear(field, 3, main.getDimension()); + int i1 = equation.addSecondaryEquations(secondary1); + FieldSecondaryEquations secondary2 = new Linear(field, 5, main.getDimension() + secondary1.getDimension()); + int i2 = equation.addSecondaryEquations(secondary2); + Assert.assertEquals(main.getDimension() + secondary1.getDimension() + secondary2.getDimension(), + equation.getMapper().getTotalDimension()); + Assert.assertEquals(3, equation.getMapper().getNumberOfEquations()); + Assert.assertEquals(1, i1); + Assert.assertEquals(2, i2); + + T t0 = field.getZero().add(10); + T t = field.getZero().add(100); + T[] complete = MathArrays.buildArray(field, equation.getMapper().getTotalDimension()); + for (int i = 0; i < complete.length; ++i) { + complete[i] = field.getZero().add(i); + } + equation.init(t0, complete, t); + T[] completeDot = equation.computeDerivatives(t0, complete); + + try { + equation.getMapper().mapStateAndDerivative(t0, MathArrays.buildArray(field, complete.length + 1), completeDot); + Assert.fail("an exception should have been thrown"); + } catch (DimensionMismatchException dme) { + // expected + } + try { + equation.getMapper().mapStateAndDerivative(t0, complete, MathArrays.buildArray(field, completeDot.length + 1)); + Assert.fail("an exception should have been thrown"); + } catch (DimensionMismatchException dme) { + // expected + } + FieldODEStateAndDerivative state = equation.getMapper().mapStateAndDerivative(t0, complete, completeDot); + Assert.assertEquals(2, state.getNumberOfSecondaryStates()); + Assert.assertEquals(main.getDimension(), state.getSecondaryStateDimension(0)); + Assert.assertEquals(secondary1.getDimension(), state.getSecondaryStateDimension(i1)); + Assert.assertEquals(secondary2.getDimension(), state.getSecondaryStateDimension(i2)); + + T[] mainState = state.getState(); + T[] mainStateDot = state.getDerivative(); + T[] mainStateAlternate = state.getSecondaryState(0); + T[] mainStateDotAlternate = state.getSecondaryDerivative(0); + Assert.assertEquals(main.getDimension(), mainState.length); + for (int i = 0; i < main.getDimension(); ++i) { + Assert.assertEquals(i, mainState[i].getReal(), 1.0e-15); + Assert.assertEquals(i, mainStateDot[i].getReal(), 1.0e-15); + Assert.assertEquals(i, mainStateAlternate[i].getReal(), 1.0e-15); + Assert.assertEquals(i, mainStateDotAlternate[i].getReal(), 1.0e-15); + Assert.assertEquals(i, completeDot[i].getReal(), 1.0e-15); + } + + T[] secondaryState1 = state.getSecondaryState(i1); + T[] secondaryState1Dot = state.getSecondaryDerivative(i1); + Assert.assertEquals(secondary1.getDimension(), secondaryState1.length); + for (int i = 0; i < secondary1.getDimension(); ++i) { + Assert.assertEquals(i + main.getDimension(), secondaryState1[i].getReal(), 1.0e-15); + Assert.assertEquals(-i, secondaryState1Dot[i].getReal(), 1.0e-15); + Assert.assertEquals(-i, completeDot[i + main.getDimension()].getReal(), 1.0e-15); + } + + T[] secondaryState2 = state.getSecondaryState(i2); + T[] secondaryState2Dot = state.getSecondaryDerivative(i2); + Assert.assertEquals(secondary2.getDimension(), secondaryState2.length); + for (int i = 0; i < secondary2.getDimension(); ++i) { + Assert.assertEquals(i + main.getDimension() + secondary1.getDimension(), secondaryState2[i].getReal(), 1.0e-15); + Assert.assertEquals(-i, secondaryState2Dot[i].getReal(), 1.0e-15); + Assert.assertEquals(-i, completeDot[i + main.getDimension() + secondary1.getDimension()].getReal(), 1.0e-15); + } + + T[] remappedState = equation.getMapper().mapState(state); + T[] remappedDerivative = equation.getMapper().mapDerivative(state); + Assert.assertEquals(equation.getMapper().getTotalDimension(), remappedState.length); + Assert.assertEquals(equation.getMapper().getTotalDimension(), remappedDerivative.length); + for (int i = 0; i < remappedState.length; ++i) { + Assert.assertEquals(complete[i].getReal(), remappedState[i].getReal(), 1.0e-15); + Assert.assertEquals(completeDot[i].getReal(), remappedDerivative[i].getReal(), 1.0e-15); + } + } + + @Test(expected=DimensionMismatchException.class) + public void testExtractDimensionMismatch() { + doTestExtractDimensionMismatch(Decimal64Field.getInstance()); + } + + private > void doTestExtractDimensionMismatch(final Field field) + throws DimensionMismatchException { + + FieldFirstOrderDifferentialEquations main = new Linear<>(field, 3, 0); + FieldExpandableODE equation = new FieldExpandableODE(main); + FieldSecondaryEquations secondary1 = new Linear(field, 3, main.getDimension()); + int i1 = equation.addSecondaryEquations(secondary1); + T[] tooShort = MathArrays.buildArray(field, main.getDimension()); + equation.getMapper().extractEquationData(i1, tooShort); + } + + @Test(expected=DimensionMismatchException.class) + public void testInsertTooShortComplete() { + doTestInsertTooShortComplete(Decimal64Field.getInstance()); + } + + private > void doTestInsertTooShortComplete(final Field field) + throws DimensionMismatchException { + + FieldFirstOrderDifferentialEquations main = new Linear<>(field, 3, 0); + FieldExpandableODE equation = new FieldExpandableODE(main); + FieldSecondaryEquations secondary1 = new Linear(field, 3, main.getDimension()); + int i1 = equation.addSecondaryEquations(secondary1); + T[] equationData = MathArrays.buildArray(field, secondary1.getDimension()); + T[] tooShort = MathArrays.buildArray(field, main.getDimension()); + equation.getMapper().insertEquationData(i1, equationData, tooShort); + } + + @Test(expected=DimensionMismatchException.class) + public void testInsertWrongEquationData() { + doTestInsertWrongEquationData(Decimal64Field.getInstance()); + } + + private > void doTestInsertWrongEquationData(final Field field) + throws DimensionMismatchException { + + FieldFirstOrderDifferentialEquations main = new Linear<>(field, 3, 0); + FieldExpandableODE equation = new FieldExpandableODE(main); + FieldSecondaryEquations secondary1 = new Linear(field, 3, main.getDimension()); + int i1 = equation.addSecondaryEquations(secondary1); + T[] wrongEquationData = MathArrays.buildArray(field, secondary1.getDimension() + 1); + T[] complete = MathArrays.buildArray(field, equation.getMapper().getTotalDimension()); + equation.getMapper().insertEquationData(i1, wrongEquationData, complete); + } + + @Test(expected=MathIllegalArgumentException.class) + public void testNegativeIndex() { + doTestNegativeIndex(Decimal64Field.getInstance()); + } + + private > void doTestNegativeIndex(final Field field) + throws MathIllegalArgumentException { + + FieldFirstOrderDifferentialEquations main = new Linear<>(field, 3, 0); + FieldExpandableODE equation = new FieldExpandableODE(main); + T[] complete = MathArrays.buildArray(field, equation.getMapper().getTotalDimension()); + equation.getMapper().extractEquationData(-1, complete); + } + + @Test(expected=MathIllegalArgumentException.class) + public void testTooLargeIndex() { + doTestTooLargeIndex(Decimal64Field.getInstance()); + } + + private > void doTestTooLargeIndex(final Field field) + throws MathIllegalArgumentException { + + FieldFirstOrderDifferentialEquations main = new Linear<>(field, 3, 0); + FieldExpandableODE equation = new FieldExpandableODE(main); + T[] complete = MathArrays.buildArray(field, equation.getMapper().getTotalDimension()); + equation.getMapper().extractEquationData(+1, complete); + } + + private static class Linear> + implements FieldFirstOrderDifferentialEquations, FieldSecondaryEquations { + + private final Field field; + private final int dimension; + private final int start; + + private Linear(final Field field, final int dimension, final int start) { + this.field = field; + this.dimension = dimension; + this.start = start; + } + + public int getDimension() { + return dimension; + } + + public void init(final T t0, final T[] y0, final T finalTime) { + Assert.assertEquals(dimension, y0.length); + Assert.assertEquals(10.0, t0.getReal(), 1.0e-15); + Assert.assertEquals(100.0, finalTime.getReal(), 1.0e-15); + for (int i = 0; i < y0.length; ++i) { + Assert.assertEquals(i, y0[i].getReal(), 1.0e-15); + } + } + + public T[] computeDerivatives(final T t, final T[] y) { + final T[] yDot = MathArrays.buildArray(field, dimension); + for (int i = 0; i < dimension; ++i) { + yDot[i] = field.getZero().add(i); + } + return yDot; + } + + public void init(final T t0, final T[] primary0, final T[] secondary0, final T finalTime) { + Assert.assertEquals(dimension, secondary0.length); + Assert.assertEquals(10.0, t0.getReal(), 1.0e-15); + Assert.assertEquals(100.0, finalTime.getReal(), 1.0e-15); + for (int i = 0; i < primary0.length; ++i) { + Assert.assertEquals(i, primary0[i].getReal(), 1.0e-15); + } + for (int i = 0; i < secondary0.length; ++i) { + Assert.assertEquals(start + i, secondary0[i].getReal(), 1.0e-15); + } + } + + public T[] computeDerivatives(final T t, final T[] primary, final T[] primaryDot, final T[] secondary) { + final T[] secondaryDot = MathArrays.buildArray(field, dimension); + for (int i = 0; i < dimension; ++i) { + secondaryDot[i] = field.getZero().subtract(i); + } + return secondaryDot; + } + + } + +}