diff --git a/CHANGES.txt b/CHANGES.txt index df8ff15feca..c233324ed2a 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -34,6 +34,9 @@ Detailed Change List New Features ---------------------- +1. SOLR-1302: Added several new distance based functions, including Great Circle (haversine), Manhattan and Euclidean. + Also added deg() and rad() convenience functions. (gsingers) + Optimizations ---------------------- diff --git a/src/java/org/apache/solr/search/FunctionQParser.java b/src/java/org/apache/solr/search/FunctionQParser.java index 40af4141919..e7cf228691d 100755 --- a/src/java/org/apache/solr/search/FunctionQParser.java +++ b/src/java/org/apache/solr/search/FunctionQParser.java @@ -87,6 +87,28 @@ public class FunctionQParser extends QParser { return value; } + /** + * Parse a Double + * @return double + * @throws ParseException + */ + public double parseDouble() throws ParseException { + double value = sp.getDouble(); + consumeArgumentDelimiter(); + return value; + } + + /** + * Parse an integer + * @return An int + * @throws ParseException + */ + public int parseInt() throws ParseException { + int value = sp.getInt(); + consumeArgumentDelimiter(); + return value; + } + public String parseArg() throws ParseException { sp.eatws(); char ch = sp.peek(); diff --git a/src/java/org/apache/solr/search/QueryParsing.java b/src/java/org/apache/solr/search/QueryParsing.java index 35d6e6ed28b..77a1dd749d4 100644 --- a/src/java/org/apache/solr/search/QueryParsing.java +++ b/src/java/org/apache/solr/search/QueryParsing.java @@ -557,6 +557,46 @@ public class QueryParsing { return Float.parseFloat(new String(arr,0,i)); } + double getDouble() throws ParseException { + eatws(); + char[] arr = new char[end-pos]; + int i; + for (i=0; i='0' && ch<='9') + || ch=='+' || ch=='-' + || ch=='.' || ch=='e' || ch=='E' + ) { + pos++; + arr[i]=ch; + } else { + break; + } + } + + return Double.parseDouble(new String(arr,0,i)); + } + + int getInt() throws ParseException { + eatws(); + char[] arr = new char[end-pos]; + int i; + for (i=0; i='0' && ch<='9') + || ch=='+' || ch=='-' + ) { + pos++; + arr[i]=ch; + } else { + break; + } + } + + return Integer.parseInt(new String(arr,0,i)); + } + + String getId() throws ParseException { eatws(); int id_start=pos; diff --git a/src/java/org/apache/solr/search/ValueSourceParser.java b/src/java/org/apache/solr/search/ValueSourceParser.java index 9a9b58f0b38..fb306ffe88a 100755 --- a/src/java/org/apache/solr/search/ValueSourceParser.java +++ b/src/java/org/apache/solr/search/ValueSourceParser.java @@ -16,36 +16,63 @@ */ package org.apache.solr.search; -import java.util.*; -import java.io.IOException; - +import org.apache.lucene.index.IndexReader; import org.apache.lucene.queryParser.ParseException; import org.apache.lucene.search.Query; -import org.apache.lucene.index.IndexReader; -import org.apache.solr.common.util.NamedList; import org.apache.solr.common.SolrException; -import org.apache.solr.search.function.*; -import org.apache.solr.util.plugin.NamedListInitializedPlugin; -import org.apache.solr.schema.TrieDateField; +import org.apache.solr.common.util.NamedList; import org.apache.solr.schema.DateField; -import org.apache.solr.schema.SchemaField; import org.apache.solr.schema.LegacyDateField; +import org.apache.solr.schema.SchemaField; +import org.apache.solr.schema.TrieDateField; +import org.apache.solr.search.function.BoostedQuery; +import org.apache.solr.search.function.DegreeFunction; +import org.apache.solr.search.function.DivFloatFunction; +import org.apache.solr.search.function.DocValues; +import org.apache.solr.search.function.DualFloatFunction; +import org.apache.solr.search.function.LinearFloatFunction; +import org.apache.solr.search.function.MaxFloatFunction; +import org.apache.solr.search.function.OrdFieldSource; +import org.apache.solr.search.function.PowFloatFunction; +import org.apache.solr.search.function.ProductFloatFunction; +import org.apache.solr.search.function.QueryValueSource; +import org.apache.solr.search.function.RadianFunction; +import org.apache.solr.search.function.RangeMapFloatFunction; +import org.apache.solr.search.function.ReciprocalFloatFunction; +import org.apache.solr.search.function.ReverseOrdFieldSource; +import org.apache.solr.search.function.ScaleFloatFunction; +import org.apache.solr.search.function.SimpleFloatFunction; +import org.apache.solr.search.function.SumFloatFunction; +import org.apache.solr.search.function.TopValueSource; +import org.apache.solr.search.function.ValueSource; + +import org.apache.solr.search.function.distance.HaversineFunction; + +import org.apache.solr.search.function.distance.SquaredEuclideanFunction; +import org.apache.solr.search.function.distance.VectorDistanceFunction; +import org.apache.solr.util.plugin.NamedListInitializedPlugin; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Date; +import java.util.HashMap; +import java.util.List; +import java.util.Map; /** * A factory that parses user queries to generate ValueSource instances. * Intented usage is to create pluggable, named functions for use in function queries. */ -public abstract class ValueSourceParser implements NamedListInitializedPlugin -{ - +public abstract class ValueSourceParser implements NamedListInitializedPlugin { + /** * Initialize the plugin. */ - public abstract void init( NamedList args ); - + public abstract void init(NamedList args); + /** * Parse the user input into a ValueSource. - * + * * @param fp * @throws ParseException */ @@ -53,6 +80,7 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin /* standard functions */ public static Map standardValueSourceParsers = new HashMap(); + static { standardValueSourceParsers.put("ord", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { @@ -62,7 +90,7 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin public void init(NamedList args) { } - + }); standardValueSourceParsers.put("rord", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { @@ -72,7 +100,7 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin public void init(NamedList args) { } - + }); standardValueSourceParsers.put("top", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { @@ -81,6 +109,7 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin if (source instanceof TopValueSource) return source; return new TopValueSource(source); } + public void init(NamedList args) { } }); @@ -89,23 +118,23 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin ValueSource source = fp.parseValueSource(); float slope = fp.parseFloat(); float intercept = fp.parseFloat(); - return new LinearFloatFunction(source,slope,intercept); + return new LinearFloatFunction(source, slope, intercept); } public void init(NamedList args) { } - + }); standardValueSourceParsers.put("max", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { ValueSource source = fp.parseValueSource(); float val = fp.parseFloat(); - return new MaxFloatFunction(source,val); + return new MaxFloatFunction(source, val); } public void init(NamedList args) { } - + }); standardValueSourceParsers.put("recip", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { @@ -113,46 +142,46 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin float m = fp.parseFloat(); float a = fp.parseFloat(); float b = fp.parseFloat(); - return new ReciprocalFloatFunction(source,m,a,b); + return new ReciprocalFloatFunction(source, m, a, b); } public void init(NamedList args) { } - + }); standardValueSourceParsers.put("scale", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { ValueSource source = fp.parseValueSource(); float min = fp.parseFloat(); float max = fp.parseFloat(); - return new TopValueSource(new ScaleFloatFunction(source,min,max)); + return new TopValueSource(new ScaleFloatFunction(source, min, max)); } public void init(NamedList args) { } - + }); standardValueSourceParsers.put("pow", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { ValueSource a = fp.parseValueSource(); ValueSource b = fp.parseValueSource(); - return new PowFloatFunction(a,b); + return new PowFloatFunction(a, b); } public void init(NamedList args) { } - + }); standardValueSourceParsers.put("div", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { ValueSource a = fp.parseValueSource(); ValueSource b = fp.parseValueSource(); - return new DivFloatFunction(a,b); + return new DivFloatFunction(a, b); } public void init(NamedList args) { } - + }); standardValueSourceParsers.put("map", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { @@ -161,12 +190,12 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin float max = fp.parseFloat(); float target = fp.parseFloat(); Float def = fp.hasMoreArguments() ? fp.parseFloat() : null; - return new RangeMapFloatFunction(source,min,max,target,def); + return new RangeMapFloatFunction(source, min, max, target, def); } public void init(NamedList args) { } - + }); standardValueSourceParsers.put("sqrt", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { @@ -175,11 +204,13 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin protected String name() { return "sqrt"; } + protected float func(int doc, DocValues vals) { - return (float)Math.sqrt(vals.floatVal(doc)); + return (float) Math.sqrt(vals.floatVal(doc)); } }; } + public void init(NamedList args) { } }); @@ -190,15 +221,16 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin protected String name() { return "log"; } + protected float func(int doc, DocValues vals) { - return (float)Math.log10(vals.floatVal(doc)); + return (float) Math.log10(vals.floatVal(doc)); } }; } public void init(NamedList args) { } - + }); standardValueSourceParsers.put("abs", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { @@ -207,15 +239,16 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin protected String name() { return "abs"; } + protected float func(int doc, DocValues vals) { - return (float)Math.abs(vals.floatVal(doc)); + return (float) Math.abs(vals.floatVal(doc)); } }; } public void init(NamedList args) { } - + }); standardValueSourceParsers.put("sum", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { @@ -225,7 +258,7 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin public void init(NamedList args) { } - + }); standardValueSourceParsers.put("product", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { @@ -235,16 +268,17 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin public void init(NamedList args) { } - + }); standardValueSourceParsers.put("sub", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { ValueSource a = fp.parseValueSource(); ValueSource b = fp.parseValueSource(); - return new DualFloatFunction(a,b) { + return new DualFloatFunction(a, b) { protected String name() { return "sub"; } + protected float func(int doc, DocValues aVals, DocValues bVals) { return aVals.floatVal(doc) - bVals.floatVal(doc); } @@ -268,7 +302,7 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin public void init(NamedList args) { } - + }); standardValueSourceParsers.put("boost", new ValueSourceParser() { public ValueSource parse(FunctionQParser fp) throws ParseException { @@ -280,30 +314,115 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin public void init(NamedList args) { } - + + }); + standardValueSourceParsers.put("hsin", new ValueSourceParser() { + public ValueSource parse(FunctionQParser fp) throws ParseException { + + ValueSource x1 = fp.parseValueSource(); + ValueSource y1 = fp.parseValueSource(); + ValueSource x2 = fp.parseValueSource(); + ValueSource y2 = fp.parseValueSource(); + double radius = fp.parseDouble(); + + return new HaversineFunction(x1, y1, x2, y2, radius); + } + + public void init(NamedList args) { + } + + }); + + standardValueSourceParsers.put("rad", new ValueSourceParser() { + public ValueSource parse(FunctionQParser fp) throws ParseException { + return new RadianFunction(fp.parseValueSource()); + } + + public void init(NamedList args) { + } + + }); + + standardValueSourceParsers.put("deg", new ValueSourceParser() { + public ValueSource parse(FunctionQParser fp) throws ParseException { + return new DegreeFunction(fp.parseValueSource()); + } + + public void init(NamedList args) { + } + + }); + + standardValueSourceParsers.put("sqedist", new ValueSourceParser() { + public ValueSource parse(FunctionQParser fp) throws ParseException { + List sources = fp.parseValueSourceList(); + if (sources.size() % 2 != 0) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Illegal number of sources. There must be an even number of sources"); + } + int dim = sources.size() / 2; + List sources1 = new ArrayList(dim); + List sources2 = new ArrayList(dim); + //Get dim value sources for the first vector + splitSources(dim, sources, sources1, sources2); + return new SquaredEuclideanFunction(sources1, sources2); + } + + public void init(NamedList args) { + } + + }); + + standardValueSourceParsers.put("dist", new ValueSourceParser() { + public ValueSource parse(FunctionQParser fp) throws ParseException { + float power = fp.parseFloat(); + List sources = fp.parseValueSourceList(); + if (sources.size() % 2 != 0) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Illegal number of sources. There must be an even number of sources"); + } + int dim = sources.size() / 2; + List sources1 = new ArrayList(dim); + List sources2 = new ArrayList(dim); + splitSources(dim, sources, sources1, sources2); + return new VectorDistanceFunction(power, sources1, sources2); + } + + public void init(NamedList args) { + } + }); standardValueSourceParsers.put("ms", new DateValueSourceParser()); } + protected void splitSources(int dim, List sources, List dest1, List dest2) { + //Get dim value sources for the first vector + for (int i = 0; i < dim; i++) { + dest1.add(sources.get(i)); + } + //Get dim value sources for the second vector + for (int i = dim; i < sources.size(); i++) { + dest2.add(sources.get(i)); + } + } + } - - class DateValueSourceParser extends ValueSourceParser { DateField df = new TrieDateField(); - public void init(NamedList args) {} + + public void init(NamedList args) { + } public Date getDate(FunctionQParser fp, String arg) { - if (arg==null) return null; - if (arg.startsWith("NOW") || (arg.length()>0 && Character.isDigit(arg.charAt(0)))) { + if (arg == null) return null; + if (arg.startsWith("NOW") || (arg.length() > 0 && Character.isDigit(arg.charAt(0)))) { return df.parseMathLenient(null, arg, fp.req); } return null; } public ValueSource getValueSource(FunctionQParser fp, String arg) { - if (arg==null) return null; + if (arg == null) return null; SchemaField f = fp.req.getSchema().getField(arg); if (f.getType().getClass() == DateField.class || f.getType().getClass() == LegacyDateField.class) { throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Can't use ms() function on non-numeric legacy date field " + arg); @@ -314,13 +433,13 @@ class DateValueSourceParser extends ValueSourceParser { public ValueSource parse(FunctionQParser fp) throws ParseException { String first = fp.parseArg(); String second = fp.parseArg(); - if (first==null) first="NOW"; + if (first == null) first = "NOW"; - Date d1=getDate(fp,first); - ValueSource v1 = d1==null ? getValueSource(fp, first) : null; + Date d1 = getDate(fp, first); + ValueSource v1 = d1 == null ? getValueSource(fp, first) : null; - Date d2=getDate(fp,second); - ValueSource v2 = d2==null ? getValueSource(fp, second) : null; + Date d2 = getDate(fp, second); + ValueSource v2 = d2 == null ? getValueSource(fp, second) : null; // d constant // v field @@ -330,54 +449,57 @@ class DateValueSourceParser extends ValueSourceParser { // vv subtract fields final long ms1 = (d1 == null) ? 0 : d1.getTime(); - final long ms2 = (d2 == null) ? 0 : d2.getTime(); + final long ms2 = (d2 == null) ? 0 : d2.getTime(); // "d,dd" handle both constant cases - if (d1 != null && v2==null) { - return new LongConstValueSource(ms1-ms2); + if (d1 != null && v2 == null) { + return new LongConstValueSource(ms1 - ms2); } // "v" just the date field - if (v1 != null && v2==null && d2==null) { + if (v1 != null && v2 == null && d2 == null) { return v1; } // "dv" - if (d1!=null && v2!=null) + if (d1 != null && v2 != null) return new DualFloatFunction(new LongConstValueSource(ms1), v2) { protected String name() { return "ms"; } + protected float func(int doc, DocValues aVals, DocValues bVals) { return ms1 - bVals.longVal(doc); } }; // "vd" - if (v1!=null && d2!=null) + if (v1 != null && d2 != null) return new DualFloatFunction(v1, new LongConstValueSource(ms2)) { protected String name() { return "ms"; } + protected float func(int doc, DocValues aVals, DocValues bVals) { return aVals.longVal(doc) - ms2; } }; // "vv" - if (v1!=null && v2!=null) - return new DualFloatFunction(v1,v2) { + if (v1 != null && v2 != null) + return new DualFloatFunction(v1, v2) { protected String name() { return "ms"; } + protected float func(int doc, DocValues aVals, DocValues bVals) { return aVals.longVal(doc) - bVals.longVal(doc); } }; - return null; // shouldn't happen + return null; // shouldn't happen } } @@ -400,18 +522,23 @@ class LongConstValueSource extends ValueSource { public float floatVal(int doc) { return constant; } + public int intVal(int doc) { - return (int)constant; + return (int) constant; } + public long longVal(int doc) { return constant; } + public double doubleVal(int doc) { return constant; } + public String strVal(int doc) { return Long.toString(constant); } + public String toString(int doc) { return description(); } @@ -419,12 +546,12 @@ class LongConstValueSource extends ValueSource { } public int hashCode() { - return (int)constant + (int)(constant>>>32); + return (int) constant + (int) (constant >>> 32); } public boolean equals(Object o) { if (LongConstValueSource.class != o.getClass()) return false; - LongConstValueSource other = (LongConstValueSource)o; + LongConstValueSource other = (LongConstValueSource) o; return this.constant == other.constant; } } \ No newline at end of file diff --git a/src/java/org/apache/solr/search/function/DegreeFunction.java b/src/java/org/apache/solr/search/function/DegreeFunction.java new file mode 100644 index 00000000000..02c6e00120a --- /dev/null +++ b/src/java/org/apache/solr/search/function/DegreeFunction.java @@ -0,0 +1,63 @@ +package org.apache.solr.search.function; + +import org.apache.lucene.index.IndexReader; + +import java.util.Map; +import java.io.IOException; + + +/** + * + * + **/ +public class DegreeFunction extends ValueSource{ + protected ValueSource valSource; + + public DegreeFunction(ValueSource valSource) { + this.valSource = valSource; + } + + public String description() { + return "deg(" + valSource.description() + ')'; + } + + public DocValues getValues(Map context, IndexReader reader) throws IOException { + final DocValues dv = valSource.getValues(context, reader); + return new DocValues() { + public float floatVal(int doc) { + return (float) doubleVal(doc); + } + + public int intVal(int doc) { + return (int) doubleVal(doc); + } + + public long longVal(int doc) { + return (long) doubleVal(doc); + } + + public double doubleVal(int doc) { + return Math.toDegrees(dv.doubleVal(doc)); + } + + public String strVal(int doc) { + return Double.toString(doubleVal(doc)); + } + + public String toString(int doc) { + return description() + '=' + floatVal(doc); + } + }; + } + + public boolean equals(Object o) { + if (o.getClass() != DegreeFunction.class) return false; + DegreeFunction other = (DegreeFunction) o; + return description().equals(other.description()) && valSource.equals(other.valSource); + } + + public int hashCode() { + return description().hashCode() + valSource.hashCode(); + }; + +} diff --git a/src/java/org/apache/solr/search/function/MultiFloatFunction.java b/src/java/org/apache/solr/search/function/MultiFloatFunction.java new file mode 100644 index 00000000000..29dbbe79d4f --- /dev/null +++ b/src/java/org/apache/solr/search/function/MultiFloatFunction.java @@ -0,0 +1,113 @@ +package org.apache.solr.search.function; +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Searcher; + +import java.util.Map; +import java.util.Arrays; +import java.io.IOException; + + +/** + * + * + **/ // a simple function of multiple sources +public abstract class MultiFloatFunction extends ValueSource { + protected final ValueSource[] sources; + + public MultiFloatFunction(ValueSource[] sources) { + this.sources = sources; + } + + abstract protected String name(); + abstract protected float func(int doc, DocValues[] valsArr); + + public String description() { + StringBuilder sb = new StringBuilder(); + sb.append(name()).append('('); + boolean firstTime=true; + for (ValueSource source : sources) { + if (firstTime) { + firstTime=false; + } else { + sb.append(','); + } + sb.append(source); + } + sb.append(')'); + return sb.toString(); + } + + public DocValues getValues(Map context, IndexReader reader) throws IOException { + final DocValues[] valsArr = new DocValues[sources.length]; + for (int i=0; i + * Assumes the value sources are in radians + *

+ * See http://en.wikipedia.org/wiki/Great-circle_distance and + * http://en.wikipedia.org/wiki/Haversine_formula for the actual formula and + * also http://www.movable-type.co.uk/scripts/latlong.html + * + * @see org.apache.solr.search.function.RadianFunction + */ +public class HaversineFunction extends ValueSource { + + private ValueSource x1; + private ValueSource y1; + private ValueSource x2; + private ValueSource y2; + private double radius; + + public HaversineFunction(ValueSource x1, ValueSource y1, ValueSource x2, ValueSource y2, double radius) { + this.x1 = x1; + this.y1 = y1; + this.x2 = x2; + this.y2 = y2; + this.radius = radius; + } + + protected String name() { + return "hsin"; + } + + /** + * @param doc The doc to score + * @param x1DV + * @param y1DV + * @param x2DV + * @param y2DV + * @return The haversine distance formula + */ + protected double distance(int doc, DocValues x1DV, DocValues y1DV, DocValues x2DV, DocValues y2DV) { + double result = 0; + double x1 = x1DV.doubleVal(doc); //in radians + double y1 = y1DV.doubleVal(doc); + double x2 = x2DV.doubleVal(doc); + double y2 = y2DV.doubleVal(doc); + + //make sure they aren't all the same, as then we can just return 0 + if ((x1 != x2) || (y1 != y2)) { + double diffX = x1 - x2; + double diffY = y1 - y2; + double hsinX = Math.sin(diffX / 2); + double hsinY = Math.sin(diffY / 2); + double h = hsinX * hsinX + + (Math.cos(x1) * Math.cos(x2) * hsinY * hsinY); + result = (radius * 2 * Math.atan2(Math.sqrt(h), Math.sqrt(1 - h))); + } + + return result; + } + + + @Override + public DocValues getValues(Map context, IndexReader reader) throws IOException { + final DocValues x1DV = x1.getValues(context, reader); + final DocValues y1DV = y1.getValues(context, reader); + final DocValues x2DV = x2.getValues(context, reader); + final DocValues y2DV = y2.getValues(context, reader); + return new DocValues() { + public float floatVal(int doc) { + return (float) doubleVal(doc); + } + + public int intVal(int doc) { + return (int) doubleVal(doc); + } + + public long longVal(int doc) { + return (long) doubleVal(doc); + } + + public double doubleVal(int doc) { + return (double) distance(doc, x1DV, y1DV, x2DV, y2DV); + } + + public String strVal(int doc) { + return Double.toString(doubleVal(doc)); + } + + @Override + public String toString(int doc) { + StringBuilder sb = new StringBuilder(); + sb.append(name()).append('('); + sb.append(x1DV.toString(doc)).append(',').append(y1DV.toString(doc)).append(',') + .append(x2DV.toString(doc)).append(',').append(y2DV.toString(doc)); + sb.append(')'); + return sb.toString(); + } + }; + } + + @Override + public void createWeight(Map context, Searcher searcher) throws IOException { + x1.createWeight(context, searcher); + x2.createWeight(context, searcher); + y1.createWeight(context, searcher); + y2.createWeight(context, searcher); + } + + public boolean equals(Object o) { + if (this.getClass() != o.getClass()) return false; + HaversineFunction other = (HaversineFunction) o; + return this.name().equals(other.name()) + && x1.equals(other.x1) && + y1.equals(other.y1) && + x2.equals(other.x2) && + y2.equals(other.y2); + } + + public int hashCode() { + + return x1.hashCode() + x2.hashCode() + y1.hashCode() + y2.hashCode() + name().hashCode(); + } + + public String description() { + StringBuilder sb = new StringBuilder(); + sb.append(name() + '('); + sb.append(x1).append(',').append(y1).append(',').append(x2).append(',').append(y2); + sb.append(')'); + return sb.toString(); + } +} diff --git a/src/java/org/apache/solr/search/function/distance/SquaredEuclideanFunction.java b/src/java/org/apache/solr/search/function/distance/SquaredEuclideanFunction.java new file mode 100644 index 00000000000..17932a8114c --- /dev/null +++ b/src/java/org/apache/solr/search/function/distance/SquaredEuclideanFunction.java @@ -0,0 +1,73 @@ +package org.apache.solr.search.function.distance; +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.solr.search.function.DocValues; +import org.apache.solr.search.function.ValueSource; + +import java.util.List; + + +/** + * While not strictly a distance, the Sq. Euclidean Distance is often all that is needed in many applications + * that require a distance, thus saving a sq. rt. calculation + * + **/ +public class SquaredEuclideanFunction extends VectorDistanceFunction { + protected String name = "sqedist"; + + public SquaredEuclideanFunction(List sources1, List sources2) { + super(-1, sources1, sources2);//overriding distance, so power doesn't matter here + } + + + protected String name() { + + return name; + } + + /** + * @param doc The doc to score + */ + protected double distance(int doc, DocValues[] docValues1, DocValues[] docValues2) { + double result = 0; + for (int i = 0; i < docValues1.length; i++) { + result += Math.pow(docValues1[i].doubleVal(doc) - docValues2[i].doubleVal(doc), 2); + } + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof SquaredEuclideanFunction)) return false; + if (!super.equals(o)) return false; + + SquaredEuclideanFunction that = (SquaredEuclideanFunction) o; + + if (!name.equals(that.name)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + name.hashCode(); + return result; + } +} diff --git a/src/java/org/apache/solr/search/function/distance/VectorDistanceFunction.java b/src/java/org/apache/solr/search/function/distance/VectorDistanceFunction.java new file mode 100644 index 00000000000..2a0cc5cc62c --- /dev/null +++ b/src/java/org/apache/solr/search/function/distance/VectorDistanceFunction.java @@ -0,0 +1,214 @@ +package org.apache.solr.search.function.distance; +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Searcher; +import org.apache.solr.common.SolrException; +import org.apache.solr.search.function.DocValues; +import org.apache.solr.search.function.ValueSource; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + + +/** + * Calculate the p-norm for a Vector. See http://en.wikipedia.org/wiki/Lp_space + *

+ * Common cases: + *

    + *
  • 0 = Sparseness calculation
  • + *
  • 1 = Manhattan distance
  • + *
  • 2 = Euclidean distance
  • + *
  • Integer.MAX_VALUE = infinite norm
  • + *
+ * + * @see SquaredEuclideanFunction for the special case + */ +public class VectorDistanceFunction extends ValueSource { + protected List sources1, sources2; + protected float power; + protected float oneOverPower; + + public VectorDistanceFunction(float power, List sources1, List sources2) { + this.power = power; + this.oneOverPower = 1 / power; + this.sources1 = sources1; + this.sources2 = sources2; + if ((sources1.size() != sources2.size())) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Illegal number of sources"); + } + } + + protected String name() { + return "dist"; + }; + + /** + * Calculate the distance + * + * @param doc The current doc + * @param docValues1 The values from the first set of value sources + * @param docValues2 The values from the second set of value sources + * @return The distance + */ + protected double distance(int doc, DocValues[] docValues1, DocValues[] docValues2) { + double result = 0; + //Handle some special cases: + if (power == 0) { + for (int i = 0; i < docValues1.length; i++) { + //sparseness measure + result += docValues1[i].doubleVal(doc) - docValues2[i].doubleVal(doc) == 0 ? 0 : 1; + } + } else if (power == 1.0) { + for (int i = 0; i < docValues1.length; i++) { + result += docValues1[i].doubleVal(doc) - docValues2[i].doubleVal(doc); + } + } else if (power == 2.0) { + for (int i = 0; i < docValues1.length; i++) { + double v = docValues1[i].doubleVal(doc) - docValues2[i].doubleVal(doc); + result += v * v; + } + result = Math.sqrt(result); + } else if (power == Integer.MAX_VALUE || Double.isInfinite(power)) {//infininte norm? + for (int i = 0; i < docValues1.length; i++) { + //TODO: is this the correct infinite norm? + result = Math.max(docValues1[i].doubleVal(doc) - docValues2[i].doubleVal(doc), result); + } + + } else { + for (int i = 0; i < docValues1.length; i++) { + result += Math.pow(docValues1[i].doubleVal(doc) - docValues2[i].doubleVal(doc), power); + } + result = Math.pow(result, oneOverPower); + } + + return result; + } + + @Override + public DocValues getValues(Map context, IndexReader reader) throws IOException { + final DocValues[] valsArr1 = new DocValues[sources1.size()]; + int i = 0; + for (ValueSource source : sources1) { + valsArr1[i++] = source.getValues(context, reader); + } + final DocValues[] valsArr2 = new DocValues[sources2.size()]; + i = 0; + for (ValueSource source : sources2) { + valsArr2[i++] = source.getValues(context, reader); + } + + + return new DocValues() { + public float floatVal(int doc) { + return (float) doubleVal(doc); + } + + public int intVal(int doc) { + return (int) doubleVal(doc); + } + + public long longVal(int doc) { + return (long) doubleVal(doc); + } + + public double doubleVal(int doc) { + return (double) distance(doc, valsArr1, valsArr2); + } + + public String strVal(int doc) { + return Double.toString(doubleVal(doc)); + } + + @Override + public String toString(int doc) { + StringBuilder sb = new StringBuilder(); + sb.append(name()).append('(').append(power).append(','); + boolean firstTime = true; + for (DocValues vals : valsArr1) { + if (firstTime) { + firstTime = false; + } else { + sb.append(','); + } + sb.append(vals.toString(doc)); + } + for (DocValues vals : valsArr2) { + sb.append(',');//we will always have valsArr1, else there is an error + sb.append(vals.toString(doc)); + } + sb.append(')'); + return sb.toString(); + } + }; + } + + @Override + public void createWeight(Map context, Searcher searcher) throws IOException { + for (ValueSource source : sources1) { + source.createWeight(context, searcher); + } + for (ValueSource source : sources2) { + source.createWeight(context, searcher); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof VectorDistanceFunction)) return false; + + VectorDistanceFunction that = (VectorDistanceFunction) o; + + if (Float.compare(that.power, power) != 0) return false; + if (!sources1.equals(that.sources1)) return false; + if (!sources2.equals(that.sources2)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = sources1.hashCode(); + result = 31 * result + sources2.hashCode(); + result = 31 * result + (power != +0.0f ? Float.floatToIntBits(power) : 0); + return result; + } + + public String description() { + StringBuilder sb = new StringBuilder(); + sb.append(name()).append('(').append(power).append(','); + boolean firstTime = true; + for (ValueSource source : sources1) { + if (firstTime) { + firstTime = false; + } else { + sb.append(','); + } + sb.append(source); + } + for (ValueSource source : sources2) { + sb.append(',');//we will always have sources1, else there is an error + sb.append(source); + } + sb.append(')'); + return sb.toString(); + } + +} diff --git a/src/test/org/apache/solr/search/function/TestFunctionQuery.java b/src/test/org/apache/solr/search/function/TestFunctionQuery.java index 53aae807c9c..7e9cf92ffdd 100755 --- a/src/test/org/apache/solr/search/function/TestFunctionQuery.java +++ b/src/test/org/apache/solr/search/function/TestFunctionQuery.java @@ -327,4 +327,20 @@ public class TestFunctionQuery extends AbstractSolrTestCase { assertQ(req("fl","*,score","q", q, "qq","text:batman", "fq",fq), "//float[@name='score']<'1.0'"); assertQ(req("fl","*,score","q", q, "qq","text:superman", "fq",fq), "//float[@name='score']>'1.0'"); } + + public void testDegreeRads() throws Exception { + assertU(adoc("id", "1", "x_td", "0", "y_td", "0")); + assertU(adoc("id", "2", "x_td", "90", "y_td", String.valueOf(Math.PI / 2))); + assertU(adoc("id", "3", "x_td", "45", "y_td", String.valueOf(Math.PI / 4))); + + + assertU(commit()); + assertQ(req("fl", "*,score", "q", "{!func}rad(x_td)", "fq", "id:1"), "//float[@name='score']='0.0'"); + assertQ(req("fl", "*,score", "q", "{!func}rad(x_td)", "fq", "id:2"), "//float[@name='score']='" + (float) (Math.PI / 2) + "'"); + assertQ(req("fl", "*,score", "q", "{!func}rad(x_td)", "fq", "id:3"), "//float[@name='score']='" + (float) (Math.PI / 4) + "'"); + + assertQ(req("fl", "*,score", "q", "{!func}deg(y_td)", "fq", "id:1"), "//float[@name='score']='0.0'"); + assertQ(req("fl", "*,score", "q", "{!func}deg(y_td)", "fq", "id:2"), "//float[@name='score']='90.0'"); + assertQ(req("fl", "*,score", "q", "{!func}deg(y_td)", "fq", "id:3"), "//float[@name='score']='45.0'"); + } } \ No newline at end of file diff --git a/src/test/org/apache/solr/search/function/distance/DistanceFunctionTest.java b/src/test/org/apache/solr/search/function/distance/DistanceFunctionTest.java new file mode 100644 index 00000000000..05ef6606669 --- /dev/null +++ b/src/test/org/apache/solr/search/function/distance/DistanceFunctionTest.java @@ -0,0 +1,107 @@ +package org.apache.solr.search.function.distance; +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.solr.common.SolrException; +import org.apache.solr.util.AbstractSolrTestCase; + + +/** + * + * + **/ +public class DistanceFunctionTest extends AbstractSolrTestCase { + public String getSchemaFile() { + return "schema11.xml"; + } + + public String getSolrConfigFile() { + return "solrconfig-functionquery.xml"; + } + + public String getCoreName() { + return "basic"; + } + + + public void testHaversine() throws Exception { + assertU(adoc("id", "1", "x_td", "0", "y_td", "0")); + assertU(adoc("id", "2", "x_td", "0", "y_td", String.valueOf(Math.PI / 2))); + assertU(adoc("id", "3", "x_td", String.valueOf(Math.PI / 2), "y_td", String.valueOf(Math.PI / 2))); + assertU(adoc("id", "4", "x_td", String.valueOf(Math.PI / 4), "y_td", String.valueOf(Math.PI / 4))); + assertU(commit()); + //Get the haversine distance between the point 0,0 and the docs above assuming a radius of 1 + assertQ(req("fl", "*,score", "q", "{!func}hsin(x_td, y_td, 0, 0, 1)", "fq", "id:1"), "//float[@name='score']='0.0'"); + assertQ(req("fl", "*,score", "q", "{!func}hsin(x_td, y_td, 0, 0, 1)", "fq", "id:2"), "//float[@name='score']='" + (float) (Math.PI / 2) + "'"); + assertQ(req("fl", "*,score", "q", "{!func}hsin(x_td, y_td, 0, 0, 1)", "fq", "id:3"), "//float[@name='score']='" + (float) (Math.PI / 2) + "'"); + assertQ(req("fl", "*,score", "q", "{!func}hsin(x_td, y_td, 0, 0, 1)", "fq", "id:4"), "//float[@name='score']='1.0471976'"); + } + + public void testVector() throws Exception { + assertU(adoc("id", "1", "x_td", "0", "y_td", "0", "z_td", "0", "w_td", "0")); + assertU(adoc("id", "2", "x_td", "0", "y_td", "1", "z_td", "0", "w_td", "0")); + assertU(adoc("id", "3", "x_td", "1", "y_td", "1", "z_td", "1", "w_td", "1")); + assertU(adoc("id", "4", "x_td", "1", "y_td", "0", "z_td", "0", "w_td", "0")); + assertU(adoc("id", "5", "x_td", "2.3", "y_td", "5.5", "z_td", "7.9", "w_td", "-2.4")); + assertU(commit()); + //two dimensions, notice how we only pass in 4 value sources + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, 0, 0)", "fq", "id:1"), "//float[@name='score']='0.0'"); + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, 0, 0)", "fq", "id:2"), "//float[@name='score']='1.0'"); + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, 0, 0)", "fq", "id:3"), "//float[@name='score']='" + 2.0f + "'"); + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, 0, 0)", "fq", "id:4"), "//float[@name='score']='1.0'"); + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, 0, 0)", "fq", "id:5"), "//float[@name='score']='" + (float) (2.3 * 2.3 + 5.5 * 5.5) + "'"); + + //three dimensions, notice how we pass in 6 value sources + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, 0, 0, 0)", "fq", "id:1"), "//float[@name='score']='0.0'"); + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, 0, 0, 0)", "fq", "id:2"), "//float[@name='score']='1.0'"); + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, 0, 0, 0)", "fq", "id:3"), "//float[@name='score']='" + 3.0f + "'"); + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, 0, 0, 0)", "fq", "id:4"), "//float[@name='score']='1.0'"); + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, 0, 0, 0)", "fq", "id:5"), "//float[@name='score']='" + (float) (2.3 * 2.3 + 5.5 * 5.5 + 7.9 * 7.9) + "'"); + + //four dimensions, notice how we pass in 8 value sources + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, w_td, 0, 0, 0, 0)", "fq", "id:1"), "//float[@name='score']='0.0'"); + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, w_td, 0, 0, 0, 0)", "fq", "id:2"), "//float[@name='score']='1.0'"); + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, w_td, 0, 0, 0, 0)", "fq", "id:3"), "//float[@name='score']='" + 4.0f + "'"); + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, w_td, 0, 0, 0, 0)", "fq", "id:4"), "//float[@name='score']='1.0'"); + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, w_td, 0, 0, 0, 0)", "fq", "id:5"), "//float[@name='score']='" + (float) (2.3 * 2.3 + 5.5 * 5.5 + 7.9 * 7.9 + 2.4 * 2.4) + "'"); + //Pass in imbalanced list, throw exception + try { + assertQ(req("fl", "*,score", "q", "{!func}sqedist(x_td, y_td, z_td, w_td, 0, 0, 0)", "fq", "id:1"), "//float[@name='score']='0.0'"); + assertTrue("should throw an exception", false); + } catch (Exception e) { + Throwable cause = e.getCause(); + assertNotNull(cause); + assertTrue(cause instanceof SolrException); + } + //do one test of Euclidean + //two dimensions, notice how we only pass in 4 value sources + assertQ(req("fl", "*,score", "q", "{!func}dist(2, x_td, y_td, 0, 0)", "fq", "id:1"), "//float[@name='score']='0.0'"); + assertQ(req("fl", "*,score", "q", "{!func}dist(2, x_td, y_td, 0, 0)", "fq", "id:2"), "//float[@name='score']='1.0'"); + assertQ(req("fl", "*,score", "q", "{!func}dist(2, x_td, y_td, 0, 0)", "fq", "id:3"), "//float[@name='score']='" + (float) Math.sqrt(2.0) + "'"); + assertQ(req("fl", "*,score", "q", "{!func}dist(2, x_td, y_td, 0, 0)", "fq", "id:4"), "//float[@name='score']='1.0'"); + assertQ(req("fl", "*,score", "q", "{!func}dist(2, x_td, y_td, 0, 0)", "fq", "id:5"), "//float[@name='score']='" + (float) Math.sqrt((2.3 * 2.3 + 5.5 * 5.5)) + "'"); + + //do one test of Manhattan + //two dimensions, notice how we only pass in 4 value sources + assertQ(req("fl", "*,score", "q", "{!func}dist(1, x_td, y_td, 0, 0)", "fq", "id:1"), "//float[@name='score']='0.0'"); + assertQ(req("fl", "*,score", "q", "{!func}dist(1, x_td, y_td, 0, 0)", "fq", "id:2"), "//float[@name='score']='1.0'"); + assertQ(req("fl", "*,score", "q", "{!func}dist(1, x_td, y_td, 0, 0)", "fq", "id:3"), "//float[@name='score']='" + (float) 2.0 + "'"); + assertQ(req("fl", "*,score", "q", "{!func}dist(1, x_td, y_td, 0, 0)", "fq", "id:4"), "//float[@name='score']='1.0'"); + assertQ(req("fl", "*,score", "q", "{!func}dist(1, x_td, y_td, 0, 0)", "fq", "id:5"), "//float[@name='score']='" + (float) (2.3 + 5.5) + "'"); + } + +}