diff --git a/src/main/java/org/springframework/data/elasticsearch/annotations/Score.java b/src/main/java/org/springframework/data/elasticsearch/annotations/Score.java new file mode 100644 index 000000000..b64bfe120 --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/annotations/Score.java @@ -0,0 +1,23 @@ +package org.springframework.data.elasticsearch.annotations; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Inherited; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.data.annotation.ReadOnlyProperty; + +/** + * Specifies that this field is used for storing the document score. + * + * @author Sascha Woo + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.FIELD) +@Documented +@Inherited +@ReadOnlyProperty +public @interface Score { +} diff --git a/src/main/java/org/springframework/data/elasticsearch/core/DefaultResultMapper.java b/src/main/java/org/springframework/data/elasticsearch/core/DefaultResultMapper.java index 5932027a6..3a20182a3 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/DefaultResultMapper.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/DefaultResultMapper.java @@ -54,6 +54,7 @@ import com.fasterxml.jackson.core.JsonGenerator; * @author Chris White * @author Mark Paluch * @author Ilkang Na + * @author Sascha Woo */ public class DefaultResultMapper extends AbstractResultMapper { @@ -82,6 +83,8 @@ public class DefaultResultMapper extends AbstractResultMapper { @Override public AggregatedPage mapResults(SearchResponse response, Class clazz, Pageable pageable) { long totalHits = response.getHits().getTotalHits(); + float maxScore = response.getHits().getMaxScore(); + List results = new ArrayList<>(); for (SearchHit hit : response.getHits()) { if (hit != null) { @@ -91,14 +94,17 @@ public class DefaultResultMapper extends AbstractResultMapper { } else { result = mapEntity(hit.getFields().values(), clazz); } + setPersistentEntityId(result, hit.getId(), clazz); setPersistentEntityVersion(result, hit.getVersion(), clazz); + setPersistentEntityScore(result, hit.getScore(), clazz); populateScriptFields(result, hit); results.add(result); } } - return new AggregatedPageImpl(results, pageable, totalHits, response.getAggregations(), response.getScrollId()); + return new AggregatedPageImpl(results, pageable, totalHits, response.getAggregations(), response.getScrollId(), + maxScore); } private void populateScriptFields(T result, SearchHit hit) { @@ -113,8 +119,8 @@ public class DefaultResultMapper extends AbstractResultMapper { try { field.set(result, searchHitField.getValue()); } catch (IllegalArgumentException e) { - throw new ElasticsearchException("failed to set scripted field: " + name + " with value: " - + searchHitField.getValue(), e); + throw new ElasticsearchException( + "failed to set scripted field: " + name + " with value: " + searchHitField.getValue(), e); } catch (IllegalAccessException e) { throw new ElasticsearchException("failed to access scripted field: " + name, e); } @@ -178,9 +184,7 @@ public class DefaultResultMapper extends AbstractResultMapper { } private void setPersistentEntityId(T result, String id, Class clazz) { - if (mappingContext != null && clazz.isAnnotationPresent(Document.class)) { - ElasticsearchPersistentEntity persistentEntity = mappingContext.getRequiredPersistentEntity(clazz); ElasticsearchPersistentProperty idProperty = persistentEntity.getIdProperty(); @@ -188,13 +192,11 @@ public class DefaultResultMapper extends AbstractResultMapper { if (idProperty != null && idProperty.getType().isAssignableFrom(String.class)) { persistentEntity.getPropertyAccessor(result).setProperty(idProperty, id); } - } } private void setPersistentEntityVersion(T result, long version, Class clazz) { if (mappingContext != null && clazz.isAnnotationPresent(Document.class)) { - ElasticsearchPersistentEntity persistentEntity = mappingContext.getPersistentEntity(clazz); ElasticsearchPersistentProperty versionProperty = persistentEntity.getVersionProperty(); @@ -207,4 +209,16 @@ public class DefaultResultMapper extends AbstractResultMapper { } } } + + private void setPersistentEntityScore(T result, float score, Class clazz) { + if (mappingContext != null && clazz.isAnnotationPresent(Document.class)) { + ElasticsearchPersistentEntity persistentEntity = mappingContext.getRequiredPersistentEntity(clazz); + ElasticsearchPersistentProperty scoreProperty = persistentEntity.getScoreProperty(); + Class type = scoreProperty.getType(); + + if (scoreProperty != null && (type == Float.class || type == Float.TYPE)) { + persistentEntity.getPropertyAccessor(result).setProperty(scoreProperty, score); + } + } + } } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ScoredPage.java b/src/main/java/org/springframework/data/elasticsearch/core/ScoredPage.java new file mode 100644 index 000000000..64e61a221 --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/core/ScoredPage.java @@ -0,0 +1,16 @@ + +package org.springframework.data.elasticsearch.core; + +import org.springframework.data.domain.Page; + +/** + * A score-aware page gaining information about max score. + * + * @param + * @author Sascha Woo + */ +public interface ScoredPage extends Page { + + float getMaxScore(); + +} diff --git a/src/main/java/org/springframework/data/elasticsearch/core/aggregation/AggregatedPage.java b/src/main/java/org/springframework/data/elasticsearch/core/aggregation/AggregatedPage.java index 0e446e145..c1ff7500d 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/aggregation/AggregatedPage.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/aggregation/AggregatedPage.java @@ -3,12 +3,14 @@ package org.springframework.data.elasticsearch.core.aggregation; import org.elasticsearch.search.aggregations.Aggregation; import org.elasticsearch.search.aggregations.Aggregations; import org.springframework.data.elasticsearch.core.FacetedPage; +import org.springframework.data.elasticsearch.core.ScoredPage; import org.springframework.data.elasticsearch.core.ScrolledPage; /** * @author Petar Tahchiev + * @author Sascha Woo */ -public interface AggregatedPage extends FacetedPage, ScrolledPage { +public interface AggregatedPage extends FacetedPage, ScrolledPage, ScoredPage { boolean hasAggregations(); diff --git a/src/main/java/org/springframework/data/elasticsearch/core/aggregation/impl/AggregatedPageImpl.java b/src/main/java/org/springframework/data/elasticsearch/core/aggregation/impl/AggregatedPageImpl.java index aeb301c55..caf8b8456 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/aggregation/impl/AggregatedPageImpl.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/aggregation/impl/AggregatedPageImpl.java @@ -15,9 +15,7 @@ */ package org.springframework.data.elasticsearch.core.aggregation.impl; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.elasticsearch.search.aggregations.Aggregation; import org.elasticsearch.search.aggregations.Aggregations; @@ -29,55 +27,77 @@ import org.springframework.data.elasticsearch.core.aggregation.AggregatedPage; * @author Petar Tahchiev * @author Artur Konczak * @author Mohsin Husen + * @author Sascha Woo */ public class AggregatedPageImpl extends FacetedPageImpl implements AggregatedPage { private Aggregations aggregations; - private Map mapOfAggregations = new HashMap<>(); - private String scrollId; + private String scrollId; + private float maxScore; public AggregatedPageImpl(List content) { super(content); } + public AggregatedPageImpl(List content, float maxScore) { + super(content); + this.maxScore = maxScore; + } + public AggregatedPageImpl(List content, String scrollId) { super(content); this.scrollId = scrollId; } + public AggregatedPageImpl(List content, String scrollId, float maxScore) { + this(content, scrollId); + this.maxScore = maxScore; + } + public AggregatedPageImpl(List content, Pageable pageable, long total) { super(content, pageable, total); } + public AggregatedPageImpl(List content, Pageable pageable, long total, float maxScore) { + super(content, pageable, total); + this.maxScore = maxScore; + } + public AggregatedPageImpl(List content, Pageable pageable, long total, String scrollId) { super(content, pageable, total); this.scrollId = scrollId; } + public AggregatedPageImpl(List content, Pageable pageable, long total, String scrollId, float maxScore) { + this(content, pageable, total, scrollId); + this.maxScore = maxScore; + } + public AggregatedPageImpl(List content, Pageable pageable, long total, Aggregations aggregations) { super(content, pageable, total); this.aggregations = aggregations; - if (aggregations != null) { - for (Aggregation aggregation : aggregations) { - mapOfAggregations.put(aggregation.getName(), aggregation); - } - } } - public AggregatedPageImpl(List content, Pageable pageable, long total, Aggregations aggregations, String scrollId) { - super(content, pageable, total); - this.aggregations = aggregations; + public AggregatedPageImpl(List content, Pageable pageable, long total, Aggregations aggregations, float maxScore) { + this(content, pageable, total, aggregations); + this.maxScore = maxScore; + } + + public AggregatedPageImpl(List content, Pageable pageable, long total, Aggregations aggregations, + String scrollId) { + this(content, pageable, total, aggregations); this.scrollId = scrollId; - if (aggregations != null) { - for (Aggregation aggregation : aggregations) { - mapOfAggregations.put(aggregation.getName(), aggregation); - } - } + } + + public AggregatedPageImpl(List content, Pageable pageable, long total, Aggregations aggregations, String scrollId, + float maxScore) { + this(content, pageable, total, aggregations, scrollId); + this.maxScore = maxScore; } @Override public boolean hasAggregations() { - return aggregations != null && mapOfAggregations.size() > 0; + return aggregations != null; } @Override @@ -94,4 +114,9 @@ public class AggregatedPageImpl extends FacetedPageImpl implements Aggrega public String getScrollId() { return scrollId; } + + @Override + public float getMaxScore() { + return maxScore; + } } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/mapping/ElasticsearchPersistentEntity.java b/src/main/java/org/springframework/data/elasticsearch/core/mapping/ElasticsearchPersistentEntity.java index 787d1ac79..d78263acc 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/mapping/ElasticsearchPersistentEntity.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/mapping/ElasticsearchPersistentEntity.java @@ -16,6 +16,7 @@ package org.springframework.data.elasticsearch.core.mapping; import org.springframework.data.mapping.PersistentEntity; +import org.springframework.lang.Nullable; /** * ElasticsearchPersistentEntity @@ -23,6 +24,7 @@ import org.springframework.data.mapping.PersistentEntity; * @author Rizwan Idrees * @author Mohsin Husen * @author Mark Paluch + * @author Sascha Woo */ public interface ElasticsearchPersistentEntity extends PersistentEntity { @@ -49,4 +51,22 @@ public interface ElasticsearchPersistentEntity extends PersistentEntity { String getFieldName(); + /** + * Returns whether the current property is a potential score property of the owning + * {@link ElasticsearchPersistentEntity}. This method is mainly used by {@link ElasticsearchPersistentEntity} + * implementation to discover score property candidates on {@link ElasticsearchPersistentEntity} creation you should + * rather call {@link ElasticsearchPersistentEntity#isScoreProperty(PersistentProperty)} to determine whether the + * current property is the version property of that {@link ElasticsearchPersistentEntity} under consideration. + * + * @return + */ + boolean isScoreProperty(); + public enum PropertyToFieldNameConverter implements Converter { INSTANCE; diff --git a/src/main/java/org/springframework/data/elasticsearch/core/mapping/SimpleElasticsearchPersistentEntity.java b/src/main/java/org/springframework/data/elasticsearch/core/mapping/SimpleElasticsearchPersistentEntity.java index 18e685ddb..3948c8637 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/mapping/SimpleElasticsearchPersistentEntity.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/mapping/SimpleElasticsearchPersistentEntity.java @@ -18,7 +18,6 @@ package org.springframework.data.elasticsearch.core.mapping; import static org.springframework.util.StringUtils.*; import java.util.Locale; -import java.util.Optional; import org.springframework.beans.BeansException; import org.springframework.context.ApplicationContext; @@ -28,12 +27,14 @@ import org.springframework.context.expression.BeanFactoryResolver; import org.springframework.data.elasticsearch.annotations.Document; import org.springframework.data.elasticsearch.annotations.Parent; import org.springframework.data.elasticsearch.annotations.Setting; +import org.springframework.data.mapping.MappingException; import org.springframework.data.mapping.model.BasicPersistentEntity; import org.springframework.data.util.TypeInformation; import org.springframework.expression.Expression; import org.springframework.expression.ParserContext; import org.springframework.expression.spel.standard.SpelExpressionParser; import org.springframework.expression.spel.support.StandardEvaluationContext; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -43,6 +44,7 @@ import org.springframework.util.Assert; * @author Rizwan Idrees * @author Mohsin Husen * @author Mark Paluch + * @author Sascha Woo */ public class SimpleElasticsearchPersistentEntity extends BasicPersistentEntity implements ElasticsearchPersistentEntity, ApplicationContextAware { @@ -59,6 +61,7 @@ public class SimpleElasticsearchPersistentEntity extends BasicPersistentEntit private String indexStoreType; private String parentType; private ElasticsearchPersistentProperty parentIdProperty; + private ElasticsearchPersistentProperty scoreProperty; private String settingPath; private boolean createIndexAndMapping; @@ -150,6 +153,17 @@ public class SimpleElasticsearchPersistentEntity extends BasicPersistentEntit return createIndexAndMapping; } + @Override + public boolean hasScoreProperty() { + return scoreProperty != null; + } + + @Nullable + @Override + public ElasticsearchPersistentProperty getScoreProperty() { + return scoreProperty; + } + @Override public void addPersistentProperty(ElasticsearchPersistentProperty property) { super.addPersistentProperty(property); @@ -165,7 +179,21 @@ public class SimpleElasticsearchPersistentEntity extends BasicPersistentEntit } if (property.isVersionProperty()) { - Assert.isTrue(property.getType() == Long.class, "Version property should be Long"); + Assert.isTrue(property.getType() == Long.class, "Version property must be of type Long!"); + } + + if (property.isScoreProperty()) { + ElasticsearchPersistentProperty scoreProperty = this.scoreProperty; + + if (scoreProperty != null) { + throw new MappingException( + String.format("Attempt to add score property %s but already have property %s registered " + + "as version. Check your mapping configuration!", property.getField(), scoreProperty.getField())); + } + + Assert.isTrue(property.getType() == Float.class || property.getType() == Float.TYPE, "Score property must be of type float!"); + + this.scoreProperty = property; } } } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/mapping/SimpleElasticsearchPersistentProperty.java b/src/main/java/org/springframework/data/elasticsearch/core/mapping/SimpleElasticsearchPersistentProperty.java index b79b4b200..21afb1e27 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/mapping/SimpleElasticsearchPersistentProperty.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/mapping/SimpleElasticsearchPersistentProperty.java @@ -18,11 +18,13 @@ package org.springframework.data.elasticsearch.core.mapping; import java.util.HashSet; import java.util.Set; +import org.springframework.data.elasticsearch.annotations.Score; import org.springframework.data.mapping.Association; import org.springframework.data.mapping.PersistentEntity; import org.springframework.data.mapping.model.AnnotationBasedPersistentProperty; import org.springframework.data.mapping.model.Property; import org.springframework.data.mapping.model.SimpleTypeHolder; +import org.springframework.data.util.Lazy; /** * Elasticsearch specific {@link org.springframework.data.mapping.PersistentProperty} implementation processing @@ -30,12 +32,15 @@ import org.springframework.data.mapping.model.SimpleTypeHolder; * @author Rizwan Idrees * @author Mohsin Husen * @author Mark Paluch + * @author Sascha Woo */ public class SimpleElasticsearchPersistentProperty extends AnnotationBasedPersistentProperty implements ElasticsearchPersistentProperty { private static final Set> SUPPORTED_ID_TYPES = new HashSet<>(); private static final Set SUPPORTED_ID_PROPERTY_NAMES = new HashSet<>(); + + private final Lazy isScore = Lazy.of(() -> isAnnotationPresent(Score.class)); static { SUPPORTED_ID_TYPES.add(String.class); @@ -62,4 +67,9 @@ public class SimpleElasticsearchPersistentProperty extends protected Association createAssociation() { return null; } + + @Override + public boolean isScoreProperty() { + return isScore.get(); + } } diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java index 7ac78bee7..08ef1b390 100755 --- a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java @@ -36,6 +36,8 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; import org.elasticsearch.search.fetch.subphase.highlight.HighlightField; import org.elasticsearch.search.sort.FieldSortBuilder; +import org.elasticsearch.search.sort.SortBuilder; +import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortOrder; import org.hamcrest.Matchers; import org.junit.Before; @@ -1512,6 +1514,31 @@ public class ElasticsearchTemplateTests { assertThat(page.getContent().get(0).getMessage(), is("ab")); } + @Test // DATAES-462 + public void shouldReturnScores() { + // given + List indexQueries = new ArrayList<>(); + + indexQueries.add(buildIndex(SampleEntity.builder().id("1").message("ab xz").build())); + indexQueries.add(buildIndex(SampleEntity.builder().id("2").message("bc").build())); + indexQueries.add(buildIndex(SampleEntity.builder().id("3").message("ac xz hi").build())); + + elasticsearchTemplate.bulkIndex(indexQueries); + elasticsearchTemplate.refresh(SampleEntity.class); + + // when + SearchQuery searchQuery = new NativeSearchQueryBuilder() + .withQuery(termQuery("message", "xz")) + .withSort(SortBuilders.fieldSort("message")) + .withTrackScores(true) + .build(); + + AggregatedPage page = elasticsearchTemplate.queryForPage(searchQuery, SampleEntity.class); + + // then + assertThat(page.getMaxScore(), greaterThan(0f)); + assertThat(page.getContent().get(0).getScore(), greaterThan(0f)); + } @Test public void shouldDoIndexWithoutId() { diff --git a/src/test/java/org/springframework/data/elasticsearch/entities/SampleEntity.java b/src/test/java/org/springframework/data/elasticsearch/entities/SampleEntity.java index 751cf8c3d..344735f06 100644 --- a/src/test/java/org/springframework/data/elasticsearch/entities/SampleEntity.java +++ b/src/test/java/org/springframework/data/elasticsearch/entities/SampleEntity.java @@ -15,29 +15,29 @@ */ package org.springframework.data.elasticsearch.entities; -import java.lang.Double; -import java.lang.Long; -import java.lang.Object; +import static org.springframework.data.elasticsearch.annotations.FieldType.*; + import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; import lombok.ToString; + import org.springframework.data.annotation.Id; import org.springframework.data.annotation.Version; import org.springframework.data.elasticsearch.annotations.Document; import org.springframework.data.elasticsearch.annotations.Field; +import org.springframework.data.elasticsearch.annotations.Score; import org.springframework.data.elasticsearch.annotations.ScriptedField; import org.springframework.data.elasticsearch.core.geo.GeoPoint; -import static org.springframework.data.elasticsearch.annotations.FieldType.*; /** * @author Rizwan Idrees * @author Mohsin Husen * @author Chris White + * @author Sascha Woo */ - @Setter @Getter @NoArgsConstructor @@ -58,11 +58,11 @@ public class SampleEntity { private Double scriptedRate; private boolean available; private String highlightedMessage; - private GeoPoint location; - @Version private Long version; + @Score + private float score; @Override public boolean equals(Object o) {