diff --git a/spring-data-mongodb/src/main/java/org/baeldung/event/CascadeCallback.java b/spring-data-mongodb/src/main/java/org/baeldung/event/CascadeCallback.java new file mode 100644 index 0000000000..2f433b6b44 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/baeldung/event/CascadeCallback.java @@ -0,0 +1,53 @@ +package org.baeldung.event; + +import java.lang.reflect.Field; + +import org.baeldung.annotation.CascadeSave; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.mapping.DBRef; +import org.springframework.util.ReflectionUtils; + +public class CascadeCallback implements ReflectionUtils.FieldCallback { + + private Object source; + private MongoOperations mongoOperations; + + public CascadeCallback(final Object source, final MongoOperations mongoOperations) { + this.source = source; + this.setMongoOperations(mongoOperations); + } + + @Override + public void doWith(Field field) throws IllegalArgumentException, IllegalAccessException { + ReflectionUtils.makeAccessible(field); + + if (field.isAnnotationPresent(DBRef.class) && field.isAnnotationPresent(CascadeSave.class)) { + final Object fieldValue = field.get(getSource()); + + if (fieldValue != null) { + FieldCallback callback = new FieldCallback(); + + ReflectionUtils.doWithFields(fieldValue.getClass(), callback); + + getMongoOperations().save(fieldValue); + } + } + + } + + public Object getSource() { + return source; + } + + public void setSource(Object source) { + this.source = source; + } + + public MongoOperations getMongoOperations() { + return mongoOperations; + } + + public void setMongoOperations(MongoOperations mongoOperations) { + this.mongoOperations = mongoOperations; + } +} diff --git a/spring-data-mongodb/src/main/java/org/baeldung/event/CascadeSaveMongoEventListener.java b/spring-data-mongodb/src/main/java/org/baeldung/event/CascadeSaveMongoEventListener.java index ad09f1ac04..ae79c1d92e 100644 --- a/spring-data-mongodb/src/main/java/org/baeldung/event/CascadeSaveMongoEventListener.java +++ b/spring-data-mongodb/src/main/java/org/baeldung/event/CascadeSaveMongoEventListener.java @@ -1,37 +1,17 @@ package org.baeldung.event; -import org.baeldung.annotation.CascadeSave; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.mongodb.core.MongoOperations; -import org.springframework.data.mongodb.core.mapping.DBRef; import org.springframework.data.mongodb.core.mapping.event.AbstractMongoEventListener; import org.springframework.util.ReflectionUtils; -import java.lang.reflect.Field; - public class CascadeSaveMongoEventListener extends AbstractMongoEventListener { + @Autowired private MongoOperations mongoOperations; - + @Override public void onBeforeConvert(final Object source) { - ReflectionUtils.doWithFields(source.getClass(), new ReflectionUtils.FieldCallback() { - - public void doWith(Field field) throws IllegalArgumentException, IllegalAccessException { - ReflectionUtils.makeAccessible(field); - - if (field.isAnnotationPresent(DBRef.class) && field.isAnnotationPresent(CascadeSave.class)) { - final Object fieldValue = field.get(source); - - if (fieldValue != null) { - FieldCallback callback = new FieldCallback(); - - ReflectionUtils.doWithFields(fieldValue.getClass(), callback); - - mongoOperations.save(fieldValue); - } - } - } - }); + ReflectionUtils.doWithFields(source.getClass(), new CascadeCallback(source, mongoOperations)); } -} +} \ No newline at end of file diff --git a/spring-data-mongodb/src/test/java/org/baeldung/mongotemplate/MongoTemplateQueryIntegrationTest.java b/spring-data-mongodb/src/test/java/org/baeldung/mongotemplate/MongoTemplateQueryIntegrationTest.java index 6082743bda..e1dc426cda 100644 --- a/spring-data-mongodb/src/test/java/org/baeldung/mongotemplate/MongoTemplateQueryIntegrationTest.java +++ b/spring-data-mongodb/src/test/java/org/baeldung/mongotemplate/MongoTemplateQueryIntegrationTest.java @@ -2,6 +2,7 @@ package org.baeldung.mongotemplate; import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertThat; +import static org.hamcrest.Matchers.nullValue; import java.util.List; @@ -25,6 +26,7 @@ import org.springframework.data.mongodb.core.query.Query; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = MongoConfig.class) public class MongoTemplateQueryIntegrationTest { @@ -189,6 +191,6 @@ public class MongoTemplateQueryIntegrationTest { user.setYearOfBirth(1985); mongoTemplate.insert(user); - assertThat(user.getYearOfBirth(), is(1985)); + assertThat(mongoTemplate.findOne(Query.query(Criteria.where("name").is("Alex")), User.class).getYearOfBirth(), is(nullValue())); } }