diff --git a/src/main/java/org/springframework/data/elasticsearch/core/DefaultResultMapper.java b/src/main/java/org/springframework/data/elasticsearch/core/DefaultResultMapper.java index 0add0f2c9..99d6fbb38 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/DefaultResultMapper.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/DefaultResultMapper.java @@ -25,22 +25,22 @@ import java.util.Collection; import java.util.LinkedList; import java.util.List; +import com.fasterxml.jackson.core.JsonEncoding; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; import org.apache.commons.lang.StringUtils; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.get.MultiGetItemResponse; import org.elasticsearch.action.get.MultiGetResponse; import org.elasticsearch.action.search.SearchResponse; -import com.fasterxml.jackson.core.JsonEncoding; -import com.fasterxml.jackson.core.JsonFactory; -import com.fasterxml.jackson.core.JsonGenerator; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHitField; import org.springframework.data.domain.Page; -import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.Pageable; import org.springframework.data.elasticsearch.ElasticsearchException; import org.springframework.data.elasticsearch.annotations.Document; import org.springframework.data.elasticsearch.annotations.ScriptedField; +import org.springframework.data.elasticsearch.core.domain.impl.AggregatedPageImpl; import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentEntity; import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentProperty; import org.springframework.data.mapping.PersistentProperty; @@ -48,6 +48,7 @@ import org.springframework.data.mapping.context.MappingContext; /** * @author Artur Konczak + * @author Petar Tahchiev */ public class DefaultResultMapper extends AbstractResultMapper { @@ -86,38 +87,39 @@ public class DefaultResultMapper extends AbstractResultMapper { result = mapEntity(hit.getFields().values(), clazz); } setPersistentEntityId(result, hit.getId(), clazz); - populateScriptFields(result, hit); + populateScriptFields(result, hit); results.add(result); } } - return new PageImpl(results, pageable, totalHits); + + return new AggregatedPageImpl(results, pageable, totalHits, response.getAggregations()); } - private void populateScriptFields(T result, SearchHit hit) { - if (hit.getFields() != null && !hit.getFields().isEmpty() && result != null) { - for (java.lang.reflect.Field field : result.getClass().getDeclaredFields()) { - ScriptedField scriptedField = field.getAnnotation(ScriptedField.class); - if (scriptedField != null) { - String name = scriptedField.name().isEmpty() ? field.getName() : scriptedField.name(); - SearchHitField searchHitField = hit.getFields().get(name); - if (searchHitField != null) { - field.setAccessible(true); - try { - field.set(result, searchHitField.getValue()); - } catch (IllegalArgumentException e) { - throw new ElasticsearchException("failed to set scripted field: " + name + " with value: " - + searchHitField.getValue(), e); - } catch (IllegalAccessException e) { - throw new ElasticsearchException("failed to access scripted field: " + name, e); - } - } - } - } - } - } + private void populateScriptFields(T result, SearchHit hit) { + if (hit.getFields() != null && !hit.getFields().isEmpty() && result != null) { + for (java.lang.reflect.Field field : result.getClass().getDeclaredFields()) { + ScriptedField scriptedField = field.getAnnotation(ScriptedField.class); + if (scriptedField != null) { + String name = scriptedField.name().isEmpty() ? field.getName() : scriptedField.name(); + SearchHitField searchHitField = hit.getFields().get(name); + if (searchHitField != null) { + field.setAccessible(true); + try { + field.set(result, searchHitField.getValue()); + } catch (IllegalArgumentException e) { + throw new ElasticsearchException("failed to set scripted field: " + name + " with value: " + + searchHitField.getValue(), e); + } catch (IllegalAccessException e) { + throw new ElasticsearchException("failed to access scripted field: " + name, e); + } + } + } + } + } + } - private T mapEntity(Collection values, Class clazz) { + private T mapEntity(Collection values, Class clazz) { return mapEntity(buildJSONFromFields(values), clazz); } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/SearchResultMapper.java b/src/main/java/org/springframework/data/elasticsearch/core/SearchResultMapper.java index e35e0c99f..ed315284e 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/SearchResultMapper.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/SearchResultMapper.java @@ -18,9 +18,11 @@ package org.springframework.data.elasticsearch.core; import org.elasticsearch.action.search.SearchResponse; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; +import org.springframework.data.elasticsearch.core.domain.AggregatedPage; /** * @author Artur Konczak + * @author Petar Tahchiev */ public interface SearchResultMapper { diff --git a/src/main/java/org/springframework/data/elasticsearch/core/domain/AggregatedPage.java b/src/main/java/org/springframework/data/elasticsearch/core/domain/AggregatedPage.java new file mode 100644 index 000000000..2d7a00c93 --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/core/domain/AggregatedPage.java @@ -0,0 +1,16 @@ +package org.springframework.data.elasticsearch.core.domain; + +import org.elasticsearch.search.aggregations.Aggregation; +import org.elasticsearch.search.aggregations.Aggregations; + +/** + * @author Petar Tahchiev + */ +public interface AggregatedPage { + + boolean hasAggregations(); + + Aggregations getAggregations(); + + Aggregation getAggregation(String name); +} diff --git a/src/main/java/org/springframework/data/elasticsearch/core/domain/impl/AggregatedPageImpl.java b/src/main/java/org/springframework/data/elasticsearch/core/domain/impl/AggregatedPageImpl.java new file mode 100644 index 000000000..fba63ebf5 --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/core/domain/impl/AggregatedPageImpl.java @@ -0,0 +1,53 @@ +package org.springframework.data.elasticsearch.core.domain.impl; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.elasticsearch.search.aggregations.Aggregation; +import org.elasticsearch.search.aggregations.Aggregations; +import org.springframework.data.domain.PageImpl; +import org.springframework.data.domain.Pageable; +import org.springframework.data.elasticsearch.core.domain.AggregatedPage; + +/** + * @author Petar Tahchiev + */ +public class AggregatedPageImpl extends PageImpl implements AggregatedPage { + + private Aggregations aggregations; + private Map mapOfAggregations = new HashMap(); + + public AggregatedPageImpl(List content) { + super(content); + } + + public AggregatedPageImpl(List content, Pageable pageable, long total) { + super(content, pageable, total); + } + + public AggregatedPageImpl(List content, Pageable pageable, long total, Aggregations aggregations) { + super(content, pageable, total); + this.aggregations = aggregations; + if (aggregations != null) { + for (Aggregation aggregation : aggregations) { + mapOfAggregations.put(aggregation.getName(), aggregation); + } + } + } + + @Override + public boolean hasAggregations() { + return aggregations != null && mapOfAggregations.size() > 0; + } + + @Override + public Aggregations getAggregations() { + return aggregations; + } + + @Override + public Aggregation getAggregation(String name) { + return aggregations == null ? null : aggregations.get(name); + } +} diff --git a/src/test/java/org/springframework/data/elasticsearch/core/DefaultResultMapperTests.java b/src/test/java/org/springframework/data/elasticsearch/core/DefaultResultMapperTests.java index 5d6c2b379..a00fba7a1 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/DefaultResultMapperTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/DefaultResultMapperTests.java @@ -19,9 +19,7 @@ import static org.hamcrest.Matchers.*; import static org.junit.Assert.*; import static org.mockito.Mockito.*; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; +import java.util.*; import com.fasterxml.jackson.databind.util.ArrayIterator; import org.elasticsearch.action.get.GetResponse; @@ -29,12 +27,15 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHitField; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.aggregations.Aggregation; +import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.internal.InternalSearchHitField; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.springframework.data.domain.Page; +import org.springframework.data.elasticsearch.core.domain.AggregatedPage; import org.springframework.data.elasticsearch.entities.Car; /** @@ -54,6 +55,31 @@ public class DefaultResultMapperTests { resultMapper = new DefaultResultMapper(); } + @Test + public void shouldMapAggregationsToPage() { + //Given + SearchHit[] hits = {createCarHit("Ford", "Grat"), createCarHit("BMW", "Arrow")}; + SearchHits searchHits = mock(SearchHits.class); + when(searchHits.totalHits()).thenReturn(2L); + when(searchHits.iterator()).thenReturn(new ArrayIterator(hits)); + when(response.getHits()).thenReturn(searchHits); + + Aggregation aggregationToReturn = createCarAggregation(); + Aggregations aggregations = mock(Aggregations.class); + Iterator iter = Collections.singletonList(aggregationToReturn).iterator(); + + when(aggregations.iterator()).thenReturn(iter); + when(aggregations.get("engine")).thenReturn(aggregationToReturn); + when(response.getAggregations()).thenReturn(aggregations); + + //When + AggregatedPage page = (AggregatedPage) resultMapper.mapResults(response, Car.class, null); + + //Then + assertThat(page.hasAggregations(), is(true)); + assertThat(page.getAggregation("engine").getName(), is("Diesel")); + } + @Test public void shouldMapSearchRequestToPage() { //Given @@ -105,6 +131,12 @@ public class DefaultResultMapperTests { assertThat(result.getName(), is("Ford")); } + private Aggregation createCarAggregation() { + Aggregation aggregation = mock(Aggregation.class); + when(aggregation.getName()).thenReturn("Diesel"); + return aggregation; + } + private SearchHit createCarHit(String name, String model) { SearchHit hit = mock(SearchHit.class); when(hit.sourceAsString()).thenReturn(createJsonCar(name, model));