Add some tests for ClassLoadingAwareObjectInputStream

This commit is contained in:
Timothy Bish 2016-06-24 12:09:58 -04:00
parent e02c1a17f8
commit 9e856290c4
4 changed files with 715 additions and 0 deletions

View File

@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.util;
import java.io.Serializable;
public class AnonymousSimplePojoParent implements Serializable {
private static final long serialVersionUID = 1L;
private SimplePojo payload;
public AnonymousSimplePojoParent(Object simplePojoPayload) {
// Create an ANONYMOUS simple payload, itself serializable, like we
// have to be since the object references us and is used
// during the serialization.
payload = new SimplePojo(simplePojoPayload) {
private static final long serialVersionUID = 1L;
};
}
public SimplePojo getPayload() {
return payload;
}
}

View File

@ -0,0 +1,557 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.util;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.Arrays;
import java.util.UUID;
import java.util.Vector;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;
public class ClassLoadingAwareObjectInputStreamTest {
private final String ACCEPTS_ALL_FILTER = "*";
private final String ACCEPTS_NONE_FILTER = "";
@Rule
public TestName name = new TestName();
//----- Test for serialized objects --------------------------------------//
@Test
public void testReadObject() throws Exception {
// Expect to succeed
doTestReadObject(new SimplePojo(name.getMethodName()), ACCEPTS_ALL_FILTER);
// Expect to fail
try {
doTestReadObject(new SimplePojo(name.getMethodName()), ACCEPTS_NONE_FILTER);
fail("Should have failed to read");
} catch (ClassNotFoundException cnfe) {
// Expected
}
}
@Test
public void testReadObjectWithAnonymousClass() throws Exception {
AnonymousSimplePojoParent pojoParent = new AnonymousSimplePojoParent(name.getMethodName());
byte[] serialized = serializeObject(pojoParent);
try (ByteArrayInputStream input = new ByteArrayInputStream(serialized);
ClassLoadingAwareObjectInputStream reader = new ClassLoadingAwareObjectInputStream(input)) {
reader.setTrustAllPackages(false);
reader.setTrustedPackages(Arrays.asList(new String[] { "org.apache.activemq.util" }));
Object obj = reader.readObject();
assertTrue(obj instanceof AnonymousSimplePojoParent);
assertEquals("Unexpected payload", pojoParent.getPayload(), ((AnonymousSimplePojoParent)obj).getPayload());
}
}
@Test
public void testReadObjectWitLocalClass() throws Exception {
LocalSimplePojoParent pojoParent = new LocalSimplePojoParent(name.getMethodName());
byte[] serialized = serializeObject(pojoParent);
try (ByteArrayInputStream input = new ByteArrayInputStream(serialized);
ClassLoadingAwareObjectInputStream reader = new ClassLoadingAwareObjectInputStream(input)) {
reader.setTrustAllPackages(false);
reader.setTrustedPackages(Arrays.asList(new String[] { "org.apache.activemq.util" }));
Object obj = reader.readObject();
assertTrue(obj instanceof LocalSimplePojoParent);
assertEquals("Unexpected payload", pojoParent.getPayload(), ((LocalSimplePojoParent)obj).getPayload());
}
}
@Test
public void testReadObjectByte() throws Exception {
doTestReadObject(Byte.valueOf((byte) 255), ACCEPTS_ALL_FILTER);
}
@Test
public void testReadObjectShort() throws Exception {
doTestReadObject(Short.valueOf((short) 255), ACCEPTS_ALL_FILTER);
}
@Test
public void testReadObjectInteger() throws Exception {
doTestReadObject(Integer.valueOf(255), ACCEPTS_ALL_FILTER);
}
@Test
public void testReadObjectLong() throws Exception {
doTestReadObject(Long.valueOf(255l), ACCEPTS_ALL_FILTER);
}
@Test
public void testReadObjectFloat() throws Exception {
doTestReadObject(Float.valueOf(255.0f), ACCEPTS_ALL_FILTER);
}
@Test
public void testReadObjectDouble() throws Exception {
doTestReadObject(Double.valueOf(255.0), ACCEPTS_ALL_FILTER);
}
@Test
public void testReadObjectBoolean() throws Exception {
doTestReadObject(Boolean.FALSE, ACCEPTS_ALL_FILTER);
}
@Test
public void testReadObjectString() throws Exception {
doTestReadObject(new String(name.getMethodName()), ACCEPTS_ALL_FILTER);
}
//----- Test that arrays of objects can be read --------------------------//
@Test
public void testReadObjectStringArray() throws Exception {
String[] value = new String[2];
value[0] = name.getMethodName() + "-1";
value[1] = name.getMethodName() + "-2";
doTestReadObject(value, ACCEPTS_ALL_FILTER);
}
@Test
public void testReadObjectMultiDimensionalArray() throws Exception {
String[][][] value = new String[2][2][1];
value[0][0][0] = "0-0-0";
value[0][1][0] = "0-1-0";
value[1][0][0] = "1-0-0";
value[1][1][0] = "1-1-0";
doTestReadObject(value, ACCEPTS_ALL_FILTER);
}
//----- Test that primitive types are not filtered -----------------------//
@Test
public void testPrimitiveByteNotFiltered() throws Exception {
doTestReadPrimitive((byte) 255, ACCEPTS_NONE_FILTER);
}
@Test
public void testPrimitiveShortNotFiltered() throws Exception {
doTestReadPrimitive((short) 255, ACCEPTS_NONE_FILTER);
}
@Test
public void testPrimitiveIntegerNotFiltered() throws Exception {
doTestReadPrimitive(255, ACCEPTS_NONE_FILTER);
}
@Test
public void testPrimitiveLongNotFiltered() throws Exception {
doTestReadPrimitive((long) 255, ACCEPTS_NONE_FILTER);
}
@Test
public void testPrimitiveFloatNotFiltered() throws Exception {
doTestReadPrimitive((float) 255.0, ACCEPTS_NONE_FILTER);
}
@Test
public void testPrimitiveDoubleNotFiltered() throws Exception {
doTestReadPrimitive(255.0, ACCEPTS_NONE_FILTER);
}
@Test
public void testPrimitiveBooleanNotFiltered() throws Exception {
doTestReadPrimitive(false, ACCEPTS_NONE_FILTER);
}
@Test
public void testPrimitveCharNotFiltered() throws Exception {
doTestReadPrimitive('c', ACCEPTS_NONE_FILTER);
}
@Test
public void testReadObjectStringNotFiltered() throws Exception {
doTestReadObject(new String(name.getMethodName()), ACCEPTS_NONE_FILTER);
}
//----- Test that primitive arrays get past filters ----------------------//
@Test
public void testPrimitiveByteArrayNotFiltered() throws Exception {
byte[] value = new byte[2];
value[0] = 1;
value[1] = 2;
doTestReadPrimitiveArray(value, ACCEPTS_NONE_FILTER);
}
@Test
public void testPrimitiveShortArrayNotFiltered() throws Exception {
short[] value = new short[2];
value[0] = 1;
value[1] = 2;
doTestReadPrimitiveArray(value, ACCEPTS_NONE_FILTER);
}
@Test
public void testPrimitiveIntegerArrayNotFiltered() throws Exception {
int[] value = new int[2];
value[0] = 1;
value[1] = 2;
doTestReadPrimitiveArray(value, ACCEPTS_NONE_FILTER);
}
@Test
public void testPrimitiveLongArrayNotFiltered() throws Exception {
long[] value = new long[2];
value[0] = 1;
value[1] = 2;
doTestReadPrimitiveArray(value, ACCEPTS_NONE_FILTER);
}
@Test
public void testPrimitiveFloatArrayNotFiltered() throws Exception {
float[] value = new float[2];
value[0] = 1.1f;
value[1] = 2.1f;
doTestReadPrimitiveArray(value, ACCEPTS_NONE_FILTER);
}
@Test
public void testPrimitiveDoubleArrayNotFiltered() throws Exception {
double[] value = new double[2];
value[0] = 1.1;
value[1] = 2.1;
doTestReadPrimitiveArray(value, ACCEPTS_NONE_FILTER);
}
//----- Tests for types that should be filtered --------------------------//
@Test
public void testReadObjectArrayFiltered() throws Exception {
UUID[] value = new UUID[2];
value[0] = UUID.randomUUID();
value[1] = UUID.randomUUID();
byte[] serialized = serializeObject(value);
try (ByteArrayInputStream input = new ByteArrayInputStream(serialized);
ClassLoadingAwareObjectInputStream reader = new ClassLoadingAwareObjectInputStream(input)) {
reader.setTrustAllPackages(false);
reader.setTrustedPackages(Arrays.asList(ACCEPTS_NONE_FILTER.split(",")));
try {
reader.readObject();
fail("Should not be able to read the payload.");
} catch (ClassNotFoundException ex) {}
}
}
@Test
public void testReadObjectMixedTypeArrayGetsFiltered() throws Exception {
Object[] value = new Object[4];
value[0] = name.getMethodName();
value[1] = UUID.randomUUID();
value[2] = new Vector<Object>();
value[3] = new SimplePojo(name.getMethodName());
byte[] serialized = serializeObject(value);
try (ByteArrayInputStream input = new ByteArrayInputStream(serialized);
ClassLoadingAwareObjectInputStream reader = new ClassLoadingAwareObjectInputStream(input)) {
reader.setTrustAllPackages(false);
reader.setTrustedPackages(Arrays.asList(new String[] { "java" }));
try {
reader.readObject();
fail("Should not be able to read the payload.");
} catch (ClassNotFoundException ex) {
}
}
// Replace the filtered type and try again
value[3] = new Integer(20);
serialized = serializeObject(value);
try (ByteArrayInputStream input = new ByteArrayInputStream(serialized);
ClassLoadingAwareObjectInputStream reader = new ClassLoadingAwareObjectInputStream(input)) {
reader.setTrustAllPackages(false);
reader.setTrustedPackages(Arrays.asList(new String[] { "java" }));
try {
Object result = reader.readObject();
assertNotNull(result);
assertTrue(result.getClass().isArray());
} catch (ClassNotFoundException ex) {
fail("Should be able to read the payload.");
}
}
}
@Test
public void testReadObjectMultiDimensionalArrayFiltered() throws Exception {
UUID[][] value = new UUID[2][2];
value[0][0] = UUID.randomUUID();
value[0][1] = UUID.randomUUID();
value[1][0] = UUID.randomUUID();
value[1][1] = UUID.randomUUID();
byte[] serialized = serializeObject(value);
try (ByteArrayInputStream input = new ByteArrayInputStream(serialized);
ClassLoadingAwareObjectInputStream reader = new ClassLoadingAwareObjectInputStream(input)) {
reader.setTrustAllPackages(false);
reader.setTrustedPackages(Arrays.asList(ACCEPTS_NONE_FILTER.split(",")));
try {
reader.readObject();
fail("Should not be able to read the payload.");
} catch (ClassNotFoundException ex) {}
}
}
@Test
public void testReadObjectFailsWithUntrustedType() throws Exception {
byte[] serialized = serializeObject(new SimplePojo(name.getMethodName()));
try (ByteArrayInputStream input = new ByteArrayInputStream(serialized);
ClassLoadingAwareObjectInputStream reader = new ClassLoadingAwareObjectInputStream(input)) {
reader.setTrustAllPackages(false);
reader.setTrustedPackages(Arrays.asList(new String[] { "java" }));
try {
reader.readObject();
fail("Should not be able to read the payload.");
} catch (ClassNotFoundException ex) {}
}
serialized = serializeObject(UUID.randomUUID());
try (ByteArrayInputStream input = new ByteArrayInputStream(serialized);
ClassLoadingAwareObjectInputStream reader = new ClassLoadingAwareObjectInputStream(input)) {
try {
reader.readObject();
} catch (ClassNotFoundException ex) {
fail("Should be able to read the payload.");
}
}
}
@Test
public void testReadObjectFailsWithUnstrustedContentInTrustedType() throws Exception {
byte[] serialized = serializeObject(new SimplePojo(UUID.randomUUID()));
ByteArrayInputStream input = new ByteArrayInputStream(serialized);
try (ClassLoadingAwareObjectInputStream reader = new ClassLoadingAwareObjectInputStream(input)) {
reader.setTrustAllPackages(false);
reader.setTrustedPackages(Arrays.asList(new String[] { "org.apache.activemq" }));
try {
reader.readObject();
fail("Should not be able to read the payload.");
} catch (ClassNotFoundException ex) {}
}
serialized = serializeObject(UUID.randomUUID());
input = new ByteArrayInputStream(serialized);
try (ClassLoadingAwareObjectInputStream reader = new ClassLoadingAwareObjectInputStream(input)) {
reader.setTrustAllPackages(false);
reader.setTrustedPackages(Arrays.asList(new String[] { "org.apache.activemq" }));
try {
reader.readObject();
fail("Should not be able to read the payload.");
} catch (ClassNotFoundException ex) {
}
}
}
//----- Internal methods -------------------------------------------------//
private void doTestReadObject(Object value, String filter) throws Exception {
byte[] serialized = serializeObject(value);
try (ByteArrayInputStream input = new ByteArrayInputStream(serialized);
ClassLoadingAwareObjectInputStream reader =
new ClassLoadingAwareObjectInputStream(input)) {
reader.setTrustAllPackages(false);
reader.setTrustedPackages(Arrays.asList(filter.split(",")));
Object result = reader.readObject();
assertNotNull(result);
assertEquals(value.getClass(), result.getClass());
if (result.getClass().isArray()) {
assertTrue(Arrays.deepEquals((Object[]) value, (Object[]) result));
} else {
assertEquals(value, result);
}
}
}
private byte[] serializeObject(Object value) throws IOException {
byte[] result = new byte[0];
if (value != null) {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(value);
oos.flush();
oos.close();
result = baos.toByteArray();
}
}
return result;
}
private void doTestReadPrimitive(Object value, String filter) throws Exception {
byte[] serialized = serializePrimitive(value);
try (ByteArrayInputStream input = new ByteArrayInputStream(serialized);
ClassLoadingAwareObjectInputStream reader = new ClassLoadingAwareObjectInputStream(input)) {
reader.setTrustAllPackages(false);
reader.setTrustedPackages(Arrays.asList(filter.split(",")));
Object result = null;
if (value instanceof Byte) {
result = reader.readByte();
} else if (value instanceof Short) {
result = reader.readShort();
} else if (value instanceof Integer) {
result = reader.readInt();
} else if (value instanceof Long) {
result = reader.readLong();
} else if (value instanceof Float) {
result = reader.readFloat();
} else if (value instanceof Double) {
result = reader.readDouble();
} else if (value instanceof Boolean) {
result = reader.readBoolean();
} else if (value instanceof Character) {
result = reader.readChar();
} else {
throw new IllegalArgumentException("unsuitable type for primitive deserialization");
}
assertNotNull(result);
assertEquals(value.getClass(), result.getClass());
assertEquals(value, result);
}
}
private void doTestReadPrimitiveArray(Object value, String filter) throws Exception {
byte[] serialized = serializeObject(value);
try (ByteArrayInputStream input = new ByteArrayInputStream(serialized);
ClassLoadingAwareObjectInputStream reader = new ClassLoadingAwareObjectInputStream(input)) {
reader.setTrustAllPackages(false);
reader.setTrustedPackages(Arrays.asList(filter.split(",")));
Object result = reader.readObject();
assertNotNull(result);
assertEquals(value.getClass(), result.getClass());
assertTrue(result.getClass().isArray());
assertEquals(value.getClass().getComponentType(), result.getClass().getComponentType());
assertTrue(result.getClass().getComponentType().isPrimitive());
}
}
private byte[] serializePrimitive(Object value) throws IOException {
byte[] result = new byte[0];
if (value != null) {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos)) {
if (value instanceof Byte) {
oos.writeByte((byte) value);
} else if (value instanceof Short) {
oos.writeShort((short) value);
} else if (value instanceof Integer) {
oos.writeInt((int) value);
} else if (value instanceof Long) {
oos.writeLong((long) value);
} else if (value instanceof Float) {
oos.writeFloat((float) value);
} else if (value instanceof Double) {
oos.writeDouble((double) value);
} else if (value instanceof Boolean) {
oos.writeBoolean((boolean) value);
} else if (value instanceof Character) {
oos.writeChar((char) value);
} else {
throw new IllegalArgumentException("unsuitable type for primitive serialization");
}
oos.flush();
oos.close();
result = baos.toByteArray();
}
}
return result;
}
}

View File

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.util;
import java.io.Serializable;
public class LocalSimplePojoParent implements Serializable {
private static final long serialVersionUID = 1L;
private SimplePojo payload;
public LocalSimplePojoParent(Object simplePojoPayload) {
// Create an LOCAL simple payload, itself serializable, like we
// have to be since the object references us and is used
// during the serialization.
class LocalSimplPojo extends SimplePojo {
private static final long serialVersionUID = 1L;
LocalSimplPojo(Object simplePojoPayload) {
super(simplePojoPayload);
}
}
payload = new LocalSimplPojo(simplePojoPayload);
}
public SimplePojo getPayload() {
return payload;
}
}

View File

@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.util;
import java.io.Serializable;
public class SimplePojo implements Serializable {
private static final long serialVersionUID = 3258560248864895099L;
private Object payload;
public SimplePojo() {
}
public SimplePojo(Object payload) {
this.payload = payload;
}
public Object getPayload() {
return payload;
}
public void setPayload(Object payload) {
this.payload = payload;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((payload == null) ? 0 : payload.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
SimplePojo other = (SimplePojo) obj;
if (payload == null) {
if (other.payload != null) {
return false;
}
} else if (!payload.equals(other.payload)) {
return false;
}
return true;
}
}