diff --git a/src/main/java/org/springframework/data/elasticsearch/repository/ElasticsearchRepository.java b/src/main/java/org/springframework/data/elasticsearch/repository/ElasticsearchRepository.java index 41b233c36..232e3355b 100644 --- a/src/main/java/org/springframework/data/elasticsearch/repository/ElasticsearchRepository.java +++ b/src/main/java/org/springframework/data/elasticsearch/repository/ElasticsearchRepository.java @@ -32,10 +32,15 @@ import org.springframework.data.repository.NoRepositoryBean; @NoRepositoryBean public interface ElasticsearchRepository extends ElasticsearchCrudRepository { - S indexWithoutRefresh(S entity); - S index(S entity); + /** + * This method is intended to be used when many single inserts must be made that cannot be aggregated to be inserted + * with {@link #saveAll(Iterable)}. This might lead to a temporary inconsistent state until {@link #refresh()} is + * called. + */ + S indexWithoutRefresh(S entity); + Iterable search(QueryBuilder query); Page search(QueryBuilder query, Pageable pageable); diff --git a/src/main/java/org/springframework/data/elasticsearch/repository/support/AbstractElasticsearchRepository.java b/src/main/java/org/springframework/data/elasticsearch/repository/support/AbstractElasticsearchRepository.java index afc240162..ce8534e5a 100644 --- a/src/main/java/org/springframework/data/elasticsearch/repository/support/AbstractElasticsearchRepository.java +++ b/src/main/java/org/springframework/data/elasticsearch/repository/support/AbstractElasticsearchRepository.java @@ -61,6 +61,7 @@ import org.springframework.util.Assert; public abstract class AbstractElasticsearchRepository implements ElasticsearchRepository { static final Logger LOGGER = LoggerFactory.getLogger(AbstractElasticsearchRepository.class); + protected ElasticsearchOperations elasticsearchOperations; protected Class entityClass; protected ElasticsearchEntityInformation entityInformation; @@ -76,6 +77,7 @@ public abstract class AbstractElasticsearchRepository implements Elastics public AbstractElasticsearchRepository(ElasticsearchEntityInformation metadata, ElasticsearchOperations elasticsearchOperations) { + this(elasticsearchOperations); Assert.notNull(metadata, "ElasticsearchEntityInformation must not be null!"); @@ -93,19 +95,23 @@ public abstract class AbstractElasticsearchRepository implements Elastics } private void createIndex() { + elasticsearchOperations.createIndex(getEntityClass()); } private void putMapping() { + elasticsearchOperations.putMapping(getEntityClass()); } private boolean createIndexAndMapping() { + return elasticsearchOperations.getPersistentEntityFor(getEntityClass()).isCreateIndexAndMapping(); } @Override public Optional findById(ID id) { + GetQuery query = new GetQuery(); query.setId(stringIdRepresentation(id)); return Optional.ofNullable(elasticsearchOperations.queryForObject(query, getEntityClass())); @@ -113,134 +119,167 @@ public abstract class AbstractElasticsearchRepository implements Elastics @Override public Iterable findAll() { + int itemCount = (int) this.count(); if (itemCount == 0) { return new PageImpl<>(Collections. emptyList()); } + return this.findAll(PageRequest.of(0, Math.max(1, itemCount))); } @Override public Page findAll(Pageable pageable) { + SearchQuery query = new NativeSearchQueryBuilder().withQuery(matchAllQuery()).withPageable(pageable).build(); + return elasticsearchOperations.queryForPage(query, getEntityClass()); } @Override public Iterable findAll(Sort sort) { + int itemCount = (int) this.count(); if (itemCount == 0) { return new PageImpl<>(Collections. emptyList()); } + SearchQuery query = new NativeSearchQueryBuilder().withQuery(matchAllQuery()) .withPageable(PageRequest.of(0, itemCount, sort)).build(); + return elasticsearchOperations.queryForPage(query, getEntityClass()); } @Override public Iterable findAllById(Iterable ids) { + Assert.notNull(ids, "ids can't be null."); + SearchQuery query = new NativeSearchQueryBuilder().withIds(stringIdsRepresentation(ids)).build(); + return elasticsearchOperations.multiGet(query, getEntityClass()); } @Override public long count() { + SearchQuery query = new NativeSearchQueryBuilder().withQuery(matchAllQuery()).build(); + return elasticsearchOperations.count(query, getEntityClass()); } @Override public S save(S entity) { + Assert.notNull(entity, "Cannot save 'null' entity."); + elasticsearchOperations.index(createIndexQuery(entity)); elasticsearchOperations.refresh(entityInformation.getIndexName()); + return entity; } public List save(List entities) { + Assert.notNull(entities, "Cannot insert 'null' as a List."); Assert.notEmpty(entities, "Cannot insert empty List."); + List queries = new ArrayList<>(); for (S s : entities) { queries.add(createIndexQuery(s)); } elasticsearchOperations.bulkIndex(queries); elasticsearchOperations.refresh(entityInformation.getIndexName()); + return entities; } @Override public S index(S entity) { + return save(entity); } - /** - * This method might lead to a temporary inconsistent state until - * {@link org.springframework.data.elasticsearch.repository.ElasticsearchRepository#refresh() refresh} is called. - */ @Override public S indexWithoutRefresh(S entity) { + Assert.notNull(entity, "Cannot save 'null' entity."); + elasticsearchOperations.index(createIndexQuery(entity)); + return entity; } @Override public Iterable saveAll(Iterable entities) { + Assert.notNull(entities, "Cannot insert 'null' as a List."); + List queries = new ArrayList<>(); for (S s : entities) { queries.add(createIndexQuery(s)); } elasticsearchOperations.bulkIndex(queries); elasticsearchOperations.refresh(entityInformation.getIndexName()); + return entities; } @Override public boolean existsById(ID id) { + return findById(id).isPresent(); } @Override public Iterable search(QueryBuilder query) { + SearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(query).build(); int count = (int) elasticsearchOperations.count(searchQuery, getEntityClass()); if (count == 0) { return new PageImpl<>(Collections. emptyList()); } + searchQuery.setPageable(PageRequest.of(0, count)); + return elasticsearchOperations.queryForPage(searchQuery, getEntityClass()); } @Override public Page search(QueryBuilder query, Pageable pageable) { + SearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(query).withPageable(pageable).build(); + return elasticsearchOperations.queryForPage(searchQuery, getEntityClass()); } @Override public Page search(SearchQuery query) { + return elasticsearchOperations.queryForPage(query, getEntityClass()); } @Override public Page searchSimilar(T entity, String[] fields, Pageable pageable) { + Assert.notNull(entity, "Cannot search similar records for 'null'."); Assert.notNull(pageable, "'pageable' cannot be 'null'"); + MoreLikeThisQuery query = new MoreLikeThisQuery(); query.setId(stringIdRepresentation(extractIdFromBean(entity))); query.setPageable(pageable); if (fields != null) { query.addFields(fields); } + return elasticsearchOperations.moreLikeThis(query, getEntityClass()); } @Override public void deleteById(ID id) { + Assert.notNull(id, "Cannot delete entity with id 'null'."); + elasticsearchOperations.delete(entityInformation.getIndexName(), entityInformation.getType(), stringIdRepresentation(id)); elasticsearchOperations.refresh(entityInformation.getIndexName()); @@ -248,13 +287,16 @@ public abstract class AbstractElasticsearchRepository implements Elastics @Override public void delete(T entity) { + Assert.notNull(entity, "Cannot delete 'null' entity."); + deleteById(extractIdFromBean(entity)); elasticsearchOperations.refresh(entityInformation.getIndexName()); } @Override public void deleteAll(Iterable entities) { + Assert.notNull(entities, "Cannot delete 'null' list."); for (T entity : entities) { delete(entity); @@ -263,6 +305,7 @@ public abstract class AbstractElasticsearchRepository implements Elastics @Override public void deleteAll() { + DeleteQuery deleteQuery = new DeleteQuery(); deleteQuery.setQuery(matchAllQuery()); elasticsearchOperations.delete(deleteQuery, getEntityClass()); @@ -271,10 +314,12 @@ public abstract class AbstractElasticsearchRepository implements Elastics @Override public void refresh() { + elasticsearchOperations.refresh(getEntityClass()); } private IndexQuery createIndexQuery(T entity) { + IndexQuery query = new IndexQuery(); query.setObject(entity); query.setId(stringIdRepresentation(extractIdFromBean(entity))); @@ -285,11 +330,13 @@ public abstract class AbstractElasticsearchRepository implements Elastics @SuppressWarnings("unchecked") private Class resolveReturnedClassFromGenericType() { + ParameterizedType parameterizedType = resolveReturnedClassFromGenericType(getClass()); return (Class) parameterizedType.getActualTypeArguments()[0]; } private ParameterizedType resolveReturnedClassFromGenericType(Class clazz) { + Object genericSuperclass = clazz.getGenericSuperclass(); if (genericSuperclass instanceof ParameterizedType) { ParameterizedType parameterizedType = (ParameterizedType) genericSuperclass; @@ -298,11 +345,13 @@ public abstract class AbstractElasticsearchRepository implements Elastics return parameterizedType; } } + return resolveReturnedClassFromGenericType(clazz.getSuperclass()); } @Override public Class getEntityClass() { + if (!isEntityClassSet()) { try { this.entityClass = resolveReturnedClassFromGenericType(); @@ -310,20 +359,25 @@ public abstract class AbstractElasticsearchRepository implements Elastics throw new InvalidDataAccessApiUsageException("Unable to resolve EntityClass. Please use according setter!", e); } } + return entityClass; } private boolean isEntityClassSet() { + return entityClass != null; } public final void setEntityClass(Class entityClass) { + Assert.notNull(entityClass, "EntityClass must not be null."); this.entityClass = entityClass; } public final void setElasticsearchOperations(ElasticsearchOperations elasticsearchOperations) { + Assert.notNull(elasticsearchOperations, "ElasticsearchOperations must not be null."); + this.elasticsearchOperations = elasticsearchOperations; } @@ -332,11 +386,14 @@ public abstract class AbstractElasticsearchRepository implements Elastics } private List stringIdsRepresentation(Iterable ids) { + Assert.notNull(ids, "ids can't be null."); + List stringIds = new ArrayList<>(); for (ID id : ids) { stringIds.add(stringIdRepresentation(id)); } + return stringIds; } diff --git a/src/test/java/org/springframework/data/elasticsearch/repository/support/simple/SimpleElasticsearchRepositoryTests.java b/src/test/java/org/springframework/data/elasticsearch/repository/support/simple/SimpleElasticsearchRepositoryTests.java index 58368ba9c..5e228972b 100644 --- a/src/test/java/org/springframework/data/elasticsearch/repository/support/simple/SimpleElasticsearchRepositoryTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/repository/support/simple/SimpleElasticsearchRepositoryTests.java @@ -35,7 +35,6 @@ import org.elasticsearch.action.ActionRequestValidationException; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.annotation.Id; import org.springframework.data.annotation.Version; @@ -570,26 +569,28 @@ public class SimpleElasticsearchRepositoryTests { assertThat(entities.getTotalElements()).isEqualTo(1L); } - @Test - public void shouldIndexWithoutRefreshEntity() { + @Test + public void shouldIndexWithoutRefreshEntity() { - // given - String documentId = randomNumeric(5); - SampleEntity sampleEntity = new SampleEntity(); - sampleEntity.setId(documentId); - sampleEntity.setVersion(System.currentTimeMillis()); - sampleEntity.setMessage("some message"); + // given + String documentId = randomNumeric(5); + SampleEntity sampleEntity = new SampleEntity(); + sampleEntity.setId(documentId); + sampleEntity.setVersion(System.currentTimeMillis()); + sampleEntity.setMessage("some message"); - // when - repository.indexWithoutRefresh(sampleEntity); + // when + repository.indexWithoutRefresh(sampleEntity); - // then - Page entities = repository.search(termQuery("id", documentId), PageRequest.of(0, 50)); - assertThat(entities.getTotalElements()).isEqualTo(0L); - repository.refresh(); - entities = repository.search(termQuery("id", documentId), PageRequest.of(0, 50)); - assertThat(entities.getTotalElements()).isEqualTo(1L); - } + // then + Page entities = repository.search(termQuery("id", documentId), PageRequest.of(0, 50)); + assertThat(entities.getTotalElements()).isEqualTo(0L); + + repository.refresh(); + + entities = repository.search(termQuery("id", documentId), PageRequest.of(0, 50)); + assertThat(entities.getTotalElements()).isEqualTo(1L); + } @Test public void shouldSortByGivenField() {