1) added support for user defined indices and types in ElasticsearchOperations

2) added support for custom ResultMapper in ElasticsearchOperations
3) Fixed pagination issues
4) Fixed findAll methods that were returning 10 results (which is the elasticsearch default).
This commit is contained in:
Rizwan Idrees 2013-02-19 11:07:51 +00:00
parent 9c2a541d95
commit 0bc8ae54fe
7 changed files with 216 additions and 20 deletions

View File

@ -61,6 +61,17 @@ public interface ElasticsearchOperations {
*/
<T> Page<T> queryForPage(SearchQuery query, Class<T> clazz);
/**
* Execute the query against elasticsearch and return result as {@link Page}
*
* @param query
* @param resultsMapper
* @return
*/
<T> Page<T> queryForPage(SearchQuery query, ResultsMapper<T> resultsMapper);
/**
* Execute the query against elasticsearch and return result as {@link Page}
*

View File

@ -13,6 +13,7 @@ import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.Requests;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.index.query.FilterBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortOrder;
@ -46,7 +47,6 @@ public class ElasticsearchTemplate implements ElasticsearchOperations {
private Client client;
private ElasticsearchConverter elasticsearchConverter;
private ObjectMapper objectMapper = new ObjectMapper();
{
@ -98,12 +98,21 @@ public class ElasticsearchTemplate implements ElasticsearchOperations {
@Override
public <T> Page<T> queryForPage(SearchQuery query, Class<T> clazz) {
SearchRequestBuilder searchRequestBuilder = prepareSearch(query,clazz);
if(query.getElasticsearchFilter() != null){
searchRequestBuilder.setFilter(query.getElasticsearchFilter());
SearchResponse response = doSearch(prepareSearch(query,clazz), query.getElasticsearchQuery(), query.getElasticsearchFilter());
return mapResults(response, clazz, query.getPageable());
}
public <T> Page<T> queryForPage(SearchQuery query, ResultsMapper<T> resultsMapper) {
SearchResponse response = doSearch(prepareSearch(query), query.getElasticsearchQuery(), query.getElasticsearchFilter());
return resultsMapper.mapResults(response);
}
private SearchResponse doSearch(SearchRequestBuilder searchRequest, QueryBuilder query, FilterBuilder filter ){
if(filter != null){
searchRequest.setFilter(filter);
}
SearchResponse response = searchRequestBuilder.setQuery(query.getElasticsearchQuery()).execute().actionGet();
return mapResults(response, clazz, query.getPageable());
return searchRequest.setQuery(query).execute().actionGet();
}
@Override
@ -195,16 +204,31 @@ public class ElasticsearchTemplate implements ElasticsearchOperations {
}
private <T> SearchRequestBuilder prepareSearch(Query query, Class<T> clazz){
int startRecord=0;
if(query.getPageable() != null){
startRecord = ((query.getPageable().getPageNumber() - 1) * query.getPageable().getPageSize());
if(query.getIndices().isEmpty()){
query.addIndices(retrieveIndexNameFromPersistentEntity(clazz));
}
ElasticsearchPersistentEntity persistentEntity = getPersistentEntityFor(clazz);
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(persistentEntity.getIndexName())
if(query.getTypes().isEmpty()){
query.addTypes(retrieveTypeFromPersistentEntity(clazz));
}
return prepareSearch(query);
}
private SearchRequestBuilder prepareSearch(Query query){
int startRecord = 0;
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(toArray(query.getIndices()))
.setSearchType(DFS_QUERY_THEN_FETCH)
.setTypes(persistentEntity.getIndexType())
.setFrom(startRecord < 0 ? 0 : startRecord)
.setSize(query.getPageable() != null ? query.getPageable().getPageSize() : 10);
.setTypes(toArray(query.getTypes()));
if(query.getPageable() != null){
startRecord = query.getPageable().getPageNumber() * query.getPageable().getPageSize();
searchRequestBuilder.setSize(query.getPageable().getPageSize());
}
searchRequestBuilder.setFrom(startRecord);
if(!query.getFields().isEmpty()){
searchRequestBuilder.addFields(toArray(query.getFields()));
}
if(query.getSort() != null){
for(Sort.Order order : query.getSort()){
@ -216,8 +240,12 @@ public class ElasticsearchTemplate implements ElasticsearchOperations {
private IndexRequestBuilder prepareIndex(IndexQuery query){
try {
ElasticsearchPersistentEntity persistentEntity = getPersistentEntityFor(query.getObject().getClass());
IndexRequestBuilder indexRequestBuilder = client.prepareIndex(persistentEntity.getIndexName(), persistentEntity.getIndexType(), query.getId())
String indexName = isBlank(query.getIndexName())?
retrieveIndexNameFromPersistentEntity(query.getObject().getClass())[0] : query.getIndexName();
String type = isBlank(query.getType())?
retrieveTypeFromPersistentEntity(query.getObject().getClass())[0] : query.getType();
IndexRequestBuilder indexRequestBuilder = client.prepareIndex(indexName,type,query.getId())
.setSource(objectMapper.writeValueAsString(query.getObject()));
if(query.getVersion() != null){
indexRequestBuilder.setVersion(query.getVersion());
@ -246,6 +274,14 @@ public class ElasticsearchTemplate implements ElasticsearchOperations {
return elasticsearchConverter.getMappingContext().getPersistentEntity(clazz);
}
private String[] retrieveIndexNameFromPersistentEntity(Class clazz){
return new String[]{getPersistentEntityFor(clazz).getIndexName()};
}
private String[] retrieveTypeFromPersistentEntity(Class clazz){
return new String[]{getPersistentEntityFor(clazz).getIndexType()};
}
private <T> Page<T> mapResults(SearchResponse response, final Class<T> elementType,final Pageable pageable){
ResultsMapper<T> resultsMapper = new ResultsMapper<T>(){
@Override
@ -273,4 +309,12 @@ public class ElasticsearchTemplate implements ElasticsearchOperations {
throw new ElasticsearchException("failed to map source [ " + source + "] to class " + clazz.getSimpleName() , e);
}
}
private static String[] toArray(List<String> values){
String[] valuesAsArray = new String[values.size()];
return values.toArray(valuesAsArray);
}
}

View File

@ -20,16 +20,24 @@ import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.util.Assert;
import java.util.ArrayList;
import java.util.List;
import static org.apache.commons.collections.CollectionUtils.addAll;
/**
* AbstractQuery
*
*
*/
abstract class AbstractQuery implements Query{
abstract class AbstractQuery implements Query{
private static final Pageable DEFAULT_PAGE = new PageRequest(0, DEFAULT_PAGE_SIZE);
protected Pageable pageable = DEFAULT_PAGE;
protected Sort sort;
protected List<String> indices = new ArrayList<String>();
protected List<String> types = new ArrayList<String>();
protected List<String> fields = new ArrayList<String>();
@Override
public Sort getSort() {
@ -44,11 +52,40 @@ abstract class AbstractQuery implements Query{
@Override
public final <T extends Query> T setPageable(Pageable pageable) {
Assert.notNull(pageable);
this.pageable = pageable;
return (T) this.addSort(pageable.getSort());
}
@Override
public void addFields(String...fields) {
addAll(this.fields, fields);
}
@Override
public List<String> getFields() {
return fields;
}
@Override
public List<String> getIndices() {
return indices;
}
@Override
public void addIndices(String... indices) {
addAll(this.indices,indices);
}
@Override
public void addTypes(String... types) {
addAll(this.types,types);
}
@Override
public List<String> getTypes() {
return types;
}
@SuppressWarnings("unchecked")
public final <T extends Query> T addSort(Sort sort) {
if (sort == null) {
@ -63,5 +100,4 @@ abstract class AbstractQuery implements Query{
return (T) this;
}
}

View File

@ -6,6 +6,8 @@ public class IndexQuery{
private String id;
private Object object;
private Long version;
private String indexName;
private String type;
public String getId() {
return id;
@ -30,4 +32,20 @@ public class IndexQuery{
public void setVersion(Long version) {
this.version = version;
}
public String getIndexName() {
return indexName;
}
public void setIndexName(String indexName) {
this.indexName = indexName;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
}

View File

@ -4,6 +4,8 @@ package org.springframework.data.elasticsearch.core.query;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import java.util.List;
public interface Query {
int DEFAULT_PAGE_SIZE = 10;
@ -48,4 +50,42 @@ public interface Query {
*/
Sort getSort();
/**
* Get Indices to be searched
* @return
*/
List<String> getIndices();
/**
* Add Indices to be added as part of search request
* @param indices
*/
void addIndices(String...indices);
/**
* Add types to be searched
* @param types
*/
void addTypes(String...types);
/**
* Get types to be searched
* @return
*/
List<String> getTypes();
/**
* Add fields to be added as part of search request
* @param fields
*/
void addFields(String...fields);
/**
* Get fields to be returned as part of search request
* @return
*/
List<String> getFields();
}

View File

@ -89,6 +89,7 @@ public class SimpleElasticsearchRepository<T> implements ElasticsearchRepository
public Page<T> findAll(Pageable pageable) {
SearchQuery query = new SearchQuery();
query.setElasticsearchQuery(matchAllQuery());
query.setPageable(pageable);
return elasticsearchOperations.queryForPage(query, getEntityClass());
}
@ -161,6 +162,11 @@ public class SimpleElasticsearchRepository<T> implements ElasticsearchRepository
@Override
public Iterable<T> search(QueryBuilder elasticsearchQuery) {
SearchQuery query = new SearchQuery();
int count = (int) elasticsearchOperations.count(query, getEntityClass());
if(count == 0){
return new PageImpl<T>(Collections.<T>emptyList());
}
query.setPageable(new PageRequest(0,count));
query.setElasticsearchQuery(elasticsearchQuery);
return elasticsearchOperations.queryForPage(query, getEntityClass());
}

View File

@ -1,12 +1,15 @@
package org.springframework.data.elasticsearch.core;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.search.SearchHit;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Sort;
import org.springframework.data.elasticsearch.SampleEntity;
@ -374,4 +377,42 @@ public class ElasticsearchTemplateTest {
assertThat(sampleEntity1, is(notNullValue()));
}
@Test
public void shouldReturnSpecifiedFields(){
//given
String documentId = randomNumeric(5);
String message = "some test message";
SampleEntity sampleEntity = new SampleEntity();
sampleEntity.setId(documentId);
sampleEntity.setMessage(message);
sampleEntity.setVersion(System.currentTimeMillis());
IndexQuery indexQuery = new IndexQuery();
indexQuery.setId(documentId);
indexQuery.setObject(sampleEntity);
elasticsearchTemplate.index(indexQuery);
elasticsearchTemplate.refresh(SampleEntity.class, true);
SearchQuery searchQuery = new SearchQuery();
searchQuery.setElasticsearchQuery(matchAllQuery());
searchQuery.addFields("message");
searchQuery.addIndices("test-index");
searchQuery.addTypes("test-type");
//when
Page<String> page = elasticsearchTemplate.queryForPage(searchQuery, new ResultsMapper<String>() {
@Override
public Page<String> mapResults(SearchResponse response) {
List<String> values = new ArrayList<String>();
for(SearchHit searchHit : response.getHits()){
values.add((String) searchHit.field("message").value());
}
return new PageImpl<String>(values);
}
});
//then
assertThat(page, is(notNullValue()));
assertThat(page.getTotalElements(), is(equalTo(1L)));
assertThat(page.getContent().get(0), is(message));
}
}