diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 5cb97c99ba0..d7cce0a621b 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -155,6 +155,9 @@ New Features * LUCENE-5297: Allow to range-facet on any ValueSource, not just NumericDocValues fields. (Shai Erera) + +* LUCENE-5337: Add Payload support to FileDictionary (Suggest) and make it more + configurable (Areek Zilluer via Erick Erickson) Bug Fixes diff --git a/lucene/suggest/src/java/org/apache/lucene/search/suggest/FileDictionary.java b/lucene/suggest/src/java/org/apache/lucene/search/suggest/FileDictionary.java index 16318b33b1d..5e5968574d0 100644 --- a/lucene/suggest/src/java/org/apache/lucene/search/suggest/FileDictionary.java +++ b/lucene/suggest/src/java/org/apache/lucene/search/suggest/FileDictionary.java @@ -28,46 +28,126 @@ import org.apache.lucene.util.IOUtils; /** * Dictionary represented by a text file. * - *

Format allowed: 1 string per line, optionally with a tab-separated integer value:
- * word1 TAB 100
- * word2 word3 TAB 101
- * word4 word5 TAB 102
+ *

Format allowed: 1 entry per line:
+ * An entry can be:
+ *

+ * where the default fieldDelimiter is {@value #DEFAULT_FIELD_DELIMITER}
+ *

+ * NOTE: + *

+ *

+ * Example:
+ * word1 word2 TAB 100 TAB payload1
+ * word3 TAB 101
+ * word4 word3 TAB 102
*/ public class FileDictionary implements Dictionary { + /** + * Tab-delimited fields are most common thus the default, but one can override this via the constructor + */ + public final static String DEFAULT_FIELD_DELIMITER = "\t"; private BufferedReader in; private String line; private boolean done = false; + private final String fieldDelimiter; /** * Creates a dictionary based on an inputstream. + * Using {@link #DEFAULT_FIELD_DELIMITER} as the + * field seperator in a line. *

* NOTE: content is treated as UTF-8 */ public FileDictionary(InputStream dictFile) { - in = new BufferedReader(IOUtils.getDecodingReader(dictFile, IOUtils.CHARSET_UTF_8)); + this(dictFile, DEFAULT_FIELD_DELIMITER); } /** * Creates a dictionary based on a reader. + * Using {@link #DEFAULT_FIELD_DELIMITER} as the + * field seperator in a line. */ public FileDictionary(Reader reader) { + this(reader, DEFAULT_FIELD_DELIMITER); + } + + /** + * Creates a dictionary based on a reader. + * Using fieldDelimiter to seperate out the + * fields in a line. + */ + public FileDictionary(Reader reader, String fieldDelimiter) { in = new BufferedReader(reader); + this.fieldDelimiter = fieldDelimiter; + } + + /** + * Creates a dictionary based on an inputstream. + * Using fieldDelimiter to seperate out the + * fields in a line. + *

+ * NOTE: content is treated as UTF-8 + */ + public FileDictionary(InputStream dictFile, String fieldDelimiter) { + in = new BufferedReader(IOUtils.getDecodingReader(dictFile, IOUtils.CHARSET_UTF_8)); + this.fieldDelimiter = fieldDelimiter; } @Override public InputIterator getWordsIterator() { - return new FileIterator(); + try { + return new FileIterator(); + } catch (IOException e) { + throw new RuntimeException(); + } } final class FileIterator implements InputIterator { - private long curFreq; + private long curWeight; private final BytesRef spare = new BytesRef(); + private BytesRef curPayload = new BytesRef(); + private boolean isFirstLine = true; + private boolean hasPayloads = false; + + private FileIterator() throws IOException { + line = in.readLine(); + if (line == null) { + done = true; + IOUtils.close(in); + } else { + String[] fields = line.split(fieldDelimiter); + if (fields.length > 3) { + throw new IllegalArgumentException("More than 3 fields in one line"); + } else if (fields.length == 3) { // term, weight, payload + hasPayloads = true; + spare.copyChars(fields[0]); + readWeight(fields[1]); + curPayload.copyChars(fields[2]); + } else if (fields.length == 2) { // term, weight + spare.copyChars(fields[0]); + readWeight(fields[1]); + } else { // only term + spare.copyChars(fields[0]); + curWeight = 1; + } + } + } - @Override public long weight() { - return curFreq; + return curWeight; } @Override @@ -75,20 +155,33 @@ public class FileDictionary implements Dictionary { if (done) { return null; } + if (isFirstLine) { + isFirstLine = false; + return spare; + } line = in.readLine(); if (line != null) { - String[] fields = line.split("\t"); - if (fields.length > 1) { - // keep reading floats for bw compat - try { - curFreq = Long.parseLong(fields[1]); - } catch (NumberFormatException e) { - curFreq = (long)Double.parseDouble(fields[1]); - } + String[] fields = line.split(fieldDelimiter); + if (fields.length > 3) { + throw new IllegalArgumentException("More than 3 fields in one line"); + } else if (fields.length == 3) { // term, weight and payload spare.copyChars(fields[0]); - } else { - spare.copyChars(line); - curFreq = 1; + readWeight(fields[1]); + if (hasPayloads) { + curPayload.copyChars(fields[2]); + } + } else if (fields.length == 2) { // term, weight + spare.copyChars(fields[0]); + readWeight(fields[1]); + if (hasPayloads) { // have an empty payload + curPayload = new BytesRef(); + } + } else { // only term + spare.copyChars(fields[0]); + curWeight = 1; + if (hasPayloads) { + curPayload = new BytesRef(); + } } return spare; } else { @@ -100,12 +193,21 @@ public class FileDictionary implements Dictionary { @Override public BytesRef payload() { - return null; + return (hasPayloads) ? curPayload : null; } @Override public boolean hasPayloads() { - return false; + return hasPayloads; + } + + private void readWeight(String weight) { + // keep reading floats for bw compat + try { + curWeight = Long.parseLong(weight); + } catch (NumberFormatException e) { + curWeight = (long)Double.parseDouble(weight); + } } } } diff --git a/lucene/suggest/src/test/org/apache/lucene/search/suggest/FileDictionaryTest.java b/lucene/suggest/src/test/org/apache/lucene/search/suggest/FileDictionaryTest.java new file mode 100644 index 00000000000..fbeb6df22a4 --- /dev/null +++ b/lucene/suggest/src/test/org/apache/lucene/search/suggest/FileDictionaryTest.java @@ -0,0 +1,196 @@ +package org.apache.lucene.search.suggest; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.LuceneTestCase; +import org.apache.lucene.util._TestUtil; +import org.junit.Test; + + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +public class FileDictionaryTest extends LuceneTestCase { + + private Map.Entry, String> generateFileEntry(String fieldDelimiter, boolean hasWeight, boolean hasPayload) { + List entryValues = new ArrayList<>(); + StringBuilder sb = new StringBuilder(); + String term = _TestUtil.randomSimpleString(random(), 1, 300); + sb.append(term); + entryValues.add(term); + if (hasWeight) { + sb.append(fieldDelimiter); + long weight = _TestUtil.nextLong(random(), Long.MIN_VALUE, Long.MAX_VALUE); + sb.append(weight); + entryValues.add(String.valueOf(weight)); + } + if (hasPayload) { + sb.append(fieldDelimiter); + String payload = _TestUtil.randomSimpleString(random(), 1, 300); + sb.append(payload); + entryValues.add(payload); + } + sb.append("\n"); + return new SimpleEntry, String>(entryValues, sb.toString()); + } + + private Map.Entry>,String> generateFileInput(int count, String fieldDelimiter, boolean hasWeights, boolean hasPayloads) { + List> entries = new ArrayList<>(); + StringBuilder sb = new StringBuilder(); + boolean hasPayload = hasPayloads; + for (int i = 0; i < count; i++) { + if (hasPayloads) { + hasPayload = (i==0) ? true : random().nextBoolean(); + } + Map.Entry, String> entrySet = generateFileEntry(fieldDelimiter, (!hasPayloads && hasWeights) ? random().nextBoolean() : hasWeights, hasPayload); + entries.add(entrySet.getKey()); + sb.append(entrySet.getValue()); + } + return new SimpleEntry>, String>(entries, sb.toString()); + } + + @Test + public void testFileWithTerm() throws IOException { + Map.Entry>,String> fileInput = generateFileInput(atLeast(100), FileDictionary.DEFAULT_FIELD_DELIMITER, false, false); + InputStream inputReader = new ByteArrayInputStream(fileInput.getValue().getBytes("UTF-8")); + FileDictionary dictionary = new FileDictionary(inputReader); + List> entries = fileInput.getKey(); + InputIterator inputIter = dictionary.getWordsIterator(); + assertFalse(inputIter.hasPayloads()); + BytesRef term; + int count = 0; + while((term = inputIter.next()) != null) { + assertTrue(entries.size() > count); + List entry = entries.get(count); + assertTrue(entry.size() >= 1); // at least a term + assertEquals(entry.get(0), term.utf8ToString()); + assertEquals(1, inputIter.weight()); + assertNull(inputIter.payload()); + count++; + } + assertEquals(count, entries.size()); + } + + @Test + public void testFileWithWeight() throws IOException { + Map.Entry>,String> fileInput = generateFileInput(atLeast(100), FileDictionary.DEFAULT_FIELD_DELIMITER, true, false); + InputStream inputReader = new ByteArrayInputStream(fileInput.getValue().getBytes("UTF-8")); + FileDictionary dictionary = new FileDictionary(inputReader); + List> entries = fileInput.getKey(); + InputIterator inputIter = dictionary.getWordsIterator(); + assertFalse(inputIter.hasPayloads()); + BytesRef term; + int count = 0; + while((term = inputIter.next()) != null) { + assertTrue(entries.size() > count); + List entry = entries.get(count); + assertTrue(entry.size() >= 1); // at least a term + assertEquals(entry.get(0), term.utf8ToString()); + assertEquals((entry.size() == 2) ? Long.parseLong(entry.get(1)) : 1, inputIter.weight()); + assertNull(inputIter.payload()); + count++; + } + assertEquals(count, entries.size()); + } + + @Test + public void testFileWithWeightAndPayload() throws IOException { + Map.Entry>,String> fileInput = generateFileInput(atLeast(100), FileDictionary.DEFAULT_FIELD_DELIMITER, true, true); + InputStream inputReader = new ByteArrayInputStream(fileInput.getValue().getBytes("UTF-8")); + FileDictionary dictionary = new FileDictionary(inputReader); + List> entries = fileInput.getKey(); + InputIterator inputIter = dictionary.getWordsIterator(); + assertTrue(inputIter.hasPayloads()); + BytesRef term; + int count = 0; + while((term = inputIter.next()) != null) { + assertTrue(entries.size() > count); + List entry = entries.get(count); + assertTrue(entry.size() >= 2); // at least term and weight + assertEquals(entry.get(0), term.utf8ToString()); + assertEquals(Long.parseLong(entry.get(1)), inputIter.weight()); + if (entry.size() == 3) { + assertEquals(entry.get(2), inputIter.payload().utf8ToString()); + } else { + assertEquals(inputIter.payload().length, 0); + } + count++; + } + assertEquals(count, entries.size()); + } + + @Test + public void testFileWithOneEntry() throws IOException { + Map.Entry>,String> fileInput = generateFileInput(1, FileDictionary.DEFAULT_FIELD_DELIMITER, true, true); + InputStream inputReader = new ByteArrayInputStream(fileInput.getValue().getBytes("UTF-8")); + FileDictionary dictionary = new FileDictionary(inputReader); + List> entries = fileInput.getKey(); + InputIterator inputIter = dictionary.getWordsIterator(); + assertTrue(inputIter.hasPayloads()); + BytesRef term; + int count = 0; + while((term = inputIter.next()) != null) { + assertTrue(entries.size() > count); + List entry = entries.get(count); + assertTrue(entry.size() >= 2); // at least term and weight + assertEquals(entry.get(0), term.utf8ToString()); + assertEquals(Long.parseLong(entry.get(1)), inputIter.weight()); + if (entry.size() == 3) { + assertEquals(entry.get(2), inputIter.payload().utf8ToString()); + } else { + assertEquals(inputIter.payload().length, 0); + } + count++; + } + assertEquals(count, entries.size()); + } + + + @Test + public void testFileWithDifferentDelimiter() throws IOException { + Map.Entry>,String> fileInput = generateFileInput(atLeast(100), " , ", true, true); + InputStream inputReader = new ByteArrayInputStream(fileInput.getValue().getBytes("UTF-8")); + FileDictionary dictionary = new FileDictionary(inputReader, " , "); + List> entries = fileInput.getKey(); + InputIterator inputIter = dictionary.getWordsIterator(); + assertTrue(inputIter.hasPayloads()); + BytesRef term; + int count = 0; + while((term = inputIter.next()) != null) { + assertTrue(entries.size() > count); + List entry = entries.get(count); + assertTrue(entry.size() >= 2); // at least term and weight + assertEquals(entry.get(0), term.utf8ToString()); + assertEquals(Long.parseLong(entry.get(1)), inputIter.weight()); + if (entry.size() == 3) { + assertEquals(entry.get(2), inputIter.payload().utf8ToString()); + } else { + assertEquals(inputIter.payload().length, 0); + } + count++; + } + assertEquals(count, entries.size()); + } + +}