draft 1 (#15222)
This commit is contained in:
parent
336ee922eb
commit
0b229a9af4
|
@ -0,0 +1,47 @@
|
|||
package com.baeldung.deserializationfilters;
|
||||
|
||||
import java.io.ObjectInputFilter;
|
||||
import java.util.function.BinaryOperator;
|
||||
|
||||
import com.baeldung.deserializationfilters.service.DeserializationService;
|
||||
import com.baeldung.deserializationfilters.service.LimitedArrayService;
|
||||
import com.baeldung.deserializationfilters.service.LowDepthService;
|
||||
import com.baeldung.deserializationfilters.service.SmallObjectService;
|
||||
import com.baeldung.deserializationfilters.utils.FilterUtils;
|
||||
|
||||
public class ContextSpecificDeserializationFilterFactory implements BinaryOperator<ObjectInputFilter> {
|
||||
|
||||
@Override
|
||||
public ObjectInputFilter apply(ObjectInputFilter current, ObjectInputFilter next) {
|
||||
if (current == null) {
|
||||
Class<?> caller = findInStack(DeserializationService.class);
|
||||
|
||||
if (caller == null) {
|
||||
current = FilterUtils.fallbackFilter();
|
||||
} else if (caller.equals(SmallObjectService.class)) {
|
||||
current = FilterUtils.safeSizeFilter(190);
|
||||
} else if (caller.equals(LowDepthService.class)) {
|
||||
current = FilterUtils.safeDepthFilter(2);
|
||||
} else if (caller.equals(LimitedArrayService.class)) {
|
||||
current = FilterUtils.safeArrayFilter(3);
|
||||
}
|
||||
}
|
||||
|
||||
return ObjectInputFilter.merge(current, next);
|
||||
}
|
||||
|
||||
private static Class<?> findInStack(Class<?> superType) {
|
||||
for (StackTraceElement element : Thread.currentThread()
|
||||
.getStackTrace()) {
|
||||
try {
|
||||
Class<?> subType = Class.forName(element.getClassName());
|
||||
if (superType.isAssignableFrom(subType)) {
|
||||
return subType;
|
||||
}
|
||||
} catch (ClassNotFoundException e) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
package com.baeldung.deserializationfilters.pojo;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
public interface ContextSpecific extends Serializable {
|
||||
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
package com.baeldung.deserializationfilters.pojo;
|
||||
|
||||
public class NestedSample implements ContextSpecific {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private Sample optional;
|
||||
|
||||
public NestedSample(Sample optional) {
|
||||
this.optional = optional;
|
||||
}
|
||||
|
||||
public Sample getOptional() {
|
||||
return optional;
|
||||
}
|
||||
|
||||
public void setOptional(Sample optional) {
|
||||
this.optional = optional;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package com.baeldung.deserializationfilters.pojo;
|
||||
|
||||
public class Sample implements ContextSpecific, Comparable<Sample> {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private int[] array;
|
||||
private String name;
|
||||
private NestedSample nested;
|
||||
|
||||
public Sample(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public Sample(int[] array) {
|
||||
this.array = array;
|
||||
}
|
||||
|
||||
public Sample(NestedSample nested) {
|
||||
this.nested = nested;
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public int[] getArray() {
|
||||
return array;
|
||||
}
|
||||
|
||||
public void setArray(int[] array) {
|
||||
this.array = array;
|
||||
}
|
||||
|
||||
public NestedSample getNested() {
|
||||
return nested;
|
||||
}
|
||||
|
||||
public void setNested(NestedSample nested) {
|
||||
this.nested = nested;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(Sample o) {
|
||||
if (name == null)
|
||||
return -1;
|
||||
|
||||
if (o == null || o.getName() == null)
|
||||
return 1;
|
||||
|
||||
return getName().compareTo(o.getName());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
package com.baeldung.deserializationfilters.pojo;
|
||||
|
||||
public class SampleExploit extends Sample {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
public SampleExploit() {
|
||||
super("exploit");
|
||||
}
|
||||
|
||||
public static void maliciousCode() {
|
||||
System.out.println("exploit executed");
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
maliciousCode();
|
||||
return "exploit";
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(Sample o) {
|
||||
maliciousCode();
|
||||
return super.compareTo(o);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
package com.baeldung.deserializationfilters.service;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.util.Set;
|
||||
|
||||
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
|
||||
|
||||
public interface DeserializationService {
|
||||
|
||||
Set<ContextSpecific> process(InputStream... inputStreams);
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
package com.baeldung.deserializationfilters.service;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.util.Set;
|
||||
|
||||
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
|
||||
import com.baeldung.deserializationfilters.utils.DeserializationUtils;
|
||||
|
||||
public class LimitedArrayService implements DeserializationService {
|
||||
|
||||
@Override
|
||||
public Set<ContextSpecific> process(InputStream... inputStreams) {
|
||||
return DeserializationUtils.deserializeIntoSet(inputStreams);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
package com.baeldung.deserializationfilters.service;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.io.ObjectInputFilter;
|
||||
import java.util.Set;
|
||||
|
||||
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
|
||||
import com.baeldung.deserializationfilters.utils.DeserializationUtils;
|
||||
|
||||
public class LowDepthService implements DeserializationService {
|
||||
|
||||
public Set<ContextSpecific> process(ObjectInputFilter filter, InputStream... inputStreams) {
|
||||
return DeserializationUtils.deserializeIntoSet(filter, inputStreams);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<ContextSpecific> process(InputStream... inputStreams) {
|
||||
return process(null, inputStreams);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
package com.baeldung.deserializationfilters.service;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.util.Set;
|
||||
|
||||
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
|
||||
import com.baeldung.deserializationfilters.utils.DeserializationUtils;
|
||||
|
||||
public class SmallObjectService implements DeserializationService {
|
||||
|
||||
@Override
|
||||
public Set<ContextSpecific> process(InputStream... inputStreams) {
|
||||
return DeserializationUtils.deserializeIntoSet(inputStreams);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
package com.baeldung.deserializationfilters.utils;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.io.InvalidClassException;
|
||||
import java.io.ObjectInputFilter;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
|
||||
|
||||
public class DeserializationUtils {
|
||||
private DeserializationUtils() {
|
||||
}
|
||||
|
||||
public static Object deserialize(InputStream inStream) {
|
||||
return deserialize(inStream, null);
|
||||
}
|
||||
|
||||
public static Object deserialize(InputStream inStream, ObjectInputFilter filter) {
|
||||
try (ObjectInputStream in = new ObjectInputStream(inStream)) {
|
||||
if (filter != null) {
|
||||
in.setObjectInputFilter(filter);
|
||||
}
|
||||
return in.readObject();
|
||||
} catch (InvalidClassException e) {
|
||||
return null;
|
||||
} catch (Throwable e) {
|
||||
e.printStackTrace();
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static Set<ContextSpecific> deserializeIntoSet(InputStream... inputStreams) {
|
||||
return deserializeIntoSet(null, inputStreams);
|
||||
}
|
||||
|
||||
public static Set<ContextSpecific> deserializeIntoSet(ObjectInputFilter filter, InputStream... inputStreams) {
|
||||
Set<ContextSpecific> set = new TreeSet<>();
|
||||
|
||||
for (InputStream inputStream : inputStreams) {
|
||||
Object object = deserialize(inputStream, filter);
|
||||
if (object != null) {
|
||||
set.add((ContextSpecific) object);
|
||||
}
|
||||
}
|
||||
|
||||
return set;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
package com.baeldung.deserializationfilters.utils;
|
||||
|
||||
import java.io.ObjectInputFilter;
|
||||
|
||||
public class FilterUtils {
|
||||
|
||||
private static final String DEFAULT_PACKAGE_PATTERN = "java.base/*;!*";
|
||||
private static final String POJO_PACKAGE = "com.baeldung.deserializationfilters.pojo";
|
||||
|
||||
private FilterUtils() {
|
||||
}
|
||||
|
||||
private static ObjectInputFilter baseFilter(String parameter, int max) {
|
||||
return ObjectInputFilter.Config.createFilter(String.format("%s=%d;%s.**;%s", parameter, max, POJO_PACKAGE, DEFAULT_PACKAGE_PATTERN));
|
||||
}
|
||||
|
||||
public static ObjectInputFilter fallbackFilter() {
|
||||
return ObjectInputFilter.Config.createFilter(String.format("%s", DEFAULT_PACKAGE_PATTERN));
|
||||
}
|
||||
|
||||
public static ObjectInputFilter safeSizeFilter(int max) {
|
||||
return baseFilter("maxbytes", max);
|
||||
}
|
||||
|
||||
public static ObjectInputFilter safeArrayFilter(int max) {
|
||||
return baseFilter("maxarray", max);
|
||||
}
|
||||
|
||||
public static ObjectInputFilter safeDepthFilter(int max) {
|
||||
return baseFilter("maxdepth", max);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
package com.baeldung.deserializationfilters.utils;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.io.OutputStream;
|
||||
|
||||
public class SerializationUtils {
|
||||
|
||||
private SerializationUtils() {
|
||||
}
|
||||
|
||||
public static void serialize(Object object, OutputStream outStream) throws IOException {
|
||||
try (ObjectOutputStream objStream = new ObjectOutputStream(outStream)) {
|
||||
objStream.writeObject(object);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
package com.baeldung.deserializationfilters;
|
||||
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInputFilter;
|
||||
import java.util.Set;
|
||||
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
|
||||
import com.baeldung.deserializationfilters.pojo.NestedSample;
|
||||
import com.baeldung.deserializationfilters.pojo.Sample;
|
||||
import com.baeldung.deserializationfilters.pojo.SampleExploit;
|
||||
import com.baeldung.deserializationfilters.service.LimitedArrayService;
|
||||
import com.baeldung.deserializationfilters.service.LowDepthService;
|
||||
import com.baeldung.deserializationfilters.service.SmallObjectService;
|
||||
import com.baeldung.deserializationfilters.utils.DeserializationUtils;
|
||||
import com.baeldung.deserializationfilters.utils.FilterUtils;
|
||||
import com.baeldung.deserializationfilters.utils.SerializationUtils;
|
||||
|
||||
public class ContextSpecificDeserializationFilterIntegrationTest {
|
||||
|
||||
private static ByteArrayOutputStream serialSampleA = new ByteArrayOutputStream();
|
||||
private static ByteArrayOutputStream serialBigSampleA = new ByteArrayOutputStream();
|
||||
|
||||
private static ByteArrayOutputStream serialSampleB = new ByteArrayOutputStream();
|
||||
private static ByteArrayOutputStream serialBigSampleB = new ByteArrayOutputStream();
|
||||
|
||||
private static ByteArrayOutputStream serialSampleC = new ByteArrayOutputStream();
|
||||
private static ByteArrayOutputStream serialBigSampleC = new ByteArrayOutputStream();
|
||||
|
||||
private static ByteArrayInputStream bytes(ByteArrayOutputStream stream) {
|
||||
return new ByteArrayInputStream(stream.toByteArray());
|
||||
}
|
||||
|
||||
@BeforeAll
|
||||
static void setup() throws IOException {
|
||||
ObjectInputFilter.Config.setSerialFilterFactory(new ContextSpecificDeserializationFilterFactory());
|
||||
|
||||
SerializationUtils.serialize(new Sample("simple"), serialSampleA);
|
||||
SerializationUtils.serialize(new SampleExploit(), serialBigSampleA);
|
||||
|
||||
SerializationUtils.serialize(new Sample(new int[] { 1, 2, 3 }), serialSampleB);
|
||||
SerializationUtils.serialize(new Sample(new int[] { 1, 2, 3, 4, 5, 6 }), serialBigSampleB);
|
||||
|
||||
SerializationUtils.serialize(new Sample(new NestedSample(null)), serialSampleC);
|
||||
SerializationUtils.serialize(new Sample(new NestedSample(new Sample("deep"))), serialBigSampleC);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenSmallObjectContext_thenCorrectFilterApplied() {
|
||||
Set<ContextSpecific> result = new SmallObjectService().process( //
|
||||
bytes(serialSampleA), //
|
||||
bytes(serialBigSampleA));
|
||||
|
||||
assertEquals(1, result.size());
|
||||
assertEquals("simple", ((Sample) result.iterator()
|
||||
.next()).getName());
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenLimitedArrayContext_thenCorrectFilterApplied() {
|
||||
Set<ContextSpecific> result = new LimitedArrayService().process( //
|
||||
bytes(serialSampleB), //
|
||||
bytes(serialBigSampleB));
|
||||
|
||||
assertEquals(1, result.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenLowDepthContext_thenCorrectFilterApplied() {
|
||||
Set<ContextSpecific> result = new LowDepthService().process( //
|
||||
bytes(serialSampleC), //
|
||||
bytes(serialBigSampleC));
|
||||
|
||||
assertEquals(1, result.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
void givenExtraFilter_whenCombinedContext_thenMergedFiltersApplied() {
|
||||
Set<ContextSpecific> result = new LowDepthService().process( //
|
||||
FilterUtils.safeSizeFilter(190), //
|
||||
bytes(serialSampleA), //
|
||||
bytes(serialBigSampleA), //
|
||||
bytes(serialSampleC), //
|
||||
bytes(serialBigSampleC));
|
||||
|
||||
assertEquals(1, result.size());
|
||||
assertEquals("simple", ((Sample) result.iterator()
|
||||
.next()).getName());
|
||||
}
|
||||
|
||||
@Test
|
||||
void givenFallbackContext_whenUsingBaseClasses_thenRestrictiveFilterApplied() throws IOException {
|
||||
String a = new String("a");
|
||||
ByteArrayOutputStream outStream = new ByteArrayOutputStream();
|
||||
SerializationUtils.serialize(a, outStream);
|
||||
|
||||
String deserializedA = (String) DeserializationUtils.deserialize(bytes(outStream));
|
||||
|
||||
assertEquals(a, deserializedA);
|
||||
}
|
||||
|
||||
@Test
|
||||
void givenFallbackContext_whenUsingAppClasses_thenRejected() throws IOException {
|
||||
Sample a = new Sample("a");
|
||||
ByteArrayOutputStream outStream = new ByteArrayOutputStream();
|
||||
SerializationUtils.serialize(a, outStream);
|
||||
|
||||
Sample deserializedA = (Sample) DeserializationUtils.deserialize(bytes(outStream));
|
||||
|
||||
assertNull(deserializedA);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue