diff --git a/hbase-mapreduce/src/main/java/org/apache/hadoop/hbase/mapreduce/TableInputFormatBase.java b/hbase-mapreduce/src/main/java/org/apache/hadoop/hbase/mapreduce/TableInputFormatBase.java index d8031d91f8e..24973c941e7 100644 --- a/hbase-mapreduce/src/main/java/org/apache/hadoop/hbase/mapreduce/TableInputFormatBase.java +++ b/hbase-mapreduce/src/main/java/org/apache/hadoop/hbase/mapreduce/TableInputFormatBase.java @@ -53,6 +53,7 @@ import org.apache.hadoop.mapreduce.RecordReader; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.net.DNS; import org.apache.hadoop.util.StringUtils; +import org.apache.hbase.thirdparty.com.google.common.annotations.VisibleForTesting; /** * A base for {@link TableInputFormat}s. Receives a {@link Connection}, a {@link TableName}, @@ -291,7 +292,7 @@ public abstract class TableInputFormatBase */ private List oneInputSplitPerRegion() throws IOException { RegionSizeCalculator sizeCalculator = - new RegionSizeCalculator(getRegionLocator(), getAdmin()); + createRegionSizeCalculator(getRegionLocator(), getAdmin()); TableName tableName = getTable().getName(); @@ -478,7 +479,8 @@ public abstract class TableInputFormatBase while (j < splits.size()) { TableSplit nextRegion = (TableSplit) splits.get(j); long nextRegionSize = nextRegion.getLength(); - if (totalSize + nextRegionSize <= averageRegionSize) { + if (totalSize + nextRegionSize <= averageRegionSize + && Bytes.equals(splitEndKey, nextRegion.getStartRow())) { totalSize = totalSize + nextRegionSize; splitEndKey = nextRegion.getEndRow(); j++; @@ -586,6 +588,12 @@ public abstract class TableInputFormatBase this.connection = connection; } + @VisibleForTesting + protected RegionSizeCalculator createRegionSizeCalculator(RegionLocator locator, Admin admin) + throws IOException { + return new RegionSizeCalculator(locator, admin); + } + /** * Gets the scan defining the actual details like columns etc. * diff --git a/hbase-mapreduce/src/test/java/org/apache/hadoop/hbase/mapreduce/TestTableInputFormatBase.java b/hbase-mapreduce/src/test/java/org/apache/hadoop/hbase/mapreduce/TestTableInputFormatBase.java index 5fa4b546915..29a92ee75c0 100644 --- a/hbase-mapreduce/src/test/java/org/apache/hadoop/hbase/mapreduce/TestTableInputFormatBase.java +++ b/hbase-mapreduce/src/test/java/org/apache/hadoop/hbase/mapreduce/TestTableInputFormatBase.java @@ -18,15 +18,45 @@ package org.apache.hadoop.hbase.mapreduce; import static org.junit.Assert.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import java.io.IOException; import java.net.Inet6Address; import java.net.InetAddress; import java.net.UnknownHostException; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.ExecutorService; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.HBaseClassTestRule; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.hbase.HConstants; +import org.apache.hadoop.hbase.HRegionLocation; +import org.apache.hadoop.hbase.ServerName; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.Admin; +import org.apache.hadoop.hbase.client.BufferedMutator; +import org.apache.hadoop.hbase.client.BufferedMutatorParams; +import org.apache.hadoop.hbase.client.ClusterConnection; +import org.apache.hadoop.hbase.client.Connection; +import org.apache.hadoop.hbase.client.RegionInfo; +import org.apache.hadoop.hbase.client.RegionInfoBuilder; +import org.apache.hadoop.hbase.client.RegionLocator; +import org.apache.hadoop.hbase.client.Table; +import org.apache.hadoop.hbase.client.TableBuilder; +import org.apache.hadoop.hbase.security.User; import org.apache.hadoop.hbase.testclassification.SmallTests; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.hadoop.hbase.util.Pair; +import org.apache.hadoop.mapreduce.JobContext; import org.junit.ClassRule; import org.junit.Test; import org.junit.experimental.categories.Category; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; @Category({SmallTests.class}) public class TestTableInputFormatBase { @@ -55,4 +85,206 @@ public class TestTableInputFormatBase { assertEquals("Should retrun the hostname for this host. Expected : " + localhost + " Actual : " + actualHostName, localhost, actualHostName); } + + @Test + public void testNonSuccessiveSplitsAreNotMerged() throws IOException { + JobContext context = mock(JobContext.class); + Configuration conf = HBaseConfiguration.create(); + conf.set(ClusterConnection.HBASE_CLIENT_CONNECTION_IMPL, + ConnectionForMergeTesting.class.getName()); + conf.set(TableInputFormat.INPUT_TABLE, "testTable"); + conf.setBoolean(TableInputFormatBase.MAPREDUCE_INPUT_AUTOBALANCE, true); + when(context.getConfiguration()).thenReturn(conf); + + TableInputFormat tifExclude = new TableInputFormatForMergeTesting(); + tifExclude.setConf(conf); + // split["b", "c"] is excluded, split["o", "p"] and split["p", "q"] are merged, + // but split["a", "b"] and split["c", "d"] are not merged. + assertEquals(ConnectionForMergeTesting.START_KEYS.length - 1 - 1, + tifExclude.getSplits(context).size()); + } + + /** + * Subclass of {@link TableInputFormat} to use in {@link #testNonSuccessiveSplitsAreNotMerged}. + * This class overrides {@link TableInputFormatBase#includeRegionInSplit} + * to exclude specific splits. + */ + private static class TableInputFormatForMergeTesting extends TableInputFormat { + private byte[] prefixStartKey = Bytes.toBytes("b"); + private byte[] prefixEndKey = Bytes.toBytes("c"); + private RegionSizeCalculator sizeCalculator; + + /** + * Exclude regions which contain rows starting with "b". + */ + @Override + protected boolean includeRegionInSplit(final byte[] startKey, final byte [] endKey) { + if (Bytes.compareTo(startKey, prefixEndKey) < 0 + && (Bytes.compareTo(prefixStartKey, endKey) < 0 + || Bytes.equals(endKey, HConstants.EMPTY_END_ROW))) { + return false; + } else { + return true; + } + } + + @Override + protected void initializeTable(Connection connection, TableName tableName) throws IOException { + super.initializeTable(connection, tableName); + ConnectionForMergeTesting cft = (ConnectionForMergeTesting) connection; + sizeCalculator = cft.getRegionSizeCalculator(); + } + + @Override + protected RegionSizeCalculator createRegionSizeCalculator(RegionLocator locator, Admin admin) + throws IOException { + return sizeCalculator; + } + } + + /** + * Connection class to use in {@link #testNonSuccessiveSplitsAreNotMerged}. + * This class returns mocked {@link Table}, {@link RegionLocator}, {@link RegionSizeCalculator}, + * and {@link Admin}. + */ + private static class ConnectionForMergeTesting implements Connection { + public static final byte[][] SPLITS = new byte[][] { + Bytes.toBytes("a"), Bytes.toBytes("b"), Bytes.toBytes("c"), Bytes.toBytes("d"), + Bytes.toBytes("e"), Bytes.toBytes("f"), Bytes.toBytes("g"), Bytes.toBytes("h"), + Bytes.toBytes("i"), Bytes.toBytes("j"), Bytes.toBytes("k"), Bytes.toBytes("l"), + Bytes.toBytes("m"), Bytes.toBytes("n"), Bytes.toBytes("o"), Bytes.toBytes("p"), + Bytes.toBytes("q"), Bytes.toBytes("r"), Bytes.toBytes("s"), Bytes.toBytes("t"), + Bytes.toBytes("u"), Bytes.toBytes("v"), Bytes.toBytes("w"), Bytes.toBytes("x"), + Bytes.toBytes("y"), Bytes.toBytes("z") + }; + + public static final byte[][] START_KEYS; + public static final byte[][] END_KEYS; + static { + START_KEYS = new byte[SPLITS.length + 1][]; + START_KEYS[0] = HConstants.EMPTY_BYTE_ARRAY; + for (int i = 0; i < SPLITS.length; i++) { + START_KEYS[i + 1] = SPLITS[i]; + } + + END_KEYS = new byte[SPLITS.length + 1][]; + for (int i = 0; i < SPLITS.length; i++) { + END_KEYS[i] = SPLITS[i]; + } + END_KEYS[SPLITS.length] = HConstants.EMPTY_BYTE_ARRAY; + } + + public static final Map SIZE_MAP = new TreeMap<>(Bytes.BYTES_COMPARATOR); + static { + for (byte[] startKey : START_KEYS) { + SIZE_MAP.put(startKey, 1024L * 1024L * 1024L); + } + SIZE_MAP.put(Bytes.toBytes("a"), 200L * 1024L * 1024L); + SIZE_MAP.put(Bytes.toBytes("b"), 200L * 1024L * 1024L); + SIZE_MAP.put(Bytes.toBytes("c"), 200L * 1024L * 1024L); + SIZE_MAP.put(Bytes.toBytes("o"), 200L * 1024L * 1024L); + SIZE_MAP.put(Bytes.toBytes("p"), 200L * 1024L * 1024L); + } + + ConnectionForMergeTesting(Configuration conf, ExecutorService pool, User user) + throws IOException { + } + + @Override + public void abort(String why, Throwable e) { + } + + @Override + public boolean isAborted() { + return false; + } + + @Override + public Configuration getConfiguration() { + throw new UnsupportedOperationException(); + } + + @Override + public Table getTable(TableName tableName) throws IOException { + Table table = mock(Table.class); + when(table.getName()).thenReturn(tableName); + return table; + } + + @Override + public Table getTable(TableName tableName, ExecutorService pool) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public BufferedMutator getBufferedMutator(TableName tableName) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public BufferedMutator getBufferedMutator(BufferedMutatorParams params) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public RegionLocator getRegionLocator(TableName tableName) throws IOException { + final Map locationMap = new TreeMap<>(Bytes.BYTES_COMPARATOR); + for (byte[] startKey : START_KEYS) { + HRegionLocation hrl = new HRegionLocation( + RegionInfoBuilder.newBuilder(tableName).setStartKey(startKey).build(), + ServerName.valueOf("localhost", 0, 0)); + locationMap.put(startKey, hrl); + } + + RegionLocator locator = mock(RegionLocator.class); + when(locator.getRegionLocation(any(byte [].class), anyBoolean())). + thenAnswer(new Answer() { + @Override + public HRegionLocation answer(InvocationOnMock invocationOnMock) throws Throwable { + Object [] args = invocationOnMock.getArguments(); + byte [] key = (byte [])args[0]; + return locationMap.get(key); + } + }); + when(locator.getStartEndKeys()). + thenReturn(new Pair(START_KEYS, END_KEYS)); + return locator; + } + + public RegionSizeCalculator getRegionSizeCalculator() { + RegionSizeCalculator sizeCalculator = mock(RegionSizeCalculator.class); + when(sizeCalculator.getRegionSize(any(byte[].class))). + thenAnswer(new Answer() { + @Override + public Long answer(InvocationOnMock invocationOnMock) throws Throwable { + Object [] args = invocationOnMock.getArguments(); + byte [] regionId = (byte [])args[0]; + byte[] startKey = RegionInfo.getStartKey(regionId); + return SIZE_MAP.get(startKey); + } + }); + return sizeCalculator; + } + + @Override + public Admin getAdmin() throws IOException { + Admin admin = mock(Admin.class); + // return non-null admin to pass null checks + return admin; + } + + @Override + public void close() throws IOException { + } + + @Override + public boolean isClosed() { + return false; + } + + @Override + public TableBuilder getTableBuilder(TableName tableName, ExecutorService pool) { + throw new UnsupportedOperationException(); + } + } }