Created AggregatedPage to hold the Aggregations

Create AggregatedPage and populate it in the DefaultResultMapper.
This commit is contained in:
Petar Tahchiev 2016-05-11 14:28:11 +02:00 committed by Artur Konczak
parent 8c77d314aa
commit fd06f8efd7
5 changed files with 137 additions and 32 deletions

View File

@ -25,22 +25,22 @@ import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import org.apache.commons.lang.StringUtils;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.get.MultiGetItemResponse;
import org.elasticsearch.action.get.MultiGetResponse;
import org.elasticsearch.action.search.SearchResponse;
import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.ElasticsearchException;
import org.springframework.data.elasticsearch.annotations.Document;
import org.springframework.data.elasticsearch.annotations.ScriptedField;
import org.springframework.data.elasticsearch.core.domain.impl.AggregatedPageImpl;
import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentEntity;
import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentProperty;
import org.springframework.data.mapping.PersistentProperty;
@ -48,6 +48,7 @@ import org.springframework.data.mapping.context.MappingContext;
/**
* @author Artur Konczak
* @author Petar Tahchiev
*/
public class DefaultResultMapper extends AbstractResultMapper {
@ -90,7 +91,8 @@ public class DefaultResultMapper extends AbstractResultMapper {
results.add(result);
}
}
return new PageImpl<T>(results, pageable, totalHits);
return new AggregatedPageImpl<T>(results, pageable, totalHits, response.getAggregations());
}
private <T> void populateScriptFields(T result, SearchHit hit) {

View File

@ -18,9 +18,11 @@ package org.springframework.data.elasticsearch.core;
import org.elasticsearch.action.search.SearchResponse;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.core.domain.AggregatedPage;
/**
* @author Artur Konczak
* @author Petar Tahchiev
*/
public interface SearchResultMapper {

View File

@ -0,0 +1,16 @@
package org.springframework.data.elasticsearch.core.domain;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.Aggregations;
/**
* @author Petar Tahchiev
*/
public interface AggregatedPage<T> {
boolean hasAggregations();
Aggregations getAggregations();
Aggregation getAggregation(String name);
}

View File

@ -0,0 +1,53 @@
package org.springframework.data.elasticsearch.core.domain.impl;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.Aggregations;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.core.domain.AggregatedPage;
/**
* @author Petar Tahchiev
*/
public class AggregatedPageImpl<T> extends PageImpl<T> implements AggregatedPage<T> {
private Aggregations aggregations;
private Map<String, Aggregation> mapOfAggregations = new HashMap<String, Aggregation>();
public AggregatedPageImpl(List<T> content) {
super(content);
}
public AggregatedPageImpl(List<T> content, Pageable pageable, long total) {
super(content, pageable, total);
}
public AggregatedPageImpl(List<T> content, Pageable pageable, long total, Aggregations aggregations) {
super(content, pageable, total);
this.aggregations = aggregations;
if (aggregations != null) {
for (Aggregation aggregation : aggregations) {
mapOfAggregations.put(aggregation.getName(), aggregation);
}
}
}
@Override
public boolean hasAggregations() {
return aggregations != null && mapOfAggregations.size() > 0;
}
@Override
public Aggregations getAggregations() {
return aggregations;
}
@Override
public Aggregation getAggregation(String name) {
return aggregations == null ? null : aggregations.get(name);
}
}

View File

@ -19,9 +19,7 @@ import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.*;
import com.fasterxml.jackson.databind.util.ArrayIterator;
import org.elasticsearch.action.get.GetResponse;
@ -29,12 +27,15 @@ import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.internal.InternalSearchHitField;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.data.domain.Page;
import org.springframework.data.elasticsearch.core.domain.AggregatedPage;
import org.springframework.data.elasticsearch.entities.Car;
/**
@ -54,6 +55,31 @@ public class DefaultResultMapperTests {
resultMapper = new DefaultResultMapper();
}
@Test
public void shouldMapAggregationsToPage() {
//Given
SearchHit[] hits = {createCarHit("Ford", "Grat"), createCarHit("BMW", "Arrow")};
SearchHits searchHits = mock(SearchHits.class);
when(searchHits.totalHits()).thenReturn(2L);
when(searchHits.iterator()).thenReturn(new ArrayIterator(hits));
when(response.getHits()).thenReturn(searchHits);
Aggregation aggregationToReturn = createCarAggregation();
Aggregations aggregations = mock(Aggregations.class);
Iterator<Aggregation> iter = Collections.singletonList(aggregationToReturn).iterator();
when(aggregations.iterator()).thenReturn(iter);
when(aggregations.get("engine")).thenReturn(aggregationToReturn);
when(response.getAggregations()).thenReturn(aggregations);
//When
AggregatedPage<Car> page = (AggregatedPage<Car>) resultMapper.mapResults(response, Car.class, null);
//Then
assertThat(page.hasAggregations(), is(true));
assertThat(page.getAggregation("engine").getName(), is("Diesel"));
}
@Test
public void shouldMapSearchRequestToPage() {
//Given
@ -105,6 +131,12 @@ public class DefaultResultMapperTests {
assertThat(result.getName(), is("Ford"));
}
private Aggregation createCarAggregation() {
Aggregation aggregation = mock(Aggregation.class);
when(aggregation.getName()).thenReturn("Diesel");
return aggregation;
}
private SearchHit createCarHit(String name, String model) {
SearchHit hit = mock(SearchHit.class);
when(hit.sourceAsString()).thenReturn(createJsonCar(name, model));