NIFI-7954 Wrapping HBase_*_ClientService calls in getUgi().doAs() (#4629)

* NIFI-7954 Wrapping HBase_*_ClientService calls in getUgi().doAs() and taking care of TGT renewal.

* NIFI-7954 Simplified SecurityUtil.callWithUgi a little.

* NIFI-7954 Simplified SecurityUtil.callWithUgi more.

* NIFI-7954 Removed unnecessary code.
This commit is contained in:
tpalfy 2020-11-09 15:00:20 +01:00 committed by GitHub
parent 14ec02f21d
commit 940bc3056c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 350 additions and 262 deletions

View File

@ -19,6 +19,8 @@ package org.apache.nifi.hadoop;
import org.apache.commons.lang3.Validate;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.security.krb.KerberosUser;
import javax.security.auth.Subject;
@ -146,4 +148,28 @@ public class SecurityUtil {
Validate.notNull(config);
return KERBEROS.equalsIgnoreCase(config.get(HADOOP_SECURITY_AUTHENTICATION));
}
public static <T> T callWithUgi(UserGroupInformation ugi, PrivilegedExceptionAction<T> action) throws IOException {
try {
return ugi.doAs(action);
} catch (InterruptedException e) {
throw new IOException(e);
}
}
public static void checkTGTAndRelogin(ComponentLog log, KerberosUser kerberosUser) {
log.trace("getting UGI instance");
if (kerberosUser != null) {
// if there's a KerberosUser associated with this UGI, check the TGT and relogin if it is close to expiring
log.debug("kerberosUser is " + kerberosUser);
try {
log.debug("checking TGT on kerberosUser " + kerberosUser);
kerberosUser.checkTGTAndRelogin();
} catch (LoginException e) {
throw new ProcessException("Unable to relogin with kerberos credentials for " + kerberosUser.getPrincipal(), e);
}
} else {
log.debug("kerberosUser was null, will not refresh TGT with KerberosUser");
}
}
}

View File

@ -62,7 +62,6 @@ import org.apache.nifi.hbase.scan.ResultCell;
import org.apache.nifi.hbase.scan.ResultHandler;
import org.apache.nifi.hbase.validate.ConfigFilesValidator;
import org.apache.nifi.kerberos.KerberosCredentialsService;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.security.krb.KerberosKeytabUser;
@ -71,7 +70,6 @@ import org.apache.nifi.security.krb.KerberosUser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.security.auth.login.LoginException;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
@ -461,47 +459,55 @@ public class HBase_1_1_2_ClientService extends AbstractControllerService impleme
@Override
public void put(final String tableName, final Collection<PutFlowFile> puts) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
// Create one Put per row....
final Map<String, List<PutColumn>> sorted = new HashMap<>();
final List<Put> newPuts = new ArrayList<>();
SecurityUtil.callWithUgi(getUgi(), () -> {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
// Create one Put per row....
final Map<String, List<PutColumn>> sorted = new HashMap<>();
final List<Put> newPuts = new ArrayList<>();
for (final PutFlowFile putFlowFile : puts) {
final String rowKeyString = new String(putFlowFile.getRow(), StandardCharsets.UTF_8);
List<PutColumn> columns = sorted.get(rowKeyString);
if (columns == null) {
columns = new ArrayList<>();
sorted.put(rowKeyString, columns);
for (final PutFlowFile putFlowFile : puts) {
final String rowKeyString = new String(putFlowFile.getRow(), StandardCharsets.UTF_8);
List<PutColumn> columns = sorted.get(rowKeyString);
if (columns == null) {
columns = new ArrayList<>();
sorted.put(rowKeyString, columns);
}
columns.addAll(putFlowFile.getColumns());
}
columns.addAll(putFlowFile.getColumns());
}
for (final Map.Entry<String, List<PutColumn>> entry : sorted.entrySet()) {
newPuts.addAll(buildPuts(entry.getKey().getBytes(StandardCharsets.UTF_8), entry.getValue()));
}
for (final Map.Entry<String, List<PutColumn>> entry : sorted.entrySet()) {
newPuts.addAll(buildPuts(entry.getKey().getBytes(StandardCharsets.UTF_8), entry.getValue()));
table.put(newPuts);
}
table.put(newPuts);
}
return null;
});
}
@Override
public void put(final String tableName, final byte[] rowId, final Collection<PutColumn> columns) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
table.put(buildPuts(rowId, new ArrayList(columns)));
}
SecurityUtil.callWithUgi(getUgi(), () -> {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
table.put(buildPuts(rowId, new ArrayList(columns)));
}
return null;
});
}
@Override
public boolean checkAndPut(final String tableName, final byte[] rowId, final byte[] family, final byte[] qualifier, final byte[] value, final PutColumn column) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
Put put = new Put(rowId);
put.addColumn(
column.getColumnFamily(),
column.getColumnQualifier(),
column.getBuffer());
return table.checkAndPut(rowId, family, qualifier, value, put);
}
return SecurityUtil.callWithUgi(getUgi(), () -> {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
Put put = new Put(rowId);
put.addColumn(
column.getColumnFamily(),
column.getColumnQualifier(),
column.getBuffer());
return table.checkAndPut(rowId, family, qualifier, value, put);
}
});
}
@Override
@ -511,13 +517,16 @@ public class HBase_1_1_2_ClientService extends AbstractControllerService impleme
@Override
public void delete(String tableName, byte[] rowId, String visibilityLabel) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
Delete delete = new Delete(rowId);
if (!StringUtils.isEmpty(visibilityLabel)) {
delete.setCellVisibility(new CellVisibility(visibilityLabel));
SecurityUtil.callWithUgi(getUgi(), () -> {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
Delete delete = new Delete(rowId);
if (!StringUtils.isEmpty(visibilityLabel)) {
delete.setCellVisibility(new CellVisibility(visibilityLabel));
}
table.delete(delete);
}
table.delete(delete);
}
return null;
});
}
@Override
@ -554,9 +563,12 @@ public class HBase_1_1_2_ClientService extends AbstractControllerService impleme
}
private void batchDelete(String tableName, List<Delete> deletes) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
table.delete(deletes);
}
SecurityUtil.callWithUgi(getUgi(), () -> {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
table.delete(deletes);
}
return null;
});
}
@Override
@ -567,64 +579,70 @@ public class HBase_1_1_2_ClientService extends AbstractControllerService impleme
@Override
public void scan(String tableName, Collection<Column> columns, String filterExpression, long minTime, List<String> visibilityLabels, ResultHandler handler) throws IOException {
Filter filter = null;
if (!StringUtils.isBlank(filterExpression)) {
ParseFilter parseFilter = new ParseFilter();
filter = parseFilter.parseFilterString(filterExpression);
}
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, columns, filter, minTime, visibilityLabels)) {
for (final Result result : scanner) {
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i=0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
SecurityUtil.callWithUgi(getUgi(), () -> {
Filter filter = null;
if (!StringUtils.isBlank(filterExpression)) {
ParseFilter parseFilter = new ParseFilter();
filter = parseFilter.parseFilterString(filterExpression);
}
}
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, columns, filter, minTime, visibilityLabels)) {
for (final Result result : scanner) {
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
}
}
return null;
});
}
@Override
public void scan(final String tableName, final byte[] startRow, final byte[] endRow, final Collection<Column> columns, List<String> authorizations, final ResultHandler handler)
throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, startRow, endRow, columns, authorizations)) {
SecurityUtil.callWithUgi(getUgi(), () -> {
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, startRow, endRow, columns, authorizations)) {
for (final Result result : scanner) {
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
for (final Result result : scanner) {
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i=0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
}
}
return null;
});
}
@Override
@ -632,37 +650,40 @@ public class HBase_1_1_2_ClientService extends AbstractControllerService impleme
final Long timerangeMin, final Long timerangeMax, final Integer limitRows, final Boolean isReversed,
final Boolean blockCache, final Collection<Column> columns, List<String> visibilityLabels, final ResultHandler handler) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, startRow, endRow, filterExpression, timerangeMin,
timerangeMax, limitRows, isReversed, blockCache, columns, visibilityLabels)) {
SecurityUtil.callWithUgi(getUgi(), () -> {
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, startRow, endRow, filterExpression, timerangeMin,
timerangeMax, limitRows, isReversed, blockCache, columns, visibilityLabels)) {
int cnt = 0;
final int lim = limitRows != null ? limitRows : 0;
for (final Result result : scanner) {
int cnt = 0;
final int lim = limitRows != null ? limitRows : 0;
for (final Result result : scanner) {
if (lim > 0 && ++cnt > lim){
break;
if (lim > 0 && ++cnt > lim) {
break;
}
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
}
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
}
}
return null;
});
}
//
@ -868,20 +889,7 @@ public class HBase_1_1_2_ClientService extends AbstractControllerService impleme
}
UserGroupInformation getUgi() {
getLogger().trace("getting UGI instance");
if (kerberosUserReference.get() != null) {
// if there's a KerberosUser associated with this UGI, check the TGT and relogin if it is close to expiring
KerberosUser kerberosUser = kerberosUserReference.get();
getLogger().debug("kerberosUser is " + kerberosUser);
try {
getLogger().debug("checking TGT on kerberosUser [{}]", new Object[] {kerberosUser});
kerberosUser.checkTGTAndRelogin();
} catch (LoginException e) {
throw new ProcessException("Unable to relogin with kerberos credentials for " + kerberosUser.getPrincipal(), e);
}
} else {
getLogger().debug("kerberosUser was null, will not refresh TGT with KerberosUser");
}
SecurityUtil.checkTGTAndRelogin(getLogger(), kerberosUserReference.get());
return ugi;
}

View File

@ -23,6 +23,7 @@ import org.apache.hadoop.hbase.client.Result;
import org.apache.hadoop.hbase.client.ResultScanner;
import org.apache.hadoop.hbase.client.Table;
import org.apache.hadoop.hbase.filter.Filter;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.nifi.controller.ConfigurationContext;
import org.apache.nifi.hadoop.KerberosProperties;
import org.apache.nifi.hbase.put.PutColumn;
@ -33,6 +34,7 @@ import org.mockito.Mockito;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@ -40,6 +42,9 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
@ -52,6 +57,19 @@ public class MockHBaseClientService extends HBase_1_1_2_ClientService {
private Map<String, Result> results = new HashMap<>();
private KerberosProperties kerberosProperties;
private boolean allowExplicitKeytab;
private UserGroupInformation mockUgi;
{
mockUgi = mock(UserGroupInformation.class);
try {
doAnswer(invocation -> {
PrivilegedExceptionAction<?> action = invocation.getArgument(0);
return action.run();
}).when(mockUgi).doAs(any(PrivilegedExceptionAction.class));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public MockHBaseClientService(final Table table, final String family, final KerberosProperties kerberosProperties) {
this(table, family, kerberosProperties, false);
@ -209,4 +227,9 @@ public class MockHBaseClientService extends HBase_1_1_2_ClientService {
boolean isAllowExplicitKeytab() {
return allowExplicitKeytab;
}
@Override
UserGroupInformation getUgi() {
return mockUgi;
}
}

View File

@ -62,7 +62,6 @@ import org.apache.nifi.hbase.scan.ResultCell;
import org.apache.nifi.hbase.scan.ResultHandler;
import org.apache.nifi.hbase.validate.ConfigFilesValidator;
import org.apache.nifi.kerberos.KerberosCredentialsService;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.security.krb.KerberosKeytabUser;
@ -71,7 +70,6 @@ import org.apache.nifi.security.krb.KerberosUser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.security.auth.login.LoginException;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
@ -460,47 +458,56 @@ public class HBase_2_ClientService extends AbstractControllerService implements
@Override
public void put(final String tableName, final Collection<PutFlowFile> puts) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
// Create one Put per row....
final Map<String, List<PutColumn>> sorted = new HashMap<>();
final List<Put> newPuts = new ArrayList<>();
SecurityUtil.callWithUgi(getUgi(), () -> {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
// Create one Put per row....
final Map<String, List<PutColumn>> sorted = new HashMap<>();
final List<Put> newPuts = new ArrayList<>();
for (final PutFlowFile putFlowFile : puts) {
final String rowKeyString = new String(putFlowFile.getRow(), StandardCharsets.UTF_8);
List<PutColumn> columns = sorted.get(rowKeyString);
if (columns == null) {
columns = new ArrayList<>();
sorted.put(rowKeyString, columns);
for (final PutFlowFile putFlowFile : puts) {
final String rowKeyString = new String(putFlowFile.getRow(), StandardCharsets.UTF_8);
List<PutColumn> columns = sorted.get(rowKeyString);
if (columns == null) {
columns = new ArrayList<>();
sorted.put(rowKeyString, columns);
}
columns.addAll(putFlowFile.getColumns());
}
columns.addAll(putFlowFile.getColumns());
for (final Map.Entry<String, List<PutColumn>> entry : sorted.entrySet()) {
newPuts.addAll(buildPuts(entry.getKey().getBytes(StandardCharsets.UTF_8), entry.getValue()));
}
table.put(newPuts);
}
for (final Map.Entry<String, List<PutColumn>> entry : sorted.entrySet()) {
newPuts.addAll(buildPuts(entry.getKey().getBytes(StandardCharsets.UTF_8), entry.getValue()));
}
table.put(newPuts);
}
return null;
});
}
@Override
public void put(final String tableName, final byte[] rowId, final Collection<PutColumn> columns) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
table.put(buildPuts(rowId, new ArrayList(columns)));
}
SecurityUtil.callWithUgi(getUgi(), () -> {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
table.put(buildPuts(rowId, new ArrayList(columns)));
}
return null;
});
}
@Override
public boolean checkAndPut(final String tableName, final byte[] rowId, final byte[] family, final byte[] qualifier, final byte[] value, final PutColumn column) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
Put put = new Put(rowId);
put.addColumn(
column.getColumnFamily(),
column.getColumnQualifier(),
column.getBuffer());
return table.checkAndPut(rowId, family, qualifier, value, put);
}
return SecurityUtil.callWithUgi(getUgi(), () -> {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
Put put = new Put(rowId);
put.addColumn(
column.getColumnFamily(),
column.getColumnQualifier(),
column.getBuffer());
return table.checkAndPut(rowId, family, qualifier, value, put);
}
});
}
@Override
@ -510,13 +517,16 @@ public class HBase_2_ClientService extends AbstractControllerService implements
@Override
public void delete(String tableName, byte[] rowId, String visibilityLabel) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
Delete delete = new Delete(rowId);
if (!StringUtils.isEmpty(visibilityLabel)) {
delete.setCellVisibility(new CellVisibility(visibilityLabel));
SecurityUtil.callWithUgi(getUgi(), () -> {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
Delete delete = new Delete(rowId);
if (!StringUtils.isEmpty(visibilityLabel)) {
delete.setCellVisibility(new CellVisibility(visibilityLabel));
}
table.delete(delete);
}
table.delete(delete);
}
return null;
});
}
@Override
@ -553,9 +563,12 @@ public class HBase_2_ClientService extends AbstractControllerService implements
}
private void batchDelete(String tableName, List<Delete> deletes) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
table.delete(deletes);
}
SecurityUtil.callWithUgi(getUgi(), () -> {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
table.delete(deletes);
}
return null;
});
}
@Override
@ -566,64 +579,70 @@ public class HBase_2_ClientService extends AbstractControllerService implements
@Override
public void scan(String tableName, Collection<Column> columns, String filterExpression, long minTime, List<String> visibilityLabels, ResultHandler handler) throws IOException {
Filter filter = null;
if (!StringUtils.isBlank(filterExpression)) {
ParseFilter parseFilter = new ParseFilter();
filter = parseFilter.parseFilterString(filterExpression);
}
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, columns, filter, minTime, visibilityLabels)) {
for (final Result result : scanner) {
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i=0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
SecurityUtil.callWithUgi(getUgi(), () -> {
Filter filter = null;
if (!StringUtils.isBlank(filterExpression)) {
ParseFilter parseFilter = new ParseFilter();
filter = parseFilter.parseFilterString(filterExpression);
}
}
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, columns, filter, minTime, visibilityLabels)) {
for (final Result result : scanner) {
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
}
}
return null;
});
}
@Override
public void scan(final String tableName, final byte[] startRow, final byte[] endRow, final Collection<Column> columns, List<String> authorizations, final ResultHandler handler)
throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, startRow, endRow, columns, authorizations)) {
SecurityUtil.callWithUgi(getUgi(), () -> {
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, startRow, endRow, columns, authorizations)) {
for (final Result result : scanner) {
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
for (final Result result : scanner) {
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i=0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
}
}
return null;
});
}
@Override
@ -631,37 +650,40 @@ public class HBase_2_ClientService extends AbstractControllerService implements
final Long timerangeMin, final Long timerangeMax, final Integer limitRows, final Boolean isReversed,
final Boolean blockCache, final Collection<Column> columns, List<String> visibilityLabels, final ResultHandler handler) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, startRow, endRow, filterExpression, timerangeMin,
timerangeMax, limitRows, isReversed, blockCache, columns, visibilityLabels)) {
SecurityUtil.callWithUgi(getUgi(), () -> {
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, startRow, endRow, filterExpression, timerangeMin,
timerangeMax, limitRows, isReversed, blockCache, columns, visibilityLabels)) {
int cnt = 0;
final int lim = limitRows != null ? limitRows : 0;
for (final Result result : scanner) {
int cnt = 0;
final int lim = limitRows != null ? limitRows : 0;
for (final Result result : scanner) {
if (lim > 0 && ++cnt > lim){
break;
if (lim > 0 && ++cnt > lim) {
break;
}
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
}
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
}
}
return null;
});
}
//
@ -866,22 +888,8 @@ public class HBase_2_ClientService extends AbstractControllerService implements
return Boolean.parseBoolean(System.getenv(ALLOW_EXPLICIT_KEYTAB));
}
UserGroupInformation getUgi() {
getLogger().trace("getting UGI instance");
if (kerberosUserReference.get() != null) {
// if there's a KerberosUser associated with this UGI, check the TGT and relogin if it is close to expiring
KerberosUser kerberosUser = kerberosUserReference.get();
getLogger().debug("kerberosUser is " + kerberosUser);
try {
getLogger().debug("checking TGT on kerberosUser [{}]", new Object[] {kerberosUser});
kerberosUser.checkTGTAndRelogin();
} catch (LoginException e) {
throw new ProcessException("Unable to relogin with kerberos credentials for " + kerberosUser.getPrincipal(), e);
}
} else {
getLogger().debug("kerberosUser was null, will not refresh TGT with KerberosUser");
}
UserGroupInformation getUgi() throws IOException {
SecurityUtil.checkTGTAndRelogin(getLogger(), kerberosUserReference.get());
return ugi;
}
}

View File

@ -23,6 +23,7 @@ import org.apache.hadoop.hbase.client.Result;
import org.apache.hadoop.hbase.client.ResultScanner;
import org.apache.hadoop.hbase.client.Table;
import org.apache.hadoop.hbase.filter.Filter;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.nifi.controller.ConfigurationContext;
import org.apache.nifi.hadoop.KerberosProperties;
import org.apache.nifi.hbase.put.PutColumn;
@ -33,6 +34,7 @@ import org.mockito.Mockito;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@ -40,6 +42,9 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
@ -52,6 +57,19 @@ public class MockHBaseClientService extends HBase_2_ClientService {
private Map<String, Result> results = new HashMap<>();
private KerberosProperties kerberosProperties;
private boolean allowExplicitKeytab;
private UserGroupInformation mockUgi;
{
mockUgi = mock(UserGroupInformation.class);
try {
doAnswer(invocation -> {
PrivilegedExceptionAction<?> action = invocation.getArgument(0);
return action.run();
}).when(mockUgi).doAs(any(PrivilegedExceptionAction.class));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public MockHBaseClientService(final Table table, final String family, final KerberosProperties kerberosProperties) {
this(table, family, kerberosProperties, false);
@ -79,7 +97,7 @@ public class MockHBaseClientService extends HBase_2_ClientService {
final Cell[] cellArray = new Cell[cells.size()];
int i = 0;
for (final Map.Entry<String, String> cellEntry : cells.entrySet()) {
final Cell cell = Mockito.mock(Cell.class);
final Cell cell = mock(Cell.class);
when(cell.getRowArray()).thenReturn(rowArray);
when(cell.getRowOffset()).thenReturn(0);
when(cell.getRowLength()).thenReturn((short) rowArray.length);
@ -106,7 +124,7 @@ public class MockHBaseClientService extends HBase_2_ClientService {
cellArray[i++] = cell;
}
final Result result = Mockito.mock(Result.class);
final Result result = mock(Result.class);
when(result.getRow()).thenReturn(rowArray);
when(result.rawCells()).thenReturn(cellArray);
results.put(rowKey, result);
@ -179,28 +197,28 @@ public class MockHBaseClientService extends HBase_2_ClientService {
}
protected ResultScanner getResults(Table table, byte[] startRow, byte[] endRow, Collection<Column> columns, List<String> labels) throws IOException {
final ResultScanner scanner = Mockito.mock(ResultScanner.class);
final ResultScanner scanner = mock(ResultScanner.class);
Mockito.when(scanner.iterator()).thenReturn(results.values().iterator());
return scanner;
}
@Override
protected ResultScanner getResults(Table table, Collection<Column> columns, Filter filter, long minTime, List<String> labels) throws IOException {
final ResultScanner scanner = Mockito.mock(ResultScanner.class);
final ResultScanner scanner = mock(ResultScanner.class);
Mockito.when(scanner.iterator()).thenReturn(results.values().iterator());
return scanner;
}
protected ResultScanner getResults(final Table table, final String startRow, final String endRow, final String filterExpression, final Long timerangeMin, final Long timerangeMax,
final Integer limitRows, final Boolean isReversed, final Collection<Column> columns) throws IOException {
final ResultScanner scanner = Mockito.mock(ResultScanner.class);
final ResultScanner scanner = mock(ResultScanner.class);
Mockito.when(scanner.iterator()).thenReturn(results.values().iterator());
return scanner;
}
@Override
protected Connection createConnection(ConfigurationContext context) throws IOException {
Connection connection = Mockito.mock(Connection.class);
Connection connection = mock(Connection.class);
Mockito.when(connection.getTable(table.getName())).thenReturn(table);
return connection;
}
@ -209,4 +227,9 @@ public class MockHBaseClientService extends HBase_2_ClientService {
boolean isAllowExplicitKeytab() {
return allowExplicitKeytab;
}
@Override
UserGroupInformation getUgi() throws IOException {
return mockUgi;
}
}