This commit is contained in:
Ulisses Lima 2023-11-16 18:35:11 -03:00 committed by GitHub
parent 336ee922eb
commit 0b229a9af4
13 changed files with 438 additions and 0 deletions

View File

@ -0,0 +1,47 @@
package com.baeldung.deserializationfilters;
import java.io.ObjectInputFilter;
import java.util.function.BinaryOperator;
import com.baeldung.deserializationfilters.service.DeserializationService;
import com.baeldung.deserializationfilters.service.LimitedArrayService;
import com.baeldung.deserializationfilters.service.LowDepthService;
import com.baeldung.deserializationfilters.service.SmallObjectService;
import com.baeldung.deserializationfilters.utils.FilterUtils;
public class ContextSpecificDeserializationFilterFactory implements BinaryOperator<ObjectInputFilter> {
@Override
public ObjectInputFilter apply(ObjectInputFilter current, ObjectInputFilter next) {
if (current == null) {
Class<?> caller = findInStack(DeserializationService.class);
if (caller == null) {
current = FilterUtils.fallbackFilter();
} else if (caller.equals(SmallObjectService.class)) {
current = FilterUtils.safeSizeFilter(190);
} else if (caller.equals(LowDepthService.class)) {
current = FilterUtils.safeDepthFilter(2);
} else if (caller.equals(LimitedArrayService.class)) {
current = FilterUtils.safeArrayFilter(3);
}
}
return ObjectInputFilter.merge(current, next);
}
private static Class<?> findInStack(Class<?> superType) {
for (StackTraceElement element : Thread.currentThread()
.getStackTrace()) {
try {
Class<?> subType = Class.forName(element.getClassName());
if (superType.isAssignableFrom(subType)) {
return subType;
}
} catch (ClassNotFoundException e) {
return null;
}
}
return null;
}
}

View File

@ -0,0 +1,7 @@
package com.baeldung.deserializationfilters.pojo;
import java.io.Serializable;
public interface ContextSpecific extends Serializable {
}

View File

@ -0,0 +1,19 @@
package com.baeldung.deserializationfilters.pojo;
public class NestedSample implements ContextSpecific {
private static final long serialVersionUID = 1L;
private Sample optional;
public NestedSample(Sample optional) {
this.optional = optional;
}
public Sample getOptional() {
return optional;
}
public void setOptional(Sample optional) {
this.optional = optional;
}
}

View File

@ -0,0 +1,61 @@
package com.baeldung.deserializationfilters.pojo;
public class Sample implements ContextSpecific, Comparable<Sample> {
private static final long serialVersionUID = 1L;
private int[] array;
private String name;
private NestedSample nested;
public Sample(String name) {
this.name = name;
}
public Sample(int[] array) {
this.array = array;
}
public Sample(NestedSample nested) {
this.nested = nested;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public int[] getArray() {
return array;
}
public void setArray(int[] array) {
this.array = array;
}
public NestedSample getNested() {
return nested;
}
public void setNested(NestedSample nested) {
this.nested = nested;
}
@Override
public String toString() {
return name;
}
@Override
public int compareTo(Sample o) {
if (name == null)
return -1;
if (o == null || o.getName() == null)
return 1;
return getName().compareTo(o.getName());
}
}

View File

@ -0,0 +1,25 @@
package com.baeldung.deserializationfilters.pojo;
public class SampleExploit extends Sample {
private static final long serialVersionUID = 1L;
public SampleExploit() {
super("exploit");
}
public static void maliciousCode() {
System.out.println("exploit executed");
}
@Override
public String toString() {
maliciousCode();
return "exploit";
}
@Override
public int compareTo(Sample o) {
maliciousCode();
return super.compareTo(o);
}
}

View File

@ -0,0 +1,11 @@
package com.baeldung.deserializationfilters.service;
import java.io.InputStream;
import java.util.Set;
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
public interface DeserializationService {
Set<ContextSpecific> process(InputStream... inputStreams);
}

View File

@ -0,0 +1,15 @@
package com.baeldung.deserializationfilters.service;
import java.io.InputStream;
import java.util.Set;
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
import com.baeldung.deserializationfilters.utils.DeserializationUtils;
public class LimitedArrayService implements DeserializationService {
@Override
public Set<ContextSpecific> process(InputStream... inputStreams) {
return DeserializationUtils.deserializeIntoSet(inputStreams);
}
}

View File

@ -0,0 +1,20 @@
package com.baeldung.deserializationfilters.service;
import java.io.InputStream;
import java.io.ObjectInputFilter;
import java.util.Set;
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
import com.baeldung.deserializationfilters.utils.DeserializationUtils;
public class LowDepthService implements DeserializationService {
public Set<ContextSpecific> process(ObjectInputFilter filter, InputStream... inputStreams) {
return DeserializationUtils.deserializeIntoSet(filter, inputStreams);
}
@Override
public Set<ContextSpecific> process(InputStream... inputStreams) {
return process(null, inputStreams);
}
}

View File

@ -0,0 +1,15 @@
package com.baeldung.deserializationfilters.service;
import java.io.InputStream;
import java.util.Set;
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
import com.baeldung.deserializationfilters.utils.DeserializationUtils;
public class SmallObjectService implements DeserializationService {
@Override
public Set<ContextSpecific> process(InputStream... inputStreams) {
return DeserializationUtils.deserializeIntoSet(inputStreams);
}
}

View File

@ -0,0 +1,50 @@
package com.baeldung.deserializationfilters.utils;
import java.io.InputStream;
import java.io.InvalidClassException;
import java.io.ObjectInputFilter;
import java.io.ObjectInputStream;
import java.util.Set;
import java.util.TreeSet;
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
public class DeserializationUtils {
private DeserializationUtils() {
}
public static Object deserialize(InputStream inStream) {
return deserialize(inStream, null);
}
public static Object deserialize(InputStream inStream, ObjectInputFilter filter) {
try (ObjectInputStream in = new ObjectInputStream(inStream)) {
if (filter != null) {
in.setObjectInputFilter(filter);
}
return in.readObject();
} catch (InvalidClassException e) {
return null;
} catch (Throwable e) {
e.printStackTrace();
return null;
}
}
public static Set<ContextSpecific> deserializeIntoSet(InputStream... inputStreams) {
return deserializeIntoSet(null, inputStreams);
}
public static Set<ContextSpecific> deserializeIntoSet(ObjectInputFilter filter, InputStream... inputStreams) {
Set<ContextSpecific> set = new TreeSet<>();
for (InputStream inputStream : inputStreams) {
Object object = deserialize(inputStream, filter);
if (object != null) {
set.add((ContextSpecific) object);
}
}
return set;
}
}

View File

@ -0,0 +1,32 @@
package com.baeldung.deserializationfilters.utils;
import java.io.ObjectInputFilter;
public class FilterUtils {
private static final String DEFAULT_PACKAGE_PATTERN = "java.base/*;!*";
private static final String POJO_PACKAGE = "com.baeldung.deserializationfilters.pojo";
private FilterUtils() {
}
private static ObjectInputFilter baseFilter(String parameter, int max) {
return ObjectInputFilter.Config.createFilter(String.format("%s=%d;%s.**;%s", parameter, max, POJO_PACKAGE, DEFAULT_PACKAGE_PATTERN));
}
public static ObjectInputFilter fallbackFilter() {
return ObjectInputFilter.Config.createFilter(String.format("%s", DEFAULT_PACKAGE_PATTERN));
}
public static ObjectInputFilter safeSizeFilter(int max) {
return baseFilter("maxbytes", max);
}
public static ObjectInputFilter safeArrayFilter(int max) {
return baseFilter("maxarray", max);
}
public static ObjectInputFilter safeDepthFilter(int max) {
return baseFilter("maxdepth", max);
}
}

View File

@ -0,0 +1,17 @@
package com.baeldung.deserializationfilters.utils;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
public class SerializationUtils {
private SerializationUtils() {
}
public static void serialize(Object object, OutputStream outStream) throws IOException {
try (ObjectOutputStream objStream = new ObjectOutputStream(outStream)) {
objStream.writeObject(object);
}
}
}

View File

@ -0,0 +1,119 @@
package com.baeldung.deserializationfilters;
import static org.junit.Assert.assertNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputFilter;
import java.util.Set;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import com.baeldung.deserializationfilters.pojo.ContextSpecific;
import com.baeldung.deserializationfilters.pojo.NestedSample;
import com.baeldung.deserializationfilters.pojo.Sample;
import com.baeldung.deserializationfilters.pojo.SampleExploit;
import com.baeldung.deserializationfilters.service.LimitedArrayService;
import com.baeldung.deserializationfilters.service.LowDepthService;
import com.baeldung.deserializationfilters.service.SmallObjectService;
import com.baeldung.deserializationfilters.utils.DeserializationUtils;
import com.baeldung.deserializationfilters.utils.FilterUtils;
import com.baeldung.deserializationfilters.utils.SerializationUtils;
public class ContextSpecificDeserializationFilterIntegrationTest {
private static ByteArrayOutputStream serialSampleA = new ByteArrayOutputStream();
private static ByteArrayOutputStream serialBigSampleA = new ByteArrayOutputStream();
private static ByteArrayOutputStream serialSampleB = new ByteArrayOutputStream();
private static ByteArrayOutputStream serialBigSampleB = new ByteArrayOutputStream();
private static ByteArrayOutputStream serialSampleC = new ByteArrayOutputStream();
private static ByteArrayOutputStream serialBigSampleC = new ByteArrayOutputStream();
private static ByteArrayInputStream bytes(ByteArrayOutputStream stream) {
return new ByteArrayInputStream(stream.toByteArray());
}
@BeforeAll
static void setup() throws IOException {
ObjectInputFilter.Config.setSerialFilterFactory(new ContextSpecificDeserializationFilterFactory());
SerializationUtils.serialize(new Sample("simple"), serialSampleA);
SerializationUtils.serialize(new SampleExploit(), serialBigSampleA);
SerializationUtils.serialize(new Sample(new int[] { 1, 2, 3 }), serialSampleB);
SerializationUtils.serialize(new Sample(new int[] { 1, 2, 3, 4, 5, 6 }), serialBigSampleB);
SerializationUtils.serialize(new Sample(new NestedSample(null)), serialSampleC);
SerializationUtils.serialize(new Sample(new NestedSample(new Sample("deep"))), serialBigSampleC);
}
@Test
void whenSmallObjectContext_thenCorrectFilterApplied() {
Set<ContextSpecific> result = new SmallObjectService().process( //
bytes(serialSampleA), //
bytes(serialBigSampleA));
assertEquals(1, result.size());
assertEquals("simple", ((Sample) result.iterator()
.next()).getName());
}
@Test
void whenLimitedArrayContext_thenCorrectFilterApplied() {
Set<ContextSpecific> result = new LimitedArrayService().process( //
bytes(serialSampleB), //
bytes(serialBigSampleB));
assertEquals(1, result.size());
}
@Test
void whenLowDepthContext_thenCorrectFilterApplied() {
Set<ContextSpecific> result = new LowDepthService().process( //
bytes(serialSampleC), //
bytes(serialBigSampleC));
assertEquals(1, result.size());
}
@Test
void givenExtraFilter_whenCombinedContext_thenMergedFiltersApplied() {
Set<ContextSpecific> result = new LowDepthService().process( //
FilterUtils.safeSizeFilter(190), //
bytes(serialSampleA), //
bytes(serialBigSampleA), //
bytes(serialSampleC), //
bytes(serialBigSampleC));
assertEquals(1, result.size());
assertEquals("simple", ((Sample) result.iterator()
.next()).getName());
}
@Test
void givenFallbackContext_whenUsingBaseClasses_thenRestrictiveFilterApplied() throws IOException {
String a = new String("a");
ByteArrayOutputStream outStream = new ByteArrayOutputStream();
SerializationUtils.serialize(a, outStream);
String deserializedA = (String) DeserializationUtils.deserialize(bytes(outStream));
assertEquals(a, deserializedA);
}
@Test
void givenFallbackContext_whenUsingAppClasses_thenRejected() throws IOException {
Sample a = new Sample("a");
ByteArrayOutputStream outStream = new ByteArrayOutputStream();
SerializationUtils.serialize(a, outStream);
Sample deserializedA = (Sample) DeserializationUtils.deserialize(bytes(outStream));
assertNull(deserializedA);
}
}