Separate equations from mapper.
This commit is contained in:
parent
6da8a0eba0
commit
fe8646e83e
|
@ -332,10 +332,10 @@ public abstract class AbstractFieldIntegrator<T extends RealFieldElement<T>> imp
|
|||
if (needReset) {
|
||||
// some event handler has triggered changes that
|
||||
// invalidate the derivatives, we need to recompute them
|
||||
final T[] y = equations.mapState(eventState);
|
||||
final T[] y = equations.getMapper().mapState(eventState);
|
||||
final T[] yDot = computeDerivatives(eventState.getTime(), y);
|
||||
resetOccurred = true;
|
||||
return equations.mapStateAndDerivative(eventState.getTime(), y, yDot);
|
||||
return equations.getMapper().mapStateAndDerivative(eventState.getTime(), y, yDot);
|
||||
}
|
||||
|
||||
// prepare handling of the remaining part of the step
|
||||
|
|
|
@ -17,68 +17,185 @@
|
|||
|
||||
package org.apache.commons.math4.ode;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.lang.reflect.Array;
|
||||
|
||||
import org.apache.commons.math4.RealFieldElement;
|
||||
import org.apache.commons.math4.exception.MathIllegalArgumentException;
|
||||
import org.apache.commons.math4.exception.util.LocalizedFormats;
|
||||
import org.apache.commons.math4.util.MathArrays;
|
||||
|
||||
/**
|
||||
* Class mapping the part of a complete state or derivative that pertains
|
||||
* to a specific differential equation.
|
||||
* to a set of differential equations.
|
||||
* <p>
|
||||
* Instances of this class are guaranteed to be immutable.
|
||||
* </p>
|
||||
* @see FieldSecondaryEquations
|
||||
* @see FieldExpandableODE
|
||||
* @param <T> the type of the field elements
|
||||
* @since 3.6
|
||||
*/
|
||||
public class FieldEquationsMapper<T extends RealFieldElement<T>> {
|
||||
class FieldEquationsMapper<T extends RealFieldElement<T>> implements Serializable {
|
||||
|
||||
/** Index of the first equation element in complete state arrays. */
|
||||
private final int firstIndex;
|
||||
/** Serializable UID. */
|
||||
private static final long serialVersionUID = 20151114L;
|
||||
|
||||
/** Dimension of the secondary state parameters. */
|
||||
private final int dimension;
|
||||
/** Start indices of the components. */
|
||||
private final int[] start;
|
||||
|
||||
/** simple constructor.
|
||||
* @param firstIndex index of the first equation element in complete state arrays
|
||||
* @param dimension dimension of the secondary state parameters
|
||||
/** Create a mapper by adding a new equation to another mapper.
|
||||
* <p>
|
||||
* The new equation will have index {@code mapper.}{@link #getNumberOfEquations()},
|
||||
* or 0 if {@code mapper} is null.
|
||||
* </p>
|
||||
* @param mapper former mapper, with one equation less (null for first equation)
|
||||
* @param dimension dimension of the equation state vector
|
||||
*/
|
||||
FieldEquationsMapper(final int firstIndex, final int dimension) {
|
||||
this.firstIndex = firstIndex;
|
||||
this.dimension = dimension;
|
||||
FieldEquationsMapper(final FieldEquationsMapper<T> mapper, final int dimension) {
|
||||
final int index = (mapper == null) ? 0 : mapper.getNumberOfEquations();
|
||||
this.start = new int[index + 2];
|
||||
if (mapper == null) {
|
||||
start[0] = 0;
|
||||
} else {
|
||||
System.arraycopy(mapper.start, 0, start, 0, index);
|
||||
}
|
||||
start[index + 1] = start[index] + dimension;
|
||||
}
|
||||
|
||||
/** Get the index of the first equation element in complete state arrays.
|
||||
* @return index of the first equation element in complete state arrays
|
||||
/** Get the number of equations mapped.
|
||||
* @return number of equations mapped
|
||||
*/
|
||||
public int getFirstIndex() {
|
||||
return firstIndex;
|
||||
public int getNumberOfEquations() {
|
||||
return start.length - 1;
|
||||
}
|
||||
|
||||
/** Get the dimension of the secondary state parameters.
|
||||
* @return dimension of the secondary state parameters
|
||||
/** Return the dimension of the complete set of equations.
|
||||
* <p>
|
||||
* The complete set of equations correspond to the primary set plus all secondary sets.
|
||||
* </p>
|
||||
* @return dimension of the complete set of equations
|
||||
*/
|
||||
public int getDimension() {
|
||||
return dimension;
|
||||
public int getTotalDimension() {
|
||||
return start[start.length - 1];
|
||||
}
|
||||
|
||||
/** Map a state to a complete flat array.
|
||||
* @param state state to map
|
||||
* @return flat array containing the mapped state, including primary and secondary components
|
||||
*/
|
||||
public T[] mapState(final FieldODEState<T> state) {
|
||||
final T[] y = MathArrays.buildArray(state.getTime().getField(), getTotalDimension());
|
||||
int index = 0;
|
||||
insertEquationData(index, state.getState(), y);
|
||||
while (++index < getNumberOfEquations()) {
|
||||
insertEquationData(index, state.getSecondaryState(index - 1), y);
|
||||
}
|
||||
return y;
|
||||
}
|
||||
|
||||
/** Map a state derivative to a complete flat array.
|
||||
* @param state state to map
|
||||
* @return flat array containing the mapped state derivative, including primary and secondary components
|
||||
*/
|
||||
public T[] mapDerivative(final FieldODEStateAndDerivative<T> state) {
|
||||
final T[] yDot = MathArrays.buildArray(state.getTime().getField(), getTotalDimension());
|
||||
int index = 0;
|
||||
insertEquationData(index, state.getDerivative(), yDot);
|
||||
while (++index < getNumberOfEquations()) {
|
||||
insertEquationData(index, state.getSecondaryDerivative(index - 1), yDot);
|
||||
}
|
||||
return yDot;
|
||||
}
|
||||
|
||||
/** Map a flat array to a state.
|
||||
* @param t time
|
||||
* @param y array to map, including primary and secondary components
|
||||
* @return mapped state
|
||||
*/
|
||||
public FieldODEState<T> mapState(final T t, final T[] y) {
|
||||
final int n = getNumberOfEquations();
|
||||
int index = 0;
|
||||
final T[] state = extractEquationData(index, y);
|
||||
if (n < 2) {
|
||||
return new FieldODEState<T>(t, state);
|
||||
} else {
|
||||
@SuppressWarnings("unchecked")
|
||||
final T[][] secondaryState = (T[][]) Array.newInstance(t.getField().getRuntimeClass(), n - 1);
|
||||
while (++index < n) {
|
||||
secondaryState[index - 1] = extractEquationData(index, y);
|
||||
}
|
||||
return new FieldODEState<T>(t, state, secondaryState);
|
||||
}
|
||||
}
|
||||
|
||||
/** Map flat arrays to a state and derivative.
|
||||
* @param t time
|
||||
* @param y state array to map, including primary and secondary components
|
||||
* @param yDot state derivative array to map, including primary and secondary components
|
||||
* @return mapped state
|
||||
*/
|
||||
public FieldODEStateAndDerivative<T> mapStateAndDerivative(final T t, final T[] y, final T[] yDot) {
|
||||
final int n = getNumberOfEquations();
|
||||
int index = 0;
|
||||
final T[] state = extractEquationData(index, y);
|
||||
final T[] derivative = extractEquationData(index, yDot);
|
||||
if (n < 2) {
|
||||
return new FieldODEStateAndDerivative<T>(t, state, derivative);
|
||||
} else {
|
||||
@SuppressWarnings("unchecked")
|
||||
final T[][] secondaryState = (T[][]) Array.newInstance(t.getField().getRuntimeClass(), n - 1);
|
||||
@SuppressWarnings("unchecked")
|
||||
final T[][] secondaryDerivative = (T[][]) Array.newInstance(t.getField().getRuntimeClass(), n - 1);
|
||||
while (++index < getNumberOfEquations()) {
|
||||
secondaryState[index - 1] = extractEquationData(index, y);
|
||||
secondaryDerivative[index - 1] = extractEquationData(index, yDot);
|
||||
}
|
||||
return new FieldODEStateAndDerivative<T>(t, state, derivative, secondaryState, secondaryDerivative);
|
||||
}
|
||||
}
|
||||
|
||||
/** Extract equation data from a complete state or derivative array.
|
||||
* @param index index of the equation, must be between 0 included and
|
||||
* {@link #getNumberOfEquations()} (excluded)
|
||||
* @param complete complete state or derivative array from which
|
||||
* equation data should be retrieved
|
||||
* @return equation data
|
||||
* @exception MathIllegalArgumentException if index is out of range
|
||||
*/
|
||||
public T[] extractEquationData(T[] complete) {
|
||||
public T[] extractEquationData(final int index, final T[] complete)
|
||||
throws MathIllegalArgumentException {
|
||||
checkIndex(index);
|
||||
final int begin = start[index];
|
||||
final int dimension = start[index + 1] - begin;
|
||||
final T[] equationData = MathArrays.buildArray(complete[0].getField(), dimension);
|
||||
System.arraycopy(complete, firstIndex, equationData, 0, dimension);
|
||||
System.arraycopy(complete, begin, equationData, 0, dimension);
|
||||
return equationData;
|
||||
}
|
||||
|
||||
/** Insert equation data into a complete state or derivative array.
|
||||
* @param index index of the equation, must be between 0 included and
|
||||
* {@link #getNumberOfEquations()} (excluded)
|
||||
* @param equationData equation data to be inserted into the complete array
|
||||
* @param complete placeholder where to put equation data (only the
|
||||
* part corresponding to the equation will be overwritten)
|
||||
*/
|
||||
public void insertEquationData(T[] equationData, T[] complete) {
|
||||
System.arraycopy(equationData, 0, complete, firstIndex, dimension);
|
||||
public void insertEquationData(final int index, T[] equationData, T[] complete) {
|
||||
checkIndex(index);
|
||||
final int begin = start[index];
|
||||
final int dimension = start[index + 1] - begin;
|
||||
System.arraycopy(equationData, 0, complete, begin, dimension);
|
||||
}
|
||||
|
||||
/** Check equation index.
|
||||
* @param index index of the equation, must be between 0 included and
|
||||
* {@link #getNumberOfEquations()} (excluded)
|
||||
* @exception MathIllegalArgumentException if index is out of range
|
||||
*/
|
||||
private void checkIndex(final int index) throws MathIllegalArgumentException {
|
||||
if (index < 0 || index > start.length - 2) {
|
||||
throw new MathIllegalArgumentException(LocalizedFormats.ARGUMENT_OUTSIDE_DOMAIN,
|
||||
index, 0, start.length - 2);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
*/
|
||||
package org.apache.commons.math4.ode;
|
||||
|
||||
import java.lang.reflect.Array;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
|
@ -52,26 +51,22 @@ import org.apache.commons.math4.util.MathArrays;
|
|||
|
||||
public class FieldExpandableODE<T extends RealFieldElement<T>> {
|
||||
|
||||
/** Total dimension. */
|
||||
private int dimension;
|
||||
|
||||
/** Primary differential equation. */
|
||||
private final FieldFirstOrderDifferentialEquations<T> primary;
|
||||
|
||||
/** Mapper for primary equation. */
|
||||
private final FieldEquationsMapper<T> primaryMapper;
|
||||
|
||||
/** Components of the expandable ODE. */
|
||||
private List<FieldSecondaryComponent<T>> components;
|
||||
private List<FieldSecondaryEquations<T>> components;
|
||||
|
||||
/** Mapper for all equations. */
|
||||
private FieldEquationsMapper<T> mapper;
|
||||
|
||||
/** Build an expandable set from its primary ODE set.
|
||||
* @param primary the primary set of differential equations to be integrated.
|
||||
*/
|
||||
public FieldExpandableODE(final FieldFirstOrderDifferentialEquations<T> primary) {
|
||||
this.dimension = primary.getDimension();
|
||||
this.primary = primary;
|
||||
this.primaryMapper = new FieldEquationsMapper<T>(0, primary.getDimension());
|
||||
this.components = new ArrayList<FieldExpandableODE.FieldSecondaryComponent<T>>();
|
||||
this.primary = primary;
|
||||
this.components = new ArrayList<FieldSecondaryEquations<T>>();
|
||||
this.mapper = new FieldEquationsMapper<T>(null, primary.getDimension());
|
||||
}
|
||||
|
||||
/** Get the primary set of differential equations.
|
||||
|
@ -81,14 +76,11 @@ public class FieldExpandableODE<T extends RealFieldElement<T>> {
|
|||
return primary;
|
||||
}
|
||||
|
||||
/** Return the dimension of the complete set of equations.
|
||||
* <p>
|
||||
* The complete set of equations correspond to the primary set plus all secondary sets.
|
||||
* </p>
|
||||
* @return dimension of the complete set of equations
|
||||
/** Get the mapper for the set of equations.
|
||||
* @return mapper for the set of equations
|
||||
*/
|
||||
public int getTotalDimension() {
|
||||
return dimension;
|
||||
FieldEquationsMapper<T> getMapper() {
|
||||
return mapper;
|
||||
}
|
||||
|
||||
/** Add a set of secondary equations to be integrated along with the primary set.
|
||||
|
@ -99,95 +91,13 @@ public class FieldExpandableODE<T extends RealFieldElement<T>> {
|
|||
*/
|
||||
public int addSecondaryEquations(final FieldSecondaryEquations<T> secondary) {
|
||||
|
||||
final int firstIndex;
|
||||
if (components.isEmpty()) {
|
||||
// lazy creation of the components list
|
||||
components = new ArrayList<FieldExpandableODE.FieldSecondaryComponent<T>>();
|
||||
firstIndex = primary.getDimension();
|
||||
} else {
|
||||
final FieldSecondaryComponent<T> last = components.get(components.size() - 1);
|
||||
firstIndex = last.mapper.getFirstIndex() + last.mapper.getDimension();
|
||||
}
|
||||
|
||||
final FieldSecondaryComponent<T> component = new FieldSecondaryComponent<T>(secondary, firstIndex);
|
||||
components.add(component);
|
||||
|
||||
// update total dimension
|
||||
dimension = component.mapper.getFirstIndex() + component.mapper.getDimension();
|
||||
components.add(secondary);
|
||||
mapper = new FieldEquationsMapper<>(mapper, secondary.getDimension());
|
||||
|
||||
return components.size() - 1;
|
||||
|
||||
}
|
||||
|
||||
/** Map a state to a complete flat array.
|
||||
* @param state state to map
|
||||
* @return flat array containing the mapped state, including primary and secondary components
|
||||
*/
|
||||
public T[] mapState(final FieldODEState<T> state) {
|
||||
final T[] y = MathArrays.buildArray(state.getTime().getField(), getTotalDimension());
|
||||
primaryMapper.insertEquationData(state.getState(), y);
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
components.get(i).mapper.insertEquationData(state.getSecondaryState(i), y);
|
||||
}
|
||||
return y;
|
||||
}
|
||||
|
||||
/** Map a state derivative to a complete flat array.
|
||||
* @param state state to map
|
||||
* @return flat array containing the mapped state derivative, including primary and secondary components
|
||||
*/
|
||||
public T[] mapDerivative(final FieldODEStateAndDerivative<T> state) {
|
||||
final T[] yDot = MathArrays.buildArray(state.getTime().getField(), getTotalDimension());
|
||||
primaryMapper.insertEquationData(state.getDerivative(), yDot);
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
components.get(i).mapper.insertEquationData(state.getSecondaryDerivative(i), yDot);
|
||||
}
|
||||
return yDot;
|
||||
}
|
||||
|
||||
/** Map a flat array to a state.
|
||||
* @param t time
|
||||
* @param y array to map, including primary and secondary components
|
||||
* @return mapped state
|
||||
*/
|
||||
public FieldODEState<T> mapState(final T t, final T[] y) {
|
||||
final T[] state = primaryMapper.extractEquationData(y);
|
||||
if (components.isEmpty()) {
|
||||
return new FieldODEState<T>(t, state);
|
||||
} else {
|
||||
@SuppressWarnings("unchecked")
|
||||
final T[][] secondaryState = (T[][]) Array.newInstance(t.getField().getRuntimeClass(), components.size());
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
secondaryState[i] = components.get(i).mapper.extractEquationData(y);
|
||||
}
|
||||
return new FieldODEState<T>(t, state, secondaryState);
|
||||
}
|
||||
}
|
||||
|
||||
/** Map flat arrays to a state and derivative.
|
||||
* @param t time
|
||||
* @param y state array to map, including primary and secondary components
|
||||
* @param yDot state derivative array to map, including primary and secondary components
|
||||
* @return mapped state
|
||||
*/
|
||||
public FieldODEStateAndDerivative<T> mapStateAndDerivative(final T t, final T[] y, final T[] yDot) {
|
||||
final T[] state = primaryMapper.extractEquationData(y);
|
||||
final T[] derivative = primaryMapper.extractEquationData(yDot);
|
||||
if (components.isEmpty()) {
|
||||
return new FieldODEStateAndDerivative<T>(t, state, derivative);
|
||||
} else {
|
||||
@SuppressWarnings("unchecked")
|
||||
final T[][] secondaryState = (T[][]) Array.newInstance(t.getField().getRuntimeClass(), components.size());
|
||||
@SuppressWarnings("unchecked")
|
||||
final T[][] secondaryDerivative = (T[][]) Array.newInstance(t.getField().getRuntimeClass(), components.size());
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
secondaryState[i] = components.get(i).mapper.extractEquationData(y);
|
||||
secondaryDerivative[i] = components.get(i).mapper.extractEquationData(yDot);
|
||||
}
|
||||
return new FieldODEStateAndDerivative<T>(t, state, derivative, secondaryState, secondaryDerivative);
|
||||
}
|
||||
}
|
||||
|
||||
/** Get the current time derivative of the complete state vector.
|
||||
* @param t current value of the independent <I>time</I> variable
|
||||
* @param y array containing the current value of the complete state vector
|
||||
|
@ -198,44 +108,24 @@ public class FieldExpandableODE<T extends RealFieldElement<T>> {
|
|||
public T[] computeDerivatives(final T t, final T[] y)
|
||||
throws MaxCountExceededException, DimensionMismatchException {
|
||||
|
||||
final T[] yDot = MathArrays.buildArray(t.getField(), getTotalDimension());
|
||||
final T[] yDot = MathArrays.buildArray(t.getField(), mapper.getTotalDimension());
|
||||
|
||||
// compute derivatives of the primary equations
|
||||
final T[] primaryState = primaryMapper.extractEquationData(y);
|
||||
int index = 0;
|
||||
final T[] primaryState = mapper.extractEquationData(index, y);
|
||||
final T[] primaryStateDot = primary.computeDerivatives(t, primaryState);
|
||||
primaryMapper.insertEquationData(primaryStateDot, yDot);
|
||||
mapper.insertEquationData(index, primaryStateDot, yDot);
|
||||
|
||||
// Add contribution for secondary equations
|
||||
for (final FieldSecondaryComponent<T> component : components) {
|
||||
final T[] componentState = component.mapper.extractEquationData(y);
|
||||
final T[] componentStateDot = component.equation.computeDerivatives(t, primaryState, primaryStateDot, componentState);
|
||||
component.mapper.insertEquationData(componentStateDot, yDot);
|
||||
while (++index < mapper.getNumberOfEquations()) {
|
||||
final T[] componentState = mapper.extractEquationData(index, y);
|
||||
final T[] componentStateDot = components.get(index - 1).computeDerivatives(t, primaryState, primaryStateDot,
|
||||
componentState);
|
||||
mapper.insertEquationData(index, componentStateDot, yDot);
|
||||
}
|
||||
|
||||
return yDot;
|
||||
|
||||
}
|
||||
|
||||
/** Components of the compound ODE.
|
||||
* @param <S> the type of the field elements
|
||||
*/
|
||||
private static class FieldSecondaryComponent<S extends RealFieldElement<S>> {
|
||||
|
||||
/** Secondary differential equation. */
|
||||
private final FieldSecondaryEquations<S> equation;
|
||||
|
||||
/** Mapper between local and complete arrays. */
|
||||
private final FieldEquationsMapper<S> mapper;
|
||||
|
||||
/** Simple constructor.
|
||||
* @param equation secondary differential equation
|
||||
* @param firstIndex index to use for the first element in the complete arrays
|
||||
*/
|
||||
FieldSecondaryComponent(final FieldSecondaryEquations<S> equation, final int firstIndex) {
|
||||
this.equation = equation;
|
||||
this.mapper = new FieldEquationsMapper<S>(firstIndex, equation.getDimension());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue