diff --git a/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/ContextSpecificDeserializationFilterFactory.java b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/ContextSpecificDeserializationFilterFactory.java new file mode 100644 index 0000000000..25a855487e --- /dev/null +++ b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/ContextSpecificDeserializationFilterFactory.java @@ -0,0 +1,47 @@ +package com.baeldung.deserializationfilters; + +import java.io.ObjectInputFilter; +import java.util.function.BinaryOperator; + +import com.baeldung.deserializationfilters.service.DeserializationService; +import com.baeldung.deserializationfilters.service.LimitedArrayService; +import com.baeldung.deserializationfilters.service.LowDepthService; +import com.baeldung.deserializationfilters.service.SmallObjectService; +import com.baeldung.deserializationfilters.utils.FilterUtils; + +public class ContextSpecificDeserializationFilterFactory implements BinaryOperator { + + @Override + public ObjectInputFilter apply(ObjectInputFilter current, ObjectInputFilter next) { + if (current == null) { + Class caller = findInStack(DeserializationService.class); + + if (caller == null) { + current = FilterUtils.fallbackFilter(); + } else if (caller.equals(SmallObjectService.class)) { + current = FilterUtils.safeSizeFilter(190); + } else if (caller.equals(LowDepthService.class)) { + current = FilterUtils.safeDepthFilter(2); + } else if (caller.equals(LimitedArrayService.class)) { + current = FilterUtils.safeArrayFilter(3); + } + } + + return ObjectInputFilter.merge(current, next); + } + + private static Class findInStack(Class superType) { + for (StackTraceElement element : Thread.currentThread() + .getStackTrace()) { + try { + Class subType = Class.forName(element.getClassName()); + if (superType.isAssignableFrom(subType)) { + return subType; + } + } catch (ClassNotFoundException e) { + return null; + } + } + return null; + } +} diff --git a/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/pojo/ContextSpecific.java b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/pojo/ContextSpecific.java new file mode 100644 index 0000000000..add827d280 --- /dev/null +++ b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/pojo/ContextSpecific.java @@ -0,0 +1,7 @@ +package com.baeldung.deserializationfilters.pojo; + +import java.io.Serializable; + +public interface ContextSpecific extends Serializable { + +} diff --git a/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/pojo/NestedSample.java b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/pojo/NestedSample.java new file mode 100644 index 0000000000..a1d41744e6 --- /dev/null +++ b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/pojo/NestedSample.java @@ -0,0 +1,19 @@ +package com.baeldung.deserializationfilters.pojo; + +public class NestedSample implements ContextSpecific { + private static final long serialVersionUID = 1L; + + private Sample optional; + + public NestedSample(Sample optional) { + this.optional = optional; + } + + public Sample getOptional() { + return optional; + } + + public void setOptional(Sample optional) { + this.optional = optional; + } +} diff --git a/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/pojo/Sample.java b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/pojo/Sample.java new file mode 100644 index 0000000000..fed3639c64 --- /dev/null +++ b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/pojo/Sample.java @@ -0,0 +1,61 @@ +package com.baeldung.deserializationfilters.pojo; + +public class Sample implements ContextSpecific, Comparable { + private static final long serialVersionUID = 1L; + + private int[] array; + private String name; + private NestedSample nested; + + public Sample(String name) { + this.name = name; + } + + public Sample(int[] array) { + this.array = array; + } + + public Sample(NestedSample nested) { + this.nested = nested; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int[] getArray() { + return array; + } + + public void setArray(int[] array) { + this.array = array; + } + + public NestedSample getNested() { + return nested; + } + + public void setNested(NestedSample nested) { + this.nested = nested; + } + + @Override + public String toString() { + return name; + } + + @Override + public int compareTo(Sample o) { + if (name == null) + return -1; + + if (o == null || o.getName() == null) + return 1; + + return getName().compareTo(o.getName()); + } +} diff --git a/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/pojo/SampleExploit.java b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/pojo/SampleExploit.java new file mode 100644 index 0000000000..24dce289c6 --- /dev/null +++ b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/pojo/SampleExploit.java @@ -0,0 +1,25 @@ +package com.baeldung.deserializationfilters.pojo; + +public class SampleExploit extends Sample { + private static final long serialVersionUID = 1L; + + public SampleExploit() { + super("exploit"); + } + + public static void maliciousCode() { + System.out.println("exploit executed"); + } + + @Override + public String toString() { + maliciousCode(); + return "exploit"; + } + + @Override + public int compareTo(Sample o) { + maliciousCode(); + return super.compareTo(o); + } +} diff --git a/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/service/DeserializationService.java b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/service/DeserializationService.java new file mode 100644 index 0000000000..9a66cb0e91 --- /dev/null +++ b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/service/DeserializationService.java @@ -0,0 +1,11 @@ +package com.baeldung.deserializationfilters.service; + +import java.io.InputStream; +import java.util.Set; + +import com.baeldung.deserializationfilters.pojo.ContextSpecific; + +public interface DeserializationService { + + Set process(InputStream... inputStreams); +} diff --git a/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/service/LimitedArrayService.java b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/service/LimitedArrayService.java new file mode 100644 index 0000000000..3aadbe7111 --- /dev/null +++ b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/service/LimitedArrayService.java @@ -0,0 +1,15 @@ +package com.baeldung.deserializationfilters.service; + +import java.io.InputStream; +import java.util.Set; + +import com.baeldung.deserializationfilters.pojo.ContextSpecific; +import com.baeldung.deserializationfilters.utils.DeserializationUtils; + +public class LimitedArrayService implements DeserializationService { + + @Override + public Set process(InputStream... inputStreams) { + return DeserializationUtils.deserializeIntoSet(inputStreams); + } +} diff --git a/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/service/LowDepthService.java b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/service/LowDepthService.java new file mode 100644 index 0000000000..69350c1399 --- /dev/null +++ b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/service/LowDepthService.java @@ -0,0 +1,20 @@ +package com.baeldung.deserializationfilters.service; + +import java.io.InputStream; +import java.io.ObjectInputFilter; +import java.util.Set; + +import com.baeldung.deserializationfilters.pojo.ContextSpecific; +import com.baeldung.deserializationfilters.utils.DeserializationUtils; + +public class LowDepthService implements DeserializationService { + + public Set process(ObjectInputFilter filter, InputStream... inputStreams) { + return DeserializationUtils.deserializeIntoSet(filter, inputStreams); + } + + @Override + public Set process(InputStream... inputStreams) { + return process(null, inputStreams); + } +} diff --git a/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/service/SmallObjectService.java b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/service/SmallObjectService.java new file mode 100644 index 0000000000..a0690276b7 --- /dev/null +++ b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/service/SmallObjectService.java @@ -0,0 +1,15 @@ +package com.baeldung.deserializationfilters.service; + +import java.io.InputStream; +import java.util.Set; + +import com.baeldung.deserializationfilters.pojo.ContextSpecific; +import com.baeldung.deserializationfilters.utils.DeserializationUtils; + +public class SmallObjectService implements DeserializationService { + + @Override + public Set process(InputStream... inputStreams) { + return DeserializationUtils.deserializeIntoSet(inputStreams); + } +} diff --git a/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/utils/DeserializationUtils.java b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/utils/DeserializationUtils.java new file mode 100644 index 0000000000..54db823102 --- /dev/null +++ b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/utils/DeserializationUtils.java @@ -0,0 +1,50 @@ +package com.baeldung.deserializationfilters.utils; + +import java.io.InputStream; +import java.io.InvalidClassException; +import java.io.ObjectInputFilter; +import java.io.ObjectInputStream; +import java.util.Set; +import java.util.TreeSet; + +import com.baeldung.deserializationfilters.pojo.ContextSpecific; + +public class DeserializationUtils { + private DeserializationUtils() { + } + + public static Object deserialize(InputStream inStream) { + return deserialize(inStream, null); + } + + public static Object deserialize(InputStream inStream, ObjectInputFilter filter) { + try (ObjectInputStream in = new ObjectInputStream(inStream)) { + if (filter != null) { + in.setObjectInputFilter(filter); + } + return in.readObject(); + } catch (InvalidClassException e) { + return null; + } catch (Throwable e) { + e.printStackTrace(); + return null; + } + } + + public static Set deserializeIntoSet(InputStream... inputStreams) { + return deserializeIntoSet(null, inputStreams); + } + + public static Set deserializeIntoSet(ObjectInputFilter filter, InputStream... inputStreams) { + Set set = new TreeSet<>(); + + for (InputStream inputStream : inputStreams) { + Object object = deserialize(inputStream, filter); + if (object != null) { + set.add((ContextSpecific) object); + } + } + + return set; + } +} diff --git a/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/utils/FilterUtils.java b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/utils/FilterUtils.java new file mode 100644 index 0000000000..fac69a94b9 --- /dev/null +++ b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/utils/FilterUtils.java @@ -0,0 +1,32 @@ +package com.baeldung.deserializationfilters.utils; + +import java.io.ObjectInputFilter; + +public class FilterUtils { + + private static final String DEFAULT_PACKAGE_PATTERN = "java.base/*;!*"; + private static final String POJO_PACKAGE = "com.baeldung.deserializationfilters.pojo"; + + private FilterUtils() { + } + + private static ObjectInputFilter baseFilter(String parameter, int max) { + return ObjectInputFilter.Config.createFilter(String.format("%s=%d;%s.**;%s", parameter, max, POJO_PACKAGE, DEFAULT_PACKAGE_PATTERN)); + } + + public static ObjectInputFilter fallbackFilter() { + return ObjectInputFilter.Config.createFilter(String.format("%s", DEFAULT_PACKAGE_PATTERN)); + } + + public static ObjectInputFilter safeSizeFilter(int max) { + return baseFilter("maxbytes", max); + } + + public static ObjectInputFilter safeArrayFilter(int max) { + return baseFilter("maxarray", max); + } + + public static ObjectInputFilter safeDepthFilter(int max) { + return baseFilter("maxdepth", max); + } +} diff --git a/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/utils/SerializationUtils.java b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/utils/SerializationUtils.java new file mode 100644 index 0000000000..4f62e5d46b --- /dev/null +++ b/core-java-modules/core-java-17/src/main/java/com/baeldung/deserializationfilters/utils/SerializationUtils.java @@ -0,0 +1,17 @@ +package com.baeldung.deserializationfilters.utils; + +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.io.OutputStream; + +public class SerializationUtils { + + private SerializationUtils() { + } + + public static void serialize(Object object, OutputStream outStream) throws IOException { + try (ObjectOutputStream objStream = new ObjectOutputStream(outStream)) { + objStream.writeObject(object); + } + } +} diff --git a/core-java-modules/core-java-17/src/test/java/com/baeldung/deserializationfilters/ContextSpecificDeserializationFilterIntegrationTest.java b/core-java-modules/core-java-17/src/test/java/com/baeldung/deserializationfilters/ContextSpecificDeserializationFilterIntegrationTest.java new file mode 100644 index 0000000000..3e7de20070 --- /dev/null +++ b/core-java-modules/core-java-17/src/test/java/com/baeldung/deserializationfilters/ContextSpecificDeserializationFilterIntegrationTest.java @@ -0,0 +1,119 @@ +package com.baeldung.deserializationfilters; + +import static org.junit.Assert.assertNull; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputFilter; +import java.util.Set; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import com.baeldung.deserializationfilters.pojo.ContextSpecific; +import com.baeldung.deserializationfilters.pojo.NestedSample; +import com.baeldung.deserializationfilters.pojo.Sample; +import com.baeldung.deserializationfilters.pojo.SampleExploit; +import com.baeldung.deserializationfilters.service.LimitedArrayService; +import com.baeldung.deserializationfilters.service.LowDepthService; +import com.baeldung.deserializationfilters.service.SmallObjectService; +import com.baeldung.deserializationfilters.utils.DeserializationUtils; +import com.baeldung.deserializationfilters.utils.FilterUtils; +import com.baeldung.deserializationfilters.utils.SerializationUtils; + +public class ContextSpecificDeserializationFilterIntegrationTest { + + private static ByteArrayOutputStream serialSampleA = new ByteArrayOutputStream(); + private static ByteArrayOutputStream serialBigSampleA = new ByteArrayOutputStream(); + + private static ByteArrayOutputStream serialSampleB = new ByteArrayOutputStream(); + private static ByteArrayOutputStream serialBigSampleB = new ByteArrayOutputStream(); + + private static ByteArrayOutputStream serialSampleC = new ByteArrayOutputStream(); + private static ByteArrayOutputStream serialBigSampleC = new ByteArrayOutputStream(); + + private static ByteArrayInputStream bytes(ByteArrayOutputStream stream) { + return new ByteArrayInputStream(stream.toByteArray()); + } + + @BeforeAll + static void setup() throws IOException { + ObjectInputFilter.Config.setSerialFilterFactory(new ContextSpecificDeserializationFilterFactory()); + + SerializationUtils.serialize(new Sample("simple"), serialSampleA); + SerializationUtils.serialize(new SampleExploit(), serialBigSampleA); + + SerializationUtils.serialize(new Sample(new int[] { 1, 2, 3 }), serialSampleB); + SerializationUtils.serialize(new Sample(new int[] { 1, 2, 3, 4, 5, 6 }), serialBigSampleB); + + SerializationUtils.serialize(new Sample(new NestedSample(null)), serialSampleC); + SerializationUtils.serialize(new Sample(new NestedSample(new Sample("deep"))), serialBigSampleC); + } + + @Test + void whenSmallObjectContext_thenCorrectFilterApplied() { + Set result = new SmallObjectService().process( // + bytes(serialSampleA), // + bytes(serialBigSampleA)); + + assertEquals(1, result.size()); + assertEquals("simple", ((Sample) result.iterator() + .next()).getName()); + } + + @Test + void whenLimitedArrayContext_thenCorrectFilterApplied() { + Set result = new LimitedArrayService().process( // + bytes(serialSampleB), // + bytes(serialBigSampleB)); + + assertEquals(1, result.size()); + } + + @Test + void whenLowDepthContext_thenCorrectFilterApplied() { + Set result = new LowDepthService().process( // + bytes(serialSampleC), // + bytes(serialBigSampleC)); + + assertEquals(1, result.size()); + } + + @Test + void givenExtraFilter_whenCombinedContext_thenMergedFiltersApplied() { + Set result = new LowDepthService().process( // + FilterUtils.safeSizeFilter(190), // + bytes(serialSampleA), // + bytes(serialBigSampleA), // + bytes(serialSampleC), // + bytes(serialBigSampleC)); + + assertEquals(1, result.size()); + assertEquals("simple", ((Sample) result.iterator() + .next()).getName()); + } + + @Test + void givenFallbackContext_whenUsingBaseClasses_thenRestrictiveFilterApplied() throws IOException { + String a = new String("a"); + ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + SerializationUtils.serialize(a, outStream); + + String deserializedA = (String) DeserializationUtils.deserialize(bytes(outStream)); + + assertEquals(a, deserializedA); + } + + @Test + void givenFallbackContext_whenUsingAppClasses_thenRejected() throws IOException { + Sample a = new Sample("a"); + ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + SerializationUtils.serialize(a, outStream); + + Sample deserializedA = (Sample) DeserializationUtils.deserialize(bytes(outStream)); + + assertNull(deserializedA); + } +}