relocate method in BufferAggregator. (#4071)

*  relocate method in BufferAggregator.

* Unused import.

* Detailed javadoc.

* using Int2ObjectMap.

* batch relocate.

* Revert batch relocate.

* Unused import.

* code comments.

* code comment.
This commit is contained in:
Akash Dwivedi 2017-03-23 13:07:59 -07:00 committed by Gian Merlino
parent f68ba4128f
commit ff7f90b02d
8 changed files with 266 additions and 19 deletions

View File

@ -28,20 +28,19 @@ import com.yahoo.sketches.theta.Union;
import io.druid.query.aggregation.BufferAggregator;
import io.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import io.druid.segment.ObjectColumnSelector;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.IdentityHashMap;
public class SketchBufferAggregator implements BufferAggregator
{
private final ObjectColumnSelector selector;
private final int size;
private final int maxIntermediateSize;
private NativeMemory nm;
private final Map<Integer, Union> unions = new HashMap<>(); //position in BB -> Union Object
private final IdentityHashMap<ByteBuffer, Int2ObjectMap<Union>> unions = new IdentityHashMap<>();
private final IdentityHashMap<ByteBuffer, NativeMemory> nmCache = new IdentityHashMap<>();
public SketchBufferAggregator(ObjectColumnSelector selector, int size, int maxIntermediateSize)
{
@ -53,12 +52,7 @@ public class SketchBufferAggregator implements BufferAggregator
@Override
public void init(ByteBuffer buf, int position)
{
if (nm == null) {
nm = new NativeMemory(buf);
}
Memory mem = new MemoryRegion(nm, position, maxIntermediateSize);
unions.put(position, (Union) SetOperation.builder().initMemory(mem).build(size, Family.UNION));
createNewUnion(buf, position, false);
}
@Override
@ -87,12 +81,27 @@ public class SketchBufferAggregator implements BufferAggregator
//Note that this is not threadsafe and I don't think it needs to be
private Union getUnion(ByteBuffer buf, int position)
{
Union union = unions.get(position);
if (union == null) {
Memory mem = new MemoryRegion(nm, position, maxIntermediateSize);
union = (Union) SetOperation.wrap(mem);
unions.put(position, union);
Int2ObjectMap<Union> unionMap = unions.get(buf);
Union union = unionMap != null ? unionMap.get(position) : null;
if (union != null) {
return union;
}
return createNewUnion(buf, position, true);
}
private Union createNewUnion(ByteBuffer buf, int position, boolean isWrapped)
{
NativeMemory nm = getNativeMemory(buf);
Memory mem = new MemoryRegion(nm, position, maxIntermediateSize);
Union union = isWrapped
? (Union) SetOperation.wrap(mem)
: (Union) SetOperation.builder().initMemory(mem).build(size, Family.UNION);
Int2ObjectMap<Union> unionMap = unions.get(buf);
if (unionMap == null) {
unionMap = new Int2ObjectOpenHashMap<>();
unions.put(buf, unionMap);
}
unionMap.put(position, union);
return union;
}
@ -119,4 +128,29 @@ public class SketchBufferAggregator implements BufferAggregator
{
inspector.visit("selector", selector);
}
@Override
public void relocate(int oldPosition, int newPosition, ByteBuffer oldBuffer, ByteBuffer newBuffer)
{
createNewUnion(newBuffer, newPosition, true);
Int2ObjectMap<Union> unionMap = unions.get(oldBuffer);
if (unionMap != null) {
unionMap.remove(oldPosition);
if (unionMap.isEmpty()) {
unions.remove(oldBuffer);
nmCache.remove(oldBuffer);
}
}
}
private NativeMemory getNativeMemory(ByteBuffer buffer)
{
NativeMemory nm = nmCache.get(buffer);
if (nm == null) {
nm = new NativeMemory(buffer);
nmCache.put(buffer, nm);
}
return nm;
}
}

View File

@ -290,4 +290,16 @@ public class SketchHolder
throw new IllegalArgumentException("Unknown sketch operation " + func);
}
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
return this.getSketch().equals(((SketchHolder) o).getSketch());
}
}

View File

@ -0,0 +1,97 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.query.aggregation.datasketches.theta;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.yahoo.sketches.theta.Sketches;
import com.yahoo.sketches.theta.UpdateSketch;
import io.druid.data.input.MapBasedRow;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.CountAggregatorFactory;
import io.druid.query.groupby.epinephelinae.BufferGrouper;
import io.druid.query.groupby.epinephelinae.Grouper;
import io.druid.query.groupby.epinephelinae.GrouperTestUtil;
import io.druid.query.groupby.epinephelinae.TestColumnSelectorFactory;
import org.junit.Assert;
import org.junit.Test;
import java.nio.ByteBuffer;
public class BufferGrouperUsingSketchMergeAggregatorFactoryTest
{
private static BufferGrouper<Integer> makeGrouper(
TestColumnSelectorFactory columnSelectorFactory,
int bufferSize,
int initialBuckets
)
{
final BufferGrouper<Integer> grouper = new BufferGrouper<>(
Suppliers.ofInstance(ByteBuffer.allocate(bufferSize)),
GrouperTestUtil.intKeySerde(),
columnSelectorFactory,
new AggregatorFactory[]{
new SketchMergeAggregatorFactory("sketch", "sketch", 16, false, true, 2),
new CountAggregatorFactory("count")
},
Integer.MAX_VALUE,
0.75f,
initialBuckets
);
grouper.init();
return grouper;
}
@Test
public void testGrowingBufferGrouper()
{
final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory();
final Grouper<Integer> grouper = makeGrouper(columnSelectorFactory, 100000, 2);
try {
final int expectedMaxSize = 5;
SketchHolder sketchHolder = SketchHolder.of(Sketches.updateSketchBuilder().build(16));
UpdateSketch updateSketch = (UpdateSketch) sketchHolder.getSketch();
updateSketch.update(1);
columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.<String, Object>of("sketch", sketchHolder)));
for (int i = 0; i < expectedMaxSize; i++) {
Assert.assertTrue(String.valueOf(i), grouper.aggregate(i));
}
updateSketch.update(3);
columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.<String, Object>of("sketch", sketchHolder)));
for (int i = 0; i < expectedMaxSize; i++) {
Assert.assertTrue(String.valueOf(i), grouper.aggregate(i));
}
Object[] holders = Lists.newArrayList(grouper.iterator(true)).get(0).getValues();
Assert.assertEquals(2.0d, ((SketchHolder) holders[0]).getEstimate(), 0);
}
finally {
grouper.close();
}
}
}

View File

@ -28,6 +28,7 @@ import com.yahoo.sketches.theta.SetOperation;
import com.yahoo.sketches.theta.Sketch;
import com.yahoo.sketches.theta.Sketches;
import com.yahoo.sketches.theta.Union;
import com.yahoo.sketches.theta.UpdateSketch;
import io.druid.data.input.MapBasedRow;
import io.druid.data.input.Row;
import io.druid.java.util.common.granularity.Granularities;
@ -39,6 +40,8 @@ import io.druid.query.aggregation.PostAggregator;
import io.druid.query.aggregation.post.FieldAccessPostAggregator;
import io.druid.query.groupby.GroupByQueryConfig;
import io.druid.query.groupby.GroupByQueryRunnerTest;
import io.druid.query.groupby.epinephelinae.GrouperTestUtil;
import io.druid.query.groupby.epinephelinae.TestColumnSelectorFactory;
import org.joda.time.DateTime;
import org.junit.Assert;
import org.junit.Rule;
@ -389,6 +392,23 @@ public class SketchAggregationTest
Assert.assertEquals(1, comparator.compare(SketchHolder.of(union2), SketchHolder.of(sketch1)));
}
@Test
public void testRelocation()
{
final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory();
SketchHolder sketchHolder = SketchHolder.of(Sketches.updateSketchBuilder().build(16));
UpdateSketch updateSketch = (UpdateSketch) sketchHolder.getSketch();
updateSketch.update(1);
columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.<String, Object>of("sketch", sketchHolder)));
SketchHolder[] holders = helper.runRelocateVerificationTest(
new SketchMergeAggregatorFactory("sketch", "sketch", 16, false, true, 2),
columnSelectorFactory,
SketchHolder.class
);
Assert.assertEquals(holders[0].getEstimate(), holders[1].getEstimate(), 0);
}
private void assertPostAggregatorSerde(PostAggregator agg) throws Exception
{
Assert.assertEquals(

View File

@ -22,6 +22,8 @@ package io.druid.query.aggregation.datasketches.theta.oldapi;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.io.Files;
import com.yahoo.sketches.theta.Sketches;
import com.yahoo.sketches.theta.UpdateSketch;
import io.druid.data.input.MapBasedRow;
import io.druid.java.util.common.granularity.Granularities;
import io.druid.java.util.common.guava.Sequence;
@ -29,9 +31,12 @@ import io.druid.java.util.common.guava.Sequences;
import io.druid.query.aggregation.AggregationTestHelper;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.PostAggregator;
import io.druid.query.aggregation.datasketches.theta.SketchHolder;
import io.druid.query.aggregation.post.FieldAccessPostAggregator;
import io.druid.query.groupby.GroupByQueryConfig;
import io.druid.query.groupby.GroupByQueryRunnerTest;
import io.druid.query.groupby.epinephelinae.GrouperTestUtil;
import io.druid.query.groupby.epinephelinae.TestColumnSelectorFactory;
import org.joda.time.DateTime;
import org.junit.Assert;
import org.junit.Rule;
@ -194,6 +199,23 @@ public class OldApiSketchAggregationTest
);
}
@Test
public void testRelocation()
{
final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory();
SketchHolder sketchHolder = SketchHolder.of(Sketches.updateSketchBuilder().build(16));
UpdateSketch updateSketch = (UpdateSketch) sketchHolder.getSketch();
updateSketch.update(1);
columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.<String, Object>of("sketch", sketchHolder)));
SketchHolder[] holders = helper.runRelocateVerificationTest(
new OldSketchMergeAggregatorFactory("sketch", "sketch", 16, false),
columnSelectorFactory,
SketchHolder.class
);
Assert.assertEquals(holders[0].getEstimate(), holders[1].getEstimate(), 0);
}
private void assertPostAggregatorSerde(PostAggregator agg) throws Exception
{
Assert.assertEquals(

View File

@ -126,4 +126,28 @@ public interface BufferAggregator extends HotLoopCallee
default void inspectRuntimeShape(RuntimeShapeInspector inspector)
{
}
/*
* Relocates any cached objects.
* If underlying ByteBuffer used for aggregation buffer relocates to a new ByteBuffer, positional caches(if any)
* built on top of old ByteBuffer can not be used for further {@link BufferAggregator#aggregate(ByteBuffer, int)}
* calls. This method tells the BufferAggregator that the cached objects at a certain location has been relocated to
* a different location.
*
* Only used if there is any positional caches/objects in the BufferAggregator implementation.
*
* If relocate happens to be across multiple new ByteBuffers (say n ByteBuffers), this method should be called
* multiple times(n times) given all the new positions/old positions should exist in newBuffer/OldBuffer.
*
* <b>Implementations must not change the position, limit or mark of the given buffer</b>
*
* @param oldPosition old position of a cached object before aggregation buffer relocates to a new ByteBuffer.
* @param newPosition new position of a cached object after aggregation buffer relocates to a new ByteBuffer.
* @param oldBuffer old aggregation buffer.
* @param newBuffer new aggregation buffer.
*/
default void relocate(int oldPosition, int newPosition, ByteBuffer oldBuffer, ByteBuffer newBuffer)
{
}
}

View File

@ -430,8 +430,9 @@ public class BufferGrouper<KeyType> implements Grouper<KeyType>
for (int oldBucket = 0; oldBucket < buckets; oldBucket++) {
if (isUsed(oldBucket)) {
int oldPosition = oldBucket * bucketSize;
entryBuffer.limit((oldBucket + 1) * bucketSize);
entryBuffer.position(oldBucket * bucketSize);
entryBuffer.position(oldPosition);
keyBuffer.limit(entryBuffer.position() + HASH_SIZE + keySize);
keyBuffer.position(entryBuffer.position() + HASH_SIZE);
@ -442,9 +443,19 @@ public class BufferGrouper<KeyType> implements Grouper<KeyType>
throw new ISE("WTF?! Couldn't find a bucket while resizing?!");
}
newTableBuffer.position(newBucket * bucketSize);
int newPosition = newBucket * bucketSize;
newTableBuffer.position(newPosition);
newTableBuffer.put(entryBuffer);
for (int i = 0; i < aggregators.length; i++) {
aggregators[i].relocate(
oldPosition + aggregatorOffsets[i],
newPosition + aggregatorOffsets[i],
tableBuffer,
newTableBuffer
);
}
buffer.putInt(tableArenaSize + newSize * Ints.BYTES, newBucket * bucketSize);
newSize++;
}

View File

@ -67,6 +67,7 @@ import io.druid.query.timeseries.TimeseriesQueryRunnerFactory;
import io.druid.query.topn.TopNQueryConfig;
import io.druid.query.topn.TopNQueryQueryToolChest;
import io.druid.query.topn.TopNQueryRunnerFactory;
import io.druid.segment.ColumnSelectorFactory;
import io.druid.segment.IndexIO;
import io.druid.segment.IndexMerger;
import io.druid.segment.IndexSpec;
@ -84,6 +85,7 @@ import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Iterator;
@ -591,5 +593,30 @@ public class AggregationTestHelper
{
return mapper;
}
public <T> T[] runRelocateVerificationTest(
AggregatorFactory factory,
ColumnSelectorFactory selector,
Class<T> clazz
)
{
T[] results = (T[]) Array.newInstance(clazz, 2);
BufferAggregator agg = factory.factorizeBuffered(selector);
ByteBuffer myBuf = ByteBuffer.allocate(10040902);
agg.init(myBuf, 0);
agg.aggregate(myBuf, 0);
results[0] = (T) agg.get(myBuf, 0);
byte[] theBytes = new byte[factory.getMaxIntermediateSize()];
myBuf.get(theBytes);
ByteBuffer newBuf = ByteBuffer.allocate(941209);
newBuf.position(7574);
newBuf.put(theBytes);
newBuf.position(0);
agg.relocate(0, 7574, myBuf, newBuf);
results[1] = (T) agg.get(newBuf, 7574);
return results;
}
}