SOLR-1302: Implemented Haversine, plus Euclidean, Manhattan distances (p-norms) plus deg, rad functions

git-svn-id: https://svn.apache.org/repos/asf/lucene/solr/trunk@836216 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Grant Ingersoll 2009-11-14 17:10:41 +00:00
parent 128c7bf12d
commit cdbcf8aca3
14 changed files with 1108 additions and 151 deletions

View File

@ -34,6 +34,9 @@ Detailed Change List
New Features New Features
---------------------- ----------------------
1. SOLR-1302: Added several new distance based functions, including Great Circle (haversine), Manhattan and Euclidean.
Also added deg() and rad() convenience functions. (gsingers)
Optimizations Optimizations
---------------------- ----------------------

View File

@ -87,6 +87,28 @@ public class FunctionQParser extends QParser {
return value; return value;
} }
/**
* Parse a Double
* @return double
* @throws ParseException
*/
public double parseDouble() throws ParseException {
double value = sp.getDouble();
consumeArgumentDelimiter();
return value;
}
/**
* Parse an integer
* @return An int
* @throws ParseException
*/
public int parseInt() throws ParseException {
int value = sp.getInt();
consumeArgumentDelimiter();
return value;
}
public String parseArg() throws ParseException { public String parseArg() throws ParseException {
sp.eatws(); sp.eatws();
char ch = sp.peek(); char ch = sp.peek();

View File

@ -557,6 +557,46 @@ public class QueryParsing {
return Float.parseFloat(new String(arr,0,i)); return Float.parseFloat(new String(arr,0,i));
} }
double getDouble() throws ParseException {
eatws();
char[] arr = new char[end-pos];
int i;
for (i=0; i<arr.length; i++) {
char ch = val.charAt(pos);
if ( (ch>='0' && ch<='9')
|| ch=='+' || ch=='-'
|| ch=='.' || ch=='e' || ch=='E'
) {
pos++;
arr[i]=ch;
} else {
break;
}
}
return Double.parseDouble(new String(arr,0,i));
}
int getInt() throws ParseException {
eatws();
char[] arr = new char[end-pos];
int i;
for (i=0; i<arr.length; i++) {
char ch = val.charAt(pos);
if ( (ch>='0' && ch<='9')
|| ch=='+' || ch=='-'
) {
pos++;
arr[i]=ch;
} else {
break;
}
}
return Integer.parseInt(new String(arr,0,i));
}
String getId() throws ParseException { String getId() throws ParseException {
eatws(); eatws();
int id_start=pos; int id_start=pos;

View File

@ -16,36 +16,63 @@
*/ */
package org.apache.solr.search; package org.apache.solr.search;
import java.util.*; import org.apache.lucene.index.IndexReader;
import java.io.IOException;
import org.apache.lucene.queryParser.ParseException; import org.apache.lucene.queryParser.ParseException;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.index.IndexReader;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrException;
import org.apache.solr.search.function.*; import org.apache.solr.common.util.NamedList;
import org.apache.solr.util.plugin.NamedListInitializedPlugin;
import org.apache.solr.schema.TrieDateField;
import org.apache.solr.schema.DateField; import org.apache.solr.schema.DateField;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.schema.LegacyDateField; import org.apache.solr.schema.LegacyDateField;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.schema.TrieDateField;
import org.apache.solr.search.function.BoostedQuery;
import org.apache.solr.search.function.DegreeFunction;
import org.apache.solr.search.function.DivFloatFunction;
import org.apache.solr.search.function.DocValues;
import org.apache.solr.search.function.DualFloatFunction;
import org.apache.solr.search.function.LinearFloatFunction;
import org.apache.solr.search.function.MaxFloatFunction;
import org.apache.solr.search.function.OrdFieldSource;
import org.apache.solr.search.function.PowFloatFunction;
import org.apache.solr.search.function.ProductFloatFunction;
import org.apache.solr.search.function.QueryValueSource;
import org.apache.solr.search.function.RadianFunction;
import org.apache.solr.search.function.RangeMapFloatFunction;
import org.apache.solr.search.function.ReciprocalFloatFunction;
import org.apache.solr.search.function.ReverseOrdFieldSource;
import org.apache.solr.search.function.ScaleFloatFunction;
import org.apache.solr.search.function.SimpleFloatFunction;
import org.apache.solr.search.function.SumFloatFunction;
import org.apache.solr.search.function.TopValueSource;
import org.apache.solr.search.function.ValueSource;
import org.apache.solr.search.function.distance.HaversineFunction;
import org.apache.solr.search.function.distance.SquaredEuclideanFunction;
import org.apache.solr.search.function.distance.VectorDistanceFunction;
import org.apache.solr.util.plugin.NamedListInitializedPlugin;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/** /**
* A factory that parses user queries to generate ValueSource instances. * A factory that parses user queries to generate ValueSource instances.
* Intented usage is to create pluggable, named functions for use in function queries. * Intented usage is to create pluggable, named functions for use in function queries.
*/ */
public abstract class ValueSourceParser implements NamedListInitializedPlugin public abstract class ValueSourceParser implements NamedListInitializedPlugin {
{
/** /**
* Initialize the plugin. * Initialize the plugin.
*/ */
public abstract void init( NamedList args ); public abstract void init(NamedList args);
/** /**
* Parse the user input into a ValueSource. * Parse the user input into a ValueSource.
* *
* @param fp * @param fp
* @throws ParseException * @throws ParseException
*/ */
@ -53,6 +80,7 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin
/* standard functions */ /* standard functions */
public static Map<String, ValueSourceParser> standardValueSourceParsers = new HashMap<String, ValueSourceParser>(); public static Map<String, ValueSourceParser> standardValueSourceParsers = new HashMap<String, ValueSourceParser>();
static { static {
standardValueSourceParsers.put("ord", new ValueSourceParser() { standardValueSourceParsers.put("ord", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
@ -62,7 +90,7 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
standardValueSourceParsers.put("rord", new ValueSourceParser() { standardValueSourceParsers.put("rord", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
@ -72,7 +100,7 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
standardValueSourceParsers.put("top", new ValueSourceParser() { standardValueSourceParsers.put("top", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
@ -81,6 +109,7 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin
if (source instanceof TopValueSource) return source; if (source instanceof TopValueSource) return source;
return new TopValueSource(source); return new TopValueSource(source);
} }
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
@ -89,23 +118,23 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin
ValueSource source = fp.parseValueSource(); ValueSource source = fp.parseValueSource();
float slope = fp.parseFloat(); float slope = fp.parseFloat();
float intercept = fp.parseFloat(); float intercept = fp.parseFloat();
return new LinearFloatFunction(source,slope,intercept); return new LinearFloatFunction(source, slope, intercept);
} }
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
standardValueSourceParsers.put("max", new ValueSourceParser() { standardValueSourceParsers.put("max", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
ValueSource source = fp.parseValueSource(); ValueSource source = fp.parseValueSource();
float val = fp.parseFloat(); float val = fp.parseFloat();
return new MaxFloatFunction(source,val); return new MaxFloatFunction(source, val);
} }
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
standardValueSourceParsers.put("recip", new ValueSourceParser() { standardValueSourceParsers.put("recip", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
@ -113,46 +142,46 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin
float m = fp.parseFloat(); float m = fp.parseFloat();
float a = fp.parseFloat(); float a = fp.parseFloat();
float b = fp.parseFloat(); float b = fp.parseFloat();
return new ReciprocalFloatFunction(source,m,a,b); return new ReciprocalFloatFunction(source, m, a, b);
} }
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
standardValueSourceParsers.put("scale", new ValueSourceParser() { standardValueSourceParsers.put("scale", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
ValueSource source = fp.parseValueSource(); ValueSource source = fp.parseValueSource();
float min = fp.parseFloat(); float min = fp.parseFloat();
float max = fp.parseFloat(); float max = fp.parseFloat();
return new TopValueSource(new ScaleFloatFunction(source,min,max)); return new TopValueSource(new ScaleFloatFunction(source, min, max));
} }
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
standardValueSourceParsers.put("pow", new ValueSourceParser() { standardValueSourceParsers.put("pow", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
ValueSource a = fp.parseValueSource(); ValueSource a = fp.parseValueSource();
ValueSource b = fp.parseValueSource(); ValueSource b = fp.parseValueSource();
return new PowFloatFunction(a,b); return new PowFloatFunction(a, b);
} }
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
standardValueSourceParsers.put("div", new ValueSourceParser() { standardValueSourceParsers.put("div", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
ValueSource a = fp.parseValueSource(); ValueSource a = fp.parseValueSource();
ValueSource b = fp.parseValueSource(); ValueSource b = fp.parseValueSource();
return new DivFloatFunction(a,b); return new DivFloatFunction(a, b);
} }
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
standardValueSourceParsers.put("map", new ValueSourceParser() { standardValueSourceParsers.put("map", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
@ -161,12 +190,12 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin
float max = fp.parseFloat(); float max = fp.parseFloat();
float target = fp.parseFloat(); float target = fp.parseFloat();
Float def = fp.hasMoreArguments() ? fp.parseFloat() : null; Float def = fp.hasMoreArguments() ? fp.parseFloat() : null;
return new RangeMapFloatFunction(source,min,max,target,def); return new RangeMapFloatFunction(source, min, max, target, def);
} }
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
standardValueSourceParsers.put("sqrt", new ValueSourceParser() { standardValueSourceParsers.put("sqrt", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
@ -175,11 +204,13 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin
protected String name() { protected String name() {
return "sqrt"; return "sqrt";
} }
protected float func(int doc, DocValues vals) { protected float func(int doc, DocValues vals) {
return (float)Math.sqrt(vals.floatVal(doc)); return (float) Math.sqrt(vals.floatVal(doc));
} }
}; };
} }
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
@ -190,15 +221,16 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin
protected String name() { protected String name() {
return "log"; return "log";
} }
protected float func(int doc, DocValues vals) { protected float func(int doc, DocValues vals) {
return (float)Math.log10(vals.floatVal(doc)); return (float) Math.log10(vals.floatVal(doc));
} }
}; };
} }
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
standardValueSourceParsers.put("abs", new ValueSourceParser() { standardValueSourceParsers.put("abs", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
@ -207,15 +239,16 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin
protected String name() { protected String name() {
return "abs"; return "abs";
} }
protected float func(int doc, DocValues vals) { protected float func(int doc, DocValues vals) {
return (float)Math.abs(vals.floatVal(doc)); return (float) Math.abs(vals.floatVal(doc));
} }
}; };
} }
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
standardValueSourceParsers.put("sum", new ValueSourceParser() { standardValueSourceParsers.put("sum", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
@ -225,7 +258,7 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
standardValueSourceParsers.put("product", new ValueSourceParser() { standardValueSourceParsers.put("product", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
@ -235,16 +268,17 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
standardValueSourceParsers.put("sub", new ValueSourceParser() { standardValueSourceParsers.put("sub", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
ValueSource a = fp.parseValueSource(); ValueSource a = fp.parseValueSource();
ValueSource b = fp.parseValueSource(); ValueSource b = fp.parseValueSource();
return new DualFloatFunction(a,b) { return new DualFloatFunction(a, b) {
protected String name() { protected String name() {
return "sub"; return "sub";
} }
protected float func(int doc, DocValues aVals, DocValues bVals) { protected float func(int doc, DocValues aVals, DocValues bVals) {
return aVals.floatVal(doc) - bVals.floatVal(doc); return aVals.floatVal(doc) - bVals.floatVal(doc);
} }
@ -268,7 +302,7 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin
public void init(NamedList args) { public void init(NamedList args) {
} }
}); });
standardValueSourceParsers.put("boost", new ValueSourceParser() { standardValueSourceParsers.put("boost", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
@ -280,30 +314,115 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin
public void init(NamedList args) { public void init(NamedList args) {
} }
});
standardValueSourceParsers.put("hsin", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException {
ValueSource x1 = fp.parseValueSource();
ValueSource y1 = fp.parseValueSource();
ValueSource x2 = fp.parseValueSource();
ValueSource y2 = fp.parseValueSource();
double radius = fp.parseDouble();
return new HaversineFunction(x1, y1, x2, y2, radius);
}
public void init(NamedList args) {
}
});
standardValueSourceParsers.put("rad", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException {
return new RadianFunction(fp.parseValueSource());
}
public void init(NamedList args) {
}
});
standardValueSourceParsers.put("deg", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException {
return new DegreeFunction(fp.parseValueSource());
}
public void init(NamedList args) {
}
});
standardValueSourceParsers.put("sqedist", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException {
List<ValueSource> sources = fp.parseValueSourceList();
if (sources.size() % 2 != 0) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Illegal number of sources. There must be an even number of sources");
}
int dim = sources.size() / 2;
List<ValueSource> sources1 = new ArrayList<ValueSource>(dim);
List<ValueSource> sources2 = new ArrayList<ValueSource>(dim);
//Get dim value sources for the first vector
splitSources(dim, sources, sources1, sources2);
return new SquaredEuclideanFunction(sources1, sources2);
}
public void init(NamedList args) {
}
});
standardValueSourceParsers.put("dist", new ValueSourceParser() {
public ValueSource parse(FunctionQParser fp) throws ParseException {
float power = fp.parseFloat();
List<ValueSource> sources = fp.parseValueSourceList();
if (sources.size() % 2 != 0) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Illegal number of sources. There must be an even number of sources");
}
int dim = sources.size() / 2;
List<ValueSource> sources1 = new ArrayList<ValueSource>(dim);
List<ValueSource> sources2 = new ArrayList<ValueSource>(dim);
splitSources(dim, sources, sources1, sources2);
return new VectorDistanceFunction(power, sources1, sources2);
}
public void init(NamedList args) {
}
}); });
standardValueSourceParsers.put("ms", new DateValueSourceParser()); standardValueSourceParsers.put("ms", new DateValueSourceParser());
} }
protected void splitSources(int dim, List<ValueSource> sources, List<ValueSource> dest1, List<ValueSource> dest2) {
//Get dim value sources for the first vector
for (int i = 0; i < dim; i++) {
dest1.add(sources.get(i));
}
//Get dim value sources for the second vector
for (int i = dim; i < sources.size(); i++) {
dest2.add(sources.get(i));
}
}
} }
class DateValueSourceParser extends ValueSourceParser { class DateValueSourceParser extends ValueSourceParser {
DateField df = new TrieDateField(); DateField df = new TrieDateField();
public void init(NamedList args) {}
public void init(NamedList args) {
}
public Date getDate(FunctionQParser fp, String arg) { public Date getDate(FunctionQParser fp, String arg) {
if (arg==null) return null; if (arg == null) return null;
if (arg.startsWith("NOW") || (arg.length()>0 && Character.isDigit(arg.charAt(0)))) { if (arg.startsWith("NOW") || (arg.length() > 0 && Character.isDigit(arg.charAt(0)))) {
return df.parseMathLenient(null, arg, fp.req); return df.parseMathLenient(null, arg, fp.req);
} }
return null; return null;
} }
public ValueSource getValueSource(FunctionQParser fp, String arg) { public ValueSource getValueSource(FunctionQParser fp, String arg) {
if (arg==null) return null; if (arg == null) return null;
SchemaField f = fp.req.getSchema().getField(arg); SchemaField f = fp.req.getSchema().getField(arg);
if (f.getType().getClass() == DateField.class || f.getType().getClass() == LegacyDateField.class) { if (f.getType().getClass() == DateField.class || f.getType().getClass() == LegacyDateField.class) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Can't use ms() function on non-numeric legacy date field " + arg); throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Can't use ms() function on non-numeric legacy date field " + arg);
@ -314,13 +433,13 @@ class DateValueSourceParser extends ValueSourceParser {
public ValueSource parse(FunctionQParser fp) throws ParseException { public ValueSource parse(FunctionQParser fp) throws ParseException {
String first = fp.parseArg(); String first = fp.parseArg();
String second = fp.parseArg(); String second = fp.parseArg();
if (first==null) first="NOW"; if (first == null) first = "NOW";
Date d1=getDate(fp,first); Date d1 = getDate(fp, first);
ValueSource v1 = d1==null ? getValueSource(fp, first) : null; ValueSource v1 = d1 == null ? getValueSource(fp, first) : null;
Date d2=getDate(fp,second); Date d2 = getDate(fp, second);
ValueSource v2 = d2==null ? getValueSource(fp, second) : null; ValueSource v2 = d2 == null ? getValueSource(fp, second) : null;
// d constant // d constant
// v field // v field
@ -330,54 +449,57 @@ class DateValueSourceParser extends ValueSourceParser {
// vv subtract fields // vv subtract fields
final long ms1 = (d1 == null) ? 0 : d1.getTime(); final long ms1 = (d1 == null) ? 0 : d1.getTime();
final long ms2 = (d2 == null) ? 0 : d2.getTime(); final long ms2 = (d2 == null) ? 0 : d2.getTime();
// "d,dd" handle both constant cases // "d,dd" handle both constant cases
if (d1 != null && v2==null) { if (d1 != null && v2 == null) {
return new LongConstValueSource(ms1-ms2); return new LongConstValueSource(ms1 - ms2);
} }
// "v" just the date field // "v" just the date field
if (v1 != null && v2==null && d2==null) { if (v1 != null && v2 == null && d2 == null) {
return v1; return v1;
} }
// "dv" // "dv"
if (d1!=null && v2!=null) if (d1 != null && v2 != null)
return new DualFloatFunction(new LongConstValueSource(ms1), v2) { return new DualFloatFunction(new LongConstValueSource(ms1), v2) {
protected String name() { protected String name() {
return "ms"; return "ms";
} }
protected float func(int doc, DocValues aVals, DocValues bVals) { protected float func(int doc, DocValues aVals, DocValues bVals) {
return ms1 - bVals.longVal(doc); return ms1 - bVals.longVal(doc);
} }
}; };
// "vd" // "vd"
if (v1!=null && d2!=null) if (v1 != null && d2 != null)
return new DualFloatFunction(v1, new LongConstValueSource(ms2)) { return new DualFloatFunction(v1, new LongConstValueSource(ms2)) {
protected String name() { protected String name() {
return "ms"; return "ms";
} }
protected float func(int doc, DocValues aVals, DocValues bVals) { protected float func(int doc, DocValues aVals, DocValues bVals) {
return aVals.longVal(doc) - ms2; return aVals.longVal(doc) - ms2;
} }
}; };
// "vv" // "vv"
if (v1!=null && v2!=null) if (v1 != null && v2 != null)
return new DualFloatFunction(v1,v2) { return new DualFloatFunction(v1, v2) {
protected String name() { protected String name() {
return "ms"; return "ms";
} }
protected float func(int doc, DocValues aVals, DocValues bVals) { protected float func(int doc, DocValues aVals, DocValues bVals) {
return aVals.longVal(doc) - bVals.longVal(doc); return aVals.longVal(doc) - bVals.longVal(doc);
} }
}; };
return null; // shouldn't happen return null; // shouldn't happen
} }
} }
@ -400,18 +522,23 @@ class LongConstValueSource extends ValueSource {
public float floatVal(int doc) { public float floatVal(int doc) {
return constant; return constant;
} }
public int intVal(int doc) { public int intVal(int doc) {
return (int)constant; return (int) constant;
} }
public long longVal(int doc) { public long longVal(int doc) {
return constant; return constant;
} }
public double doubleVal(int doc) { public double doubleVal(int doc) {
return constant; return constant;
} }
public String strVal(int doc) { public String strVal(int doc) {
return Long.toString(constant); return Long.toString(constant);
} }
public String toString(int doc) { public String toString(int doc) {
return description(); return description();
} }
@ -419,12 +546,12 @@ class LongConstValueSource extends ValueSource {
} }
public int hashCode() { public int hashCode() {
return (int)constant + (int)(constant>>>32); return (int) constant + (int) (constant >>> 32);
} }
public boolean equals(Object o) { public boolean equals(Object o) {
if (LongConstValueSource.class != o.getClass()) return false; if (LongConstValueSource.class != o.getClass()) return false;
LongConstValueSource other = (LongConstValueSource)o; LongConstValueSource other = (LongConstValueSource) o;
return this.constant == other.constant; return this.constant == other.constant;
} }
} }

View File

@ -0,0 +1,63 @@
package org.apache.solr.search.function;
import org.apache.lucene.index.IndexReader;
import java.util.Map;
import java.io.IOException;
/**
*
*
**/
public class DegreeFunction extends ValueSource{
protected ValueSource valSource;
public DegreeFunction(ValueSource valSource) {
this.valSource = valSource;
}
public String description() {
return "deg(" + valSource.description() + ')';
}
public DocValues getValues(Map context, IndexReader reader) throws IOException {
final DocValues dv = valSource.getValues(context, reader);
return new DocValues() {
public float floatVal(int doc) {
return (float) doubleVal(doc);
}
public int intVal(int doc) {
return (int) doubleVal(doc);
}
public long longVal(int doc) {
return (long) doubleVal(doc);
}
public double doubleVal(int doc) {
return Math.toDegrees(dv.doubleVal(doc));
}
public String strVal(int doc) {
return Double.toString(doubleVal(doc));
}
public String toString(int doc) {
return description() + '=' + floatVal(doc);
}
};
}
public boolean equals(Object o) {
if (o.getClass() != DegreeFunction.class) return false;
DegreeFunction other = (DegreeFunction) o;
return description().equals(other.description()) && valSource.equals(other.valSource);
}
public int hashCode() {
return description().hashCode() + valSource.hashCode();
};
}

View File

@ -0,0 +1,113 @@
package org.apache.solr.search.function;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Searcher;
import java.util.Map;
import java.util.Arrays;
import java.io.IOException;
/**
*
*
**/ // a simple function of multiple sources
public abstract class MultiFloatFunction extends ValueSource {
protected final ValueSource[] sources;
public MultiFloatFunction(ValueSource[] sources) {
this.sources = sources;
}
abstract protected String name();
abstract protected float func(int doc, DocValues[] valsArr);
public String description() {
StringBuilder sb = new StringBuilder();
sb.append(name()).append('(');
boolean firstTime=true;
for (ValueSource source : sources) {
if (firstTime) {
firstTime=false;
} else {
sb.append(',');
}
sb.append(source);
}
sb.append(')');
return sb.toString();
}
public DocValues getValues(Map context, IndexReader reader) throws IOException {
final DocValues[] valsArr = new DocValues[sources.length];
for (int i=0; i<sources.length; i++) {
valsArr[i] = sources[i].getValues(context, reader);
}
return new DocValues() {
public float floatVal(int doc) {
return func(doc, valsArr);
}
public int intVal(int doc) {
return (int)floatVal(doc);
}
public long longVal(int doc) {
return (long)floatVal(doc);
}
public double doubleVal(int doc) {
return (double)floatVal(doc);
}
public String strVal(int doc) {
return Float.toString(floatVal(doc));
}
public String toString(int doc) {
StringBuilder sb = new StringBuilder();
sb.append(name()).append('(');
boolean firstTime=true;
for (DocValues vals : valsArr) {
if (firstTime) {
firstTime=false;
} else {
sb.append(',');
}
sb.append(vals.toString(doc));
}
sb.append(')');
return sb.toString();
}
};
}
@Override
public void createWeight(Map context, Searcher searcher) throws IOException {
for (ValueSource source : sources)
source.createWeight(context, searcher);
}
public int hashCode() {
return Arrays.hashCode(sources) + name().hashCode();
}
public boolean equals(Object o) {
if (this.getClass() != o.getClass()) return false;
MultiFloatFunction other = (MultiFloatFunction)o;
return this.name().equals(other.name())
&& Arrays.equals(this.sources, other.sources);
}
}

View File

@ -0,0 +1,79 @@
package org.apache.solr.search.function;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.index.IndexReader;
import java.io.IOException;
import java.util.Map;
/**
* Take a ValueSourc and produce convert the number to radians and
* return that value
*/
public class RadianFunction extends ValueSource {
protected ValueSource valSource;
public RadianFunction(ValueSource valSource) {
this.valSource = valSource;
}
public String description() {
return "rad(" + valSource.description() + ')';
}
public DocValues getValues(Map context, IndexReader reader) throws IOException {
final DocValues dv = valSource.getValues(context, reader);
return new DocValues() {
public float floatVal(int doc) {
return (float) doubleVal(doc);
}
public int intVal(int doc) {
return (int) doubleVal(doc);
}
public long longVal(int doc) {
return (long) doubleVal(doc);
}
public double doubleVal(int doc) {
return Math.toRadians(dv.doubleVal(doc));
}
public String strVal(int doc) {
return Double.toString(doubleVal(doc));
}
public String toString(int doc) {
return description() + '=' + floatVal(doc);
}
};
}
public boolean equals(Object o) {
if (o.getClass() != RadianFunction.class) return false;
RadianFunction other = (RadianFunction) o;
return description().equals(other.description()) && valSource.equals(other.valSource);
}
public int hashCode() {
return description().hashCode() + valSource.hashCode();
};
}

View File

@ -44,90 +44,4 @@ public class SumFloatFunction extends MultiFloatFunction {
} }
return val; return val;
} }
} }
// a simple function of multiple sources
abstract class MultiFloatFunction extends ValueSource {
protected final ValueSource[] sources;
public MultiFloatFunction(ValueSource[] sources) {
this.sources = sources;
}
abstract protected String name();
abstract protected float func(int doc, DocValues[] valsArr);
public String description() {
StringBuilder sb = new StringBuilder();
sb.append(name()+'(');
boolean firstTime=true;
for (ValueSource source : sources) {
if (firstTime) {
firstTime=false;
} else {
sb.append(',');
}
sb.append(source);
}
sb.append(')');
return sb.toString();
}
public DocValues getValues(Map context, IndexReader reader) throws IOException {
final DocValues[] valsArr = new DocValues[sources.length];
for (int i=0; i<sources.length; i++) {
valsArr[i] = sources[i].getValues(context, reader);
}
return new DocValues() {
public float floatVal(int doc) {
return func(doc, valsArr);
}
public int intVal(int doc) {
return (int)floatVal(doc);
}
public long longVal(int doc) {
return (long)floatVal(doc);
}
public double doubleVal(int doc) {
return (double)floatVal(doc);
}
public String strVal(int doc) {
return Float.toString(floatVal(doc));
}
public String toString(int doc) {
StringBuilder sb = new StringBuilder();
sb.append(name()+'(');
boolean firstTime=true;
for (DocValues vals : valsArr) {
if (firstTime) {
firstTime=false;
} else {
sb.append(',');
}
sb.append(vals.toString(doc));
}
sb.append(')');
return sb.toString();
}
};
}
@Override
public void createWeight(Map context, Searcher searcher) throws IOException {
for (ValueSource source : sources)
source.createWeight(context, searcher);
}
public int hashCode() {
return Arrays.hashCode(sources) + name().hashCode();
}
public boolean equals(Object o) {
if (this.getClass() != o.getClass()) return false;
MultiFloatFunction other = (MultiFloatFunction)o;
return this.name().equals(other.name())
&& Arrays.equals(this.sources, other.sources);
}
}

View File

@ -0,0 +1,27 @@
package org.apache.solr.search.function.distance;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
*
*
**/
public interface Constants {
public static final double EARTH_RADIUS_KM = 6378.160187;
public static final double EARTH_RADIUS_MI = 3963.205;
}

View File

@ -0,0 +1,159 @@
package org.apache.solr.search.function.distance;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Searcher;
import org.apache.solr.search.function.DocValues;
import org.apache.solr.search.function.ValueSource;
import java.io.IOException;
import java.util.Map;
/**
* Calculate the Haversine formula (distance) between any two points on a sphere
* Takes in four value sources: (latA, lonA); (latB, lonB).
* <p/>
* Assumes the value sources are in radians
* <p/>
* See http://en.wikipedia.org/wiki/Great-circle_distance and
* http://en.wikipedia.org/wiki/Haversine_formula for the actual formula and
* also http://www.movable-type.co.uk/scripts/latlong.html
*
* @see org.apache.solr.search.function.RadianFunction
*/
public class HaversineFunction extends ValueSource {
private ValueSource x1;
private ValueSource y1;
private ValueSource x2;
private ValueSource y2;
private double radius;
public HaversineFunction(ValueSource x1, ValueSource y1, ValueSource x2, ValueSource y2, double radius) {
this.x1 = x1;
this.y1 = y1;
this.x2 = x2;
this.y2 = y2;
this.radius = radius;
}
protected String name() {
return "hsin";
}
/**
* @param doc The doc to score
* @param x1DV
* @param y1DV
* @param x2DV
* @param y2DV
* @return The haversine distance formula
*/
protected double distance(int doc, DocValues x1DV, DocValues y1DV, DocValues x2DV, DocValues y2DV) {
double result = 0;
double x1 = x1DV.doubleVal(doc); //in radians
double y1 = y1DV.doubleVal(doc);
double x2 = x2DV.doubleVal(doc);
double y2 = y2DV.doubleVal(doc);
//make sure they aren't all the same, as then we can just return 0
if ((x1 != x2) || (y1 != y2)) {
double diffX = x1 - x2;
double diffY = y1 - y2;
double hsinX = Math.sin(diffX / 2);
double hsinY = Math.sin(diffY / 2);
double h = hsinX * hsinX +
(Math.cos(x1) * Math.cos(x2) * hsinY * hsinY);
result = (radius * 2 * Math.atan2(Math.sqrt(h), Math.sqrt(1 - h)));
}
return result;
}
@Override
public DocValues getValues(Map context, IndexReader reader) throws IOException {
final DocValues x1DV = x1.getValues(context, reader);
final DocValues y1DV = y1.getValues(context, reader);
final DocValues x2DV = x2.getValues(context, reader);
final DocValues y2DV = y2.getValues(context, reader);
return new DocValues() {
public float floatVal(int doc) {
return (float) doubleVal(doc);
}
public int intVal(int doc) {
return (int) doubleVal(doc);
}
public long longVal(int doc) {
return (long) doubleVal(doc);
}
public double doubleVal(int doc) {
return (double) distance(doc, x1DV, y1DV, x2DV, y2DV);
}
public String strVal(int doc) {
return Double.toString(doubleVal(doc));
}
@Override
public String toString(int doc) {
StringBuilder sb = new StringBuilder();
sb.append(name()).append('(');
sb.append(x1DV.toString(doc)).append(',').append(y1DV.toString(doc)).append(',')
.append(x2DV.toString(doc)).append(',').append(y2DV.toString(doc));
sb.append(')');
return sb.toString();
}
};
}
@Override
public void createWeight(Map context, Searcher searcher) throws IOException {
x1.createWeight(context, searcher);
x2.createWeight(context, searcher);
y1.createWeight(context, searcher);
y2.createWeight(context, searcher);
}
public boolean equals(Object o) {
if (this.getClass() != o.getClass()) return false;
HaversineFunction other = (HaversineFunction) o;
return this.name().equals(other.name())
&& x1.equals(other.x1) &&
y1.equals(other.y1) &&
x2.equals(other.x2) &&
y2.equals(other.y2);
}
public int hashCode() {
return x1.hashCode() + x2.hashCode() + y1.hashCode() + y2.hashCode() + name().hashCode();
}
public String description() {
StringBuilder sb = new StringBuilder();
sb.append(name() + '(');
sb.append(x1).append(',').append(y1).append(',').append(x2).append(',').append(y2);
sb.append(')');
return sb.toString();
}
}

View File

@ -0,0 +1,73 @@
package org.apache.solr.search.function.distance;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.solr.search.function.DocValues;
import org.apache.solr.search.function.ValueSource;
import java.util.List;
/**
* While not strictly a distance, the Sq. Euclidean Distance is often all that is needed in many applications
* that require a distance, thus saving a sq. rt. calculation
*
**/
public class SquaredEuclideanFunction extends VectorDistanceFunction {
protected String name = "sqedist";
public SquaredEuclideanFunction(List<ValueSource> sources1, List<ValueSource> sources2) {
super(-1, sources1, sources2);//overriding distance, so power doesn't matter here
}
protected String name() {
return name;
}
/**
* @param doc The doc to score
*/
protected double distance(int doc, DocValues[] docValues1, DocValues[] docValues2) {
double result = 0;
for (int i = 0; i < docValues1.length; i++) {
result += Math.pow(docValues1[i].doubleVal(doc) - docValues2[i].doubleVal(doc), 2);
}
return result;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof SquaredEuclideanFunction)) return false;
if (!super.equals(o)) return false;
SquaredEuclideanFunction that = (SquaredEuclideanFunction) o;
if (!name.equals(that.name)) return false;
return true;
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + name.hashCode();
return result;
}
}

View File

@ -0,0 +1,214 @@
package org.apache.solr.search.function.distance;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Searcher;
import org.apache.solr.common.SolrException;
import org.apache.solr.search.function.DocValues;
import org.apache.solr.search.function.ValueSource;
import java.io.IOException;
import java.util.List;
import java.util.Map;
/**
* Calculate the p-norm for a Vector. See http://en.wikipedia.org/wiki/Lp_space
* <p/>
* Common cases:
* <ul>
* <li>0 = Sparseness calculation</li>
* <li>1 = Manhattan distance</li>
* <li>2 = Euclidean distance</li>
* <li>Integer.MAX_VALUE = infinite norm</li>
* </ul>
*
* @see SquaredEuclideanFunction for the special case
*/
public class VectorDistanceFunction extends ValueSource {
protected List<ValueSource> sources1, sources2;
protected float power;
protected float oneOverPower;
public VectorDistanceFunction(float power, List<ValueSource> sources1, List<ValueSource> sources2) {
this.power = power;
this.oneOverPower = 1 / power;
this.sources1 = sources1;
this.sources2 = sources2;
if ((sources1.size() != sources2.size())) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Illegal number of sources");
}
}
protected String name() {
return "dist";
};
/**
* Calculate the distance
*
* @param doc The current doc
* @param docValues1 The values from the first set of value sources
* @param docValues2 The values from the second set of value sources
* @return The distance
*/
protected double distance(int doc, DocValues[] docValues1, DocValues[] docValues2) {
double result = 0;
//Handle some special cases:
if (power == 0) {
for (int i = 0; i < docValues1.length; i++) {
//sparseness measure
result += docValues1[i].doubleVal(doc) - docValues2[i].doubleVal(doc) == 0 ? 0 : 1;
}
} else if (power == 1.0) {
for (int i = 0; i < docValues1.length; i++) {
result += docValues1[i].doubleVal(doc) - docValues2[i].doubleVal(doc);
}
} else if (power == 2.0) {
for (int i = 0; i < docValues1.length; i++) {
double v = docValues1[i].doubleVal(doc) - docValues2[i].doubleVal(doc);
result += v * v;
}
result = Math.sqrt(result);
} else if (power == Integer.MAX_VALUE || Double.isInfinite(power)) {//infininte norm?
for (int i = 0; i < docValues1.length; i++) {
//TODO: is this the correct infinite norm?
result = Math.max(docValues1[i].doubleVal(doc) - docValues2[i].doubleVal(doc), result);
}
} else {
for (int i = 0; i < docValues1.length; i++) {
result += Math.pow(docValues1[i].doubleVal(doc) - docValues2[i].doubleVal(doc), power);
}
result = Math.pow(result, oneOverPower);
}
return result;
}
@Override
public DocValues getValues(Map context, IndexReader reader) throws IOException {
final DocValues[] valsArr1 = new DocValues[sources1.size()];
int i = 0;
for (ValueSource source : sources1) {
valsArr1[i++] = source.getValues(context, reader);
}
final DocValues[] valsArr2 = new DocValues[sources2.size()];
i = 0;
for (ValueSource source : sources2) {
valsArr2[i++] = source.getValues(context, reader);
}
return new DocValues() {
public float floatVal(int doc) {
return (float) doubleVal(doc);
}
public int intVal(int doc) {
return (int) doubleVal(doc);
}
public long longVal(int doc) {
return (long) doubleVal(doc);
}
public double doubleVal(int doc) {
return (double) distance(doc, valsArr1, valsArr2);
}
public String strVal(int doc) {
return Double.toString(doubleVal(doc));
}
@Override
public String toString(int doc) {
StringBuilder sb = new StringBuilder();
sb.append(name()).append('(').append(power).append(',');
boolean firstTime = true;
for (DocValues vals : valsArr1) {
if (firstTime) {
firstTime = false;
} else {
sb.append(',');
}
sb.append(vals.toString(doc));
}
for (DocValues vals : valsArr2) {
sb.append(',');//we will always have valsArr1, else there is an error
sb.append(vals.toString(doc));
}
sb.append(')');
return sb.toString();
}
};
}
@Override
public void createWeight(Map context, Searcher searcher) throws IOException {
for (ValueSource source : sources1) {
source.createWeight(context, searcher);
}
for (ValueSource source : sources2) {
source.createWeight(context, searcher);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof VectorDistanceFunction)) return false;
VectorDistanceFunction that = (VectorDistanceFunction) o;
if (Float.compare(that.power, power) != 0) return false;
if (!sources1.equals(that.sources1)) return false;
if (!sources2.equals(that.sources2)) return false;
return true;
}
@Override
public int hashCode() {
int result = sources1.hashCode();
result = 31 * result + sources2.hashCode();
result = 31 * result + (power != +0.0f ? Float.floatToIntBits(power) : 0);
return result;
}
public String description() {
StringBuilder sb = new StringBuilder();
sb.append(name()).append('(').append(power).append(',');
boolean firstTime = true;
for (ValueSource source : sources1) {
if (firstTime) {
firstTime = false;
} else {
sb.append(',');
}
sb.append(source);
}
for (ValueSource source : sources2) {
sb.append(',');//we will always have sources1, else there is an error
sb.append(source);
}
sb.append(')');
return sb.toString();
}
}

View File

@ -327,4 +327,20 @@ public class TestFunctionQuery extends AbstractSolrTestCase {
assertQ(req("fl","*,score","q", q, "qq","text:batman", "fq",fq), "//float[@name='score']<'1.0'"); assertQ(req("fl","*,score","q", q, "qq","text:batman", "fq",fq), "//float[@name='score']<'1.0'");
assertQ(req("fl","*,score","q", q, "qq","text:superman", "fq",fq), "//float[@name='score']>'1.0'"); assertQ(req("fl","*,score","q", q, "qq","text:superman", "fq",fq), "//float[@name='score']>'1.0'");
} }
public void testDegreeRads() throws Exception {
assertU(adoc("id", "1", "x_td", "0", "y_td", "0"));
assertU(adoc("id", "2", "x_td", "90", "y_td", String.valueOf(Math.PI / 2)));
assertU(adoc("id", "3", "x_td", "45", "y_td", String.valueOf(Math.PI / 4)));
assertU(commit());
assertQ(req("fl", "*,score", "q", "{!func}rad(x_td)", "fq", "id:1"), "//float[@name='score']='0.0'");
assertQ(req("fl", "*,score", "q", "{!func}rad(x_td)", "fq", "id:2"), "//float[@name='score']='" + (float) (Math.PI / 2) + "'");
assertQ(req("fl", "*,score", "q", "{!func}rad(x_td)", "fq", "id:3"), "//float[@name='score']='" + (float) (Math.PI / 4) + "'");
assertQ(req("fl", "*,score", "q", "{!func}deg(y_td)", "fq", "id:1"), "//float[@name='score']='0.0'");
assertQ(req("fl", "*,score", "q", "{!func}deg(y_td)", "fq", "id:2"), "//float[@name='score']='90.0'");
assertQ(req("fl", "*,score", "q", "{!func}deg(y_td)", "fq", "id:3"), "//float[@name='score']='45.0'");
}
} }

View File

@ -0,0 +1,107 @@
package org.apache.solr.search.function.distance;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.solr.common.SolrException;
import org.apache.solr.util.AbstractSolrTestCase;
/**
*
*
**/
public class DistanceFunctionTest extends AbstractSolrTestCase {
public String getSchemaFile() {
return "schema11.xml";
}
public String getSolrConfigFile() {
return "solrconfig-functionquery.xml";
}
public String getCoreName() {
return "basic";
}
public void testHaversine() throws Exception {
assertU(adoc("id", "1", "x_td", "0", "y_td", "0"));
assertU(adoc("id", "2", "x_td", "0", "y_td", String.valueOf(Math.PI / 2)));
assertU(adoc("id", "3", "x_td", String.valueOf(Math.PI / 2), "y_td", String.valueOf(Math.PI / 2)));
assertU(adoc("id", "4", "x_td", String.valueOf(Math.PI / 4), "y_td", String.valueOf(Math.PI / 4)));
assertU(commit());
//Get the haversine distance between the point 0,0 and the docs above assuming a radius of 1
assertQ(req("fl", "*,score", "q", "{!func}hsin(x_td, y_td, 0, 0, 1)", "fq", "id:1"), "//float[@name='score']='0.0'");
assertQ(req("fl", "*,score", "q", "{!func}hsin(x_td, y_td, 0, 0, 1)", "fq", "id:2"), "//float[@name='score']='" + (float) (Math.PI / 2) + "'");
assertQ(req("fl", "*,score", "q", "{!func}hsin(x_td, y_td, 0, 0, 1)", "fq", "id:3"), "//float[@name='score']='" + (float) (Math.PI / 2) + "'");
assertQ(req("fl", "*,score", "q", "{!func}hsin(x_td, y_td, 0, 0, 1)", "fq", "id:4"), "//float[@name='score']='1.0471976'");
}
public void testVector() throws Exception {
assertU(adoc("id", "1", "x_td", "0", "y_td", "0", "z_td", "0", "w_td", "0"));
assertU(adoc("id", "2", "x_td", "0", "y_td", "1", "z_td", "0", "w_td", "0"));
assertU(adoc("id", "3", "x_td", "1", "y_td", "1", "z_td", "1", "w_td", "1"));
assertU(adoc("id", "4", "x_td", "1", "y_td", "0", "z_td", "0", "w_td", "0"));
assertU(adoc("id", "5", "x_td", "2.3", "y_td", "5.5", "z_td", "7.9", "w_td", "-2.4"));
assertU(commit());
//two dimensions, notice how we only pass in 4 value sources
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, 0, 0)", "fq", "id:1"), "//float[@name='score']='0.0'");
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, 0, 0)", "fq", "id:2"), "//float[@name='score']='1.0'");
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, 0, 0)", "fq", "id:3"), "//float[@name='score']='" + 2.0f + "'");
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, 0, 0)", "fq", "id:4"), "//float[@name='score']='1.0'");
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, 0, 0)", "fq", "id:5"), "//float[@name='score']='" + (float) (2.3 * 2.3 + 5.5 * 5.5) + "'");
//three dimensions, notice how we pass in 6 value sources
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, 0, 0, 0)", "fq", "id:1"), "//float[@name='score']='0.0'");
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, 0, 0, 0)", "fq", "id:2"), "//float[@name='score']='1.0'");
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, 0, 0, 0)", "fq", "id:3"), "//float[@name='score']='" + 3.0f + "'");
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, 0, 0, 0)", "fq", "id:4"), "//float[@name='score']='1.0'");
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, 0, 0, 0)", "fq", "id:5"), "//float[@name='score']='" + (float) (2.3 * 2.3 + 5.5 * 5.5 + 7.9 * 7.9) + "'");
//four dimensions, notice how we pass in 8 value sources
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, w_td, 0, 0, 0, 0)", "fq", "id:1"), "//float[@name='score']='0.0'");
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, w_td, 0, 0, 0, 0)", "fq", "id:2"), "//float[@name='score']='1.0'");
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, w_td, 0, 0, 0, 0)", "fq", "id:3"), "//float[@name='score']='" + 4.0f + "'");
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, w_td, 0, 0, 0, 0)", "fq", "id:4"), "//float[@name='score']='1.0'");
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, w_td, 0, 0, 0, 0)", "fq", "id:5"), "//float[@name='score']='" + (float) (2.3 * 2.3 + 5.5 * 5.5 + 7.9 * 7.9 + 2.4 * 2.4) + "'");
//Pass in imbalanced list, throw exception
try {
assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, w_td, 0, 0, 0)", "fq", "id:1"), "//float[@name='score']='0.0'");
assertTrue("should throw an exception", false);
} catch (Exception e) {
Throwable cause = e.getCause();
assertNotNull(cause);
assertTrue(cause instanceof SolrException);
}
//do one test of Euclidean
//two dimensions, notice how we only pass in 4 value sources
assertQ(req("fl", "*,score", "q", "{!func}dist(2, x_td, y_td, 0, 0)", "fq", "id:1"), "//float[@name='score']='0.0'");
assertQ(req("fl", "*,score", "q", "{!func}dist(2, x_td, y_td, 0, 0)", "fq", "id:2"), "//float[@name='score']='1.0'");
assertQ(req("fl", "*,score", "q", "{!func}dist(2, x_td, y_td, 0, 0)", "fq", "id:3"), "//float[@name='score']='" + (float) Math.sqrt(2.0) + "'");
assertQ(req("fl", "*,score", "q", "{!func}dist(2, x_td, y_td, 0, 0)", "fq", "id:4"), "//float[@name='score']='1.0'");
assertQ(req("fl", "*,score", "q", "{!func}dist(2, x_td, y_td, 0, 0)", "fq", "id:5"), "//float[@name='score']='" + (float) Math.sqrt((2.3 * 2.3 + 5.5 * 5.5)) + "'");
//do one test of Manhattan
//two dimensions, notice how we only pass in 4 value sources
assertQ(req("fl", "*,score", "q", "{!func}dist(1, x_td, y_td, 0, 0)", "fq", "id:1"), "//float[@name='score']='0.0'");
assertQ(req("fl", "*,score", "q", "{!func}dist(1, x_td, y_td, 0, 0)", "fq", "id:2"), "//float[@name='score']='1.0'");
assertQ(req("fl", "*,score", "q", "{!func}dist(1, x_td, y_td, 0, 0)", "fq", "id:3"), "//float[@name='score']='" + (float) 2.0 + "'");
assertQ(req("fl", "*,score", "q", "{!func}dist(1, x_td, y_td, 0, 0)", "fq", "id:4"), "//float[@name='score']='1.0'");
assertQ(req("fl", "*,score", "q", "{!func}dist(1, x_td, y_td, 0, 0)", "fq", "id:5"), "//float[@name='score']='" + (float) (2.3 + 5.5) + "'");
}
}