diff --git a/src/main/java/org/springframework/data/elasticsearch/repository/config/EnableElasticsearchRepositories.java b/src/main/java/org/springframework/data/elasticsearch/repository/config/EnableElasticsearchRepositories.java index e18297dcb..ca1c0438b 100644 --- a/src/main/java/org/springframework/data/elasticsearch/repository/config/EnableElasticsearchRepositories.java +++ b/src/main/java/org/springframework/data/elasticsearch/repository/config/EnableElasticsearchRepositories.java @@ -21,6 +21,7 @@ import org.springframework.context.annotation.ComponentScan.Filter; import org.springframework.context.annotation.Import; import org.springframework.data.elasticsearch.core.ElasticsearchTemplate; import org.springframework.data.elasticsearch.repository.support.ElasticsearchRepositoryFactoryBean; +import org.springframework.data.repository.config.DefaultRepositoryBaseClass; import org.springframework.data.repository.query.QueryLookupStrategy.Key; /** @@ -103,6 +104,13 @@ public @interface EnableElasticsearchRepositories { */ Class repositoryFactoryBeanClass() default ElasticsearchRepositoryFactoryBean.class; + /** + * Configure the repository base class to be used to create repository proxies for this particular configuration. + * + * @return + */ + Class repositoryBaseClass() default DefaultRepositoryBaseClass.class; + // Elasticsearch specific configuration /** diff --git a/src/main/java/org/springframework/data/elasticsearch/repository/support/ElasticsearchRepositoryFactory.java b/src/main/java/org/springframework/data/elasticsearch/repository/support/ElasticsearchRepositoryFactory.java index b881a5d8e..547477678 100644 --- a/src/main/java/org/springframework/data/elasticsearch/repository/support/ElasticsearchRepositoryFactory.java +++ b/src/main/java/org/springframework/data/elasticsearch/repository/support/ElasticsearchRepositoryFactory.java @@ -27,6 +27,7 @@ import org.springframework.data.elasticsearch.repository.query.ElasticsearchQuer import org.springframework.data.elasticsearch.repository.query.ElasticsearchStringQuery; import org.springframework.data.querydsl.QueryDslPredicateExecutor; import org.springframework.data.repository.core.NamedQueries; +import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.core.support.RepositoryFactorySupport; import org.springframework.data.repository.query.QueryLookupStrategy; @@ -59,28 +60,8 @@ public class ElasticsearchRepositoryFactory extends RepositoryFactorySupport { @Override @SuppressWarnings({"rawtypes", "unchecked"}) - protected Object getTargetRepository(RepositoryMetadata metadata) { - - ElasticsearchEntityInformation entityInformation = getEntityInformation(metadata.getDomainType()); - - AbstractElasticsearchRepository repository; - - // Probably a better way to store and look these up. - if (Integer.class.isAssignableFrom(entityInformation.getIdType()) - || Long.class.isAssignableFrom(entityInformation.getIdType()) - || Double.class.isAssignableFrom(entityInformation.getIdType())) { - // logger.debug("Using NumberKeyedRepository for " + metadata.getRepositoryInterface()); - repository = new NumberKeyedRepository(getEntityInformation(metadata.getDomainType()), elasticsearchOperations); - } else if (entityInformation.getIdType() == String.class) { - // logger.debug("Using SimpleElasticsearchRepository for " + metadata.getRepositoryInterface()); - repository = new SimpleElasticsearchRepository(getEntityInformation(metadata.getDomainType()), - elasticsearchOperations); - } else { - throw new IllegalArgumentException("Unsuppored ID type " + entityInformation.getIdType()); - } - repository.setEntityClass(metadata.getDomainType()); - - return repository; + protected Object getTargetRepository(RepositoryInformation metadata) { + return getTargetRepositoryViaReflection(metadata,getEntityInformation(metadata.getDomainType()), elasticsearchOperations); } @Override @@ -88,7 +69,15 @@ public class ElasticsearchRepositoryFactory extends RepositoryFactorySupport { if (isQueryDslRepository(metadata.getRepositoryInterface())) { throw new IllegalArgumentException("QueryDsl Support has not been implemented yet."); } - return SimpleElasticsearchRepository.class; + if (Integer.class.isAssignableFrom(metadata.getIdType()) + || Long.class.isAssignableFrom(metadata.getIdType()) + || Double.class.isAssignableFrom(metadata.getIdType())) { + return NumberKeyedRepository.class; + } else if (metadata.getIdType() == String.class) { + return SimpleElasticsearchRepository.class; + } else { + throw new IllegalArgumentException("Unsuppored ID type " + metadata.getIdType()); + } } private static boolean isQueryDslRepository(Class repositoryInterface) {