draft 1 (#15222)
This commit is contained in:
parent
336ee922eb
commit
0b229a9af4
|
@ -0,0 +1,47 @@
|
||||||
|
package com.baeldung.deserializationfilters;
|
||||||
|
|
||||||
|
import java.io.ObjectInputFilter;
|
||||||
|
import java.util.function.BinaryOperator;
|
||||||
|
|
||||||
|
import com.baeldung.deserializationfilters.service.DeserializationService;
|
||||||
|
import com.baeldung.deserializationfilters.service.LimitedArrayService;
|
||||||
|
import com.baeldung.deserializationfilters.service.LowDepthService;
|
||||||
|
import com.baeldung.deserializationfilters.service.SmallObjectService;
|
||||||
|
import com.baeldung.deserializationfilters.utils.FilterUtils;
|
||||||
|
|
||||||
|
public class ContextSpecificDeserializationFilterFactory implements BinaryOperator<ObjectInputFilter> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ObjectInputFilter apply(ObjectInputFilter current, ObjectInputFilter next) {
|
||||||
|
if (current == null) {
|
||||||
|
Class<?> caller = findInStack(DeserializationService.class);
|
||||||
|
|
||||||
|
if (caller == null) {
|
||||||
|
current = FilterUtils.fallbackFilter();
|
||||||
|
} else if (caller.equals(SmallObjectService.class)) {
|
||||||
|
current = FilterUtils.safeSizeFilter(190);
|
||||||
|
} else if (caller.equals(LowDepthService.class)) {
|
||||||
|
current = FilterUtils.safeDepthFilter(2);
|
||||||
|
} else if (caller.equals(LimitedArrayService.class)) {
|
||||||
|
current = FilterUtils.safeArrayFilter(3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ObjectInputFilter.merge(current, next);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Class<?> findInStack(Class<?> superType) {
|
||||||
|
for (StackTraceElement element : Thread.currentThread()
|
||||||
|
.getStackTrace()) {
|
||||||
|
try {
|
||||||
|
Class<?> subType = Class.forName(element.getClassName());
|
||||||
|
if (superType.isAssignableFrom(subType)) {
|
||||||
|
return subType;
|
||||||
|
}
|
||||||
|
} catch (ClassNotFoundException e) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
package com.baeldung.deserializationfilters.pojo;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
public interface ContextSpecific extends Serializable {
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
package com.baeldung.deserializationfilters.pojo;
|
||||||
|
|
||||||
|
public class NestedSample implements ContextSpecific {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
private Sample optional;
|
||||||
|
|
||||||
|
public NestedSample(Sample optional) {
|
||||||
|
this.optional = optional;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Sample getOptional() {
|
||||||
|
return optional;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setOptional(Sample optional) {
|
||||||
|
this.optional = optional;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
package com.baeldung.deserializationfilters.pojo;
|
||||||
|
|
||||||
|
public class Sample implements ContextSpecific, Comparable<Sample> {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
private int[] array;
|
||||||
|
private String name;
|
||||||
|
private NestedSample nested;
|
||||||
|
|
||||||
|
public Sample(String name) {
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Sample(int[] array) {
|
||||||
|
this.array = array;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Sample(NestedSample nested) {
|
||||||
|
this.nested = nested;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getName() {
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setName(String name) {
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int[] getArray() {
|
||||||
|
return array;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setArray(int[] array) {
|
||||||
|
this.array = array;
|
||||||
|
}
|
||||||
|
|
||||||
|
public NestedSample getNested() {
|
||||||
|
return nested;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setNested(NestedSample nested) {
|
||||||
|
this.nested = nested;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int compareTo(Sample o) {
|
||||||
|
if (name == null)
|
||||||
|
return -1;
|
||||||
|
|
||||||
|
if (o == null || o.getName() == null)
|
||||||
|
return 1;
|
||||||
|
|
||||||
|
return getName().compareTo(o.getName());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package com.baeldung.deserializationfilters.pojo;
|
||||||
|
|
||||||
|
public class SampleExploit extends Sample {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
public SampleExploit() {
|
||||||
|
super("exploit");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void maliciousCode() {
|
||||||
|
System.out.println("exploit executed");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
maliciousCode();
|
||||||
|
return "exploit";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int compareTo(Sample o) {
|
||||||
|
maliciousCode();
|
||||||
|
return super.compareTo(o);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,11 @@
|
||||||
|
package com.baeldung.deserializationfilters.service;
|
||||||
|
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
|
||||||
|
|
||||||
|
public interface DeserializationService {
|
||||||
|
|
||||||
|
Set<ContextSpecific> process(InputStream... inputStreams);
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
package com.baeldung.deserializationfilters.service;
|
||||||
|
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
|
||||||
|
import com.baeldung.deserializationfilters.utils.DeserializationUtils;
|
||||||
|
|
||||||
|
public class LimitedArrayService implements DeserializationService {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Set<ContextSpecific> process(InputStream... inputStreams) {
|
||||||
|
return DeserializationUtils.deserializeIntoSet(inputStreams);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
package com.baeldung.deserializationfilters.service;
|
||||||
|
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.ObjectInputFilter;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
|
||||||
|
import com.baeldung.deserializationfilters.utils.DeserializationUtils;
|
||||||
|
|
||||||
|
public class LowDepthService implements DeserializationService {
|
||||||
|
|
||||||
|
public Set<ContextSpecific> process(ObjectInputFilter filter, InputStream... inputStreams) {
|
||||||
|
return DeserializationUtils.deserializeIntoSet(filter, inputStreams);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Set<ContextSpecific> process(InputStream... inputStreams) {
|
||||||
|
return process(null, inputStreams);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
package com.baeldung.deserializationfilters.service;
|
||||||
|
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
|
||||||
|
import com.baeldung.deserializationfilters.utils.DeserializationUtils;
|
||||||
|
|
||||||
|
public class SmallObjectService implements DeserializationService {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Set<ContextSpecific> process(InputStream... inputStreams) {
|
||||||
|
return DeserializationUtils.deserializeIntoSet(inputStreams);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,50 @@
|
||||||
|
package com.baeldung.deserializationfilters.utils;
|
||||||
|
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.InvalidClassException;
|
||||||
|
import java.io.ObjectInputFilter;
|
||||||
|
import java.io.ObjectInputStream;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.TreeSet;
|
||||||
|
|
||||||
|
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
|
||||||
|
|
||||||
|
public class DeserializationUtils {
|
||||||
|
private DeserializationUtils() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Object deserialize(InputStream inStream) {
|
||||||
|
return deserialize(inStream, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Object deserialize(InputStream inStream, ObjectInputFilter filter) {
|
||||||
|
try (ObjectInputStream in = new ObjectInputStream(inStream)) {
|
||||||
|
if (filter != null) {
|
||||||
|
in.setObjectInputFilter(filter);
|
||||||
|
}
|
||||||
|
return in.readObject();
|
||||||
|
} catch (InvalidClassException e) {
|
||||||
|
return null;
|
||||||
|
} catch (Throwable e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Set<ContextSpecific> deserializeIntoSet(InputStream... inputStreams) {
|
||||||
|
return deserializeIntoSet(null, inputStreams);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Set<ContextSpecific> deserializeIntoSet(ObjectInputFilter filter, InputStream... inputStreams) {
|
||||||
|
Set<ContextSpecific> set = new TreeSet<>();
|
||||||
|
|
||||||
|
for (InputStream inputStream : inputStreams) {
|
||||||
|
Object object = deserialize(inputStream, filter);
|
||||||
|
if (object != null) {
|
||||||
|
set.add((ContextSpecific) object);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return set;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
package com.baeldung.deserializationfilters.utils;
|
||||||
|
|
||||||
|
import java.io.ObjectInputFilter;
|
||||||
|
|
||||||
|
public class FilterUtils {
|
||||||
|
|
||||||
|
private static final String DEFAULT_PACKAGE_PATTERN = "java.base/*;!*";
|
||||||
|
private static final String POJO_PACKAGE = "com.baeldung.deserializationfilters.pojo";
|
||||||
|
|
||||||
|
private FilterUtils() {
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ObjectInputFilter baseFilter(String parameter, int max) {
|
||||||
|
return ObjectInputFilter.Config.createFilter(String.format("%s=%d;%s.**;%s", parameter, max, POJO_PACKAGE, DEFAULT_PACKAGE_PATTERN));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ObjectInputFilter fallbackFilter() {
|
||||||
|
return ObjectInputFilter.Config.createFilter(String.format("%s", DEFAULT_PACKAGE_PATTERN));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ObjectInputFilter safeSizeFilter(int max) {
|
||||||
|
return baseFilter("maxbytes", max);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ObjectInputFilter safeArrayFilter(int max) {
|
||||||
|
return baseFilter("maxarray", max);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ObjectInputFilter safeDepthFilter(int max) {
|
||||||
|
return baseFilter("maxdepth", max);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,17 @@
|
||||||
|
package com.baeldung.deserializationfilters.utils;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.ObjectOutputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
|
||||||
|
public class SerializationUtils {
|
||||||
|
|
||||||
|
private SerializationUtils() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void serialize(Object object, OutputStream outStream) throws IOException {
|
||||||
|
try (ObjectOutputStream objStream = new ObjectOutputStream(outStream)) {
|
||||||
|
objStream.writeObject(object);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,119 @@
|
||||||
|
package com.baeldung.deserializationfilters;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertNull;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
import java.io.ByteArrayInputStream;
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.ObjectInputFilter;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
|
||||||
|
import com.baeldung.deserializationfilters.pojo.NestedSample;
|
||||||
|
import com.baeldung.deserializationfilters.pojo.Sample;
|
||||||
|
import com.baeldung.deserializationfilters.pojo.SampleExploit;
|
||||||
|
import com.baeldung.deserializationfilters.service.LimitedArrayService;
|
||||||
|
import com.baeldung.deserializationfilters.service.LowDepthService;
|
||||||
|
import com.baeldung.deserializationfilters.service.SmallObjectService;
|
||||||
|
import com.baeldung.deserializationfilters.utils.DeserializationUtils;
|
||||||
|
import com.baeldung.deserializationfilters.utils.FilterUtils;
|
||||||
|
import com.baeldung.deserializationfilters.utils.SerializationUtils;
|
||||||
|
|
||||||
|
public class ContextSpecificDeserializationFilterIntegrationTest {
|
||||||
|
|
||||||
|
private static ByteArrayOutputStream serialSampleA = new ByteArrayOutputStream();
|
||||||
|
private static ByteArrayOutputStream serialBigSampleA = new ByteArrayOutputStream();
|
||||||
|
|
||||||
|
private static ByteArrayOutputStream serialSampleB = new ByteArrayOutputStream();
|
||||||
|
private static ByteArrayOutputStream serialBigSampleB = new ByteArrayOutputStream();
|
||||||
|
|
||||||
|
private static ByteArrayOutputStream serialSampleC = new ByteArrayOutputStream();
|
||||||
|
private static ByteArrayOutputStream serialBigSampleC = new ByteArrayOutputStream();
|
||||||
|
|
||||||
|
private static ByteArrayInputStream bytes(ByteArrayOutputStream stream) {
|
||||||
|
return new ByteArrayInputStream(stream.toByteArray());
|
||||||
|
}
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
static void setup() throws IOException {
|
||||||
|
ObjectInputFilter.Config.setSerialFilterFactory(new ContextSpecificDeserializationFilterFactory());
|
||||||
|
|
||||||
|
SerializationUtils.serialize(new Sample("simple"), serialSampleA);
|
||||||
|
SerializationUtils.serialize(new SampleExploit(), serialBigSampleA);
|
||||||
|
|
||||||
|
SerializationUtils.serialize(new Sample(new int[] { 1, 2, 3 }), serialSampleB);
|
||||||
|
SerializationUtils.serialize(new Sample(new int[] { 1, 2, 3, 4, 5, 6 }), serialBigSampleB);
|
||||||
|
|
||||||
|
SerializationUtils.serialize(new Sample(new NestedSample(null)), serialSampleC);
|
||||||
|
SerializationUtils.serialize(new Sample(new NestedSample(new Sample("deep"))), serialBigSampleC);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void whenSmallObjectContext_thenCorrectFilterApplied() {
|
||||||
|
Set<ContextSpecific> result = new SmallObjectService().process( //
|
||||||
|
bytes(serialSampleA), //
|
||||||
|
bytes(serialBigSampleA));
|
||||||
|
|
||||||
|
assertEquals(1, result.size());
|
||||||
|
assertEquals("simple", ((Sample) result.iterator()
|
||||||
|
.next()).getName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void whenLimitedArrayContext_thenCorrectFilterApplied() {
|
||||||
|
Set<ContextSpecific> result = new LimitedArrayService().process( //
|
||||||
|
bytes(serialSampleB), //
|
||||||
|
bytes(serialBigSampleB));
|
||||||
|
|
||||||
|
assertEquals(1, result.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void whenLowDepthContext_thenCorrectFilterApplied() {
|
||||||
|
Set<ContextSpecific> result = new LowDepthService().process( //
|
||||||
|
bytes(serialSampleC), //
|
||||||
|
bytes(serialBigSampleC));
|
||||||
|
|
||||||
|
assertEquals(1, result.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void givenExtraFilter_whenCombinedContext_thenMergedFiltersApplied() {
|
||||||
|
Set<ContextSpecific> result = new LowDepthService().process( //
|
||||||
|
FilterUtils.safeSizeFilter(190), //
|
||||||
|
bytes(serialSampleA), //
|
||||||
|
bytes(serialBigSampleA), //
|
||||||
|
bytes(serialSampleC), //
|
||||||
|
bytes(serialBigSampleC));
|
||||||
|
|
||||||
|
assertEquals(1, result.size());
|
||||||
|
assertEquals("simple", ((Sample) result.iterator()
|
||||||
|
.next()).getName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void givenFallbackContext_whenUsingBaseClasses_thenRestrictiveFilterApplied() throws IOException {
|
||||||
|
String a = new String("a");
|
||||||
|
ByteArrayOutputStream outStream = new ByteArrayOutputStream();
|
||||||
|
SerializationUtils.serialize(a, outStream);
|
||||||
|
|
||||||
|
String deserializedA = (String) DeserializationUtils.deserialize(bytes(outStream));
|
||||||
|
|
||||||
|
assertEquals(a, deserializedA);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void givenFallbackContext_whenUsingAppClasses_thenRejected() throws IOException {
|
||||||
|
Sample a = new Sample("a");
|
||||||
|
ByteArrayOutputStream outStream = new ByteArrayOutputStream();
|
||||||
|
SerializationUtils.serialize(a, outStream);
|
||||||
|
|
||||||
|
Sample deserializedA = (Sample) DeserializationUtils.deserialize(bytes(outStream));
|
||||||
|
|
||||||
|
assertNull(deserializedA);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue