NIFI-7954 Wrapping HBase_*_ClientService calls in getUgi().doAs() (#4629)

* NIFI-7954 Wrapping HBase_*_ClientService calls in getUgi().doAs() and taking care of TGT renewal.

* NIFI-7954 Simplified SecurityUtil.callWithUgi a little.

* NIFI-7954 Simplified SecurityUtil.callWithUgi more.

* NIFI-7954 Removed unnecessary code.
This commit is contained in:
tpalfy 2020-11-09 15:00:20 +01:00 committed by GitHub
parent 14ec02f21d
commit 940bc3056c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 350 additions and 262 deletions

View File

@ -19,6 +19,8 @@ package org.apache.nifi.hadoop;
import org.apache.commons.lang3.Validate; import org.apache.commons.lang3.Validate;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.security.krb.KerberosUser; import org.apache.nifi.security.krb.KerberosUser;
import javax.security.auth.Subject; import javax.security.auth.Subject;
@ -146,4 +148,28 @@ public class SecurityUtil {
Validate.notNull(config); Validate.notNull(config);
return KERBEROS.equalsIgnoreCase(config.get(HADOOP_SECURITY_AUTHENTICATION)); return KERBEROS.equalsIgnoreCase(config.get(HADOOP_SECURITY_AUTHENTICATION));
} }
public static <T> T callWithUgi(UserGroupInformation ugi, PrivilegedExceptionAction<T> action) throws IOException {
try {
return ugi.doAs(action);
} catch (InterruptedException e) {
throw new IOException(e);
}
}
public static void checkTGTAndRelogin(ComponentLog log, KerberosUser kerberosUser) {
log.trace("getting UGI instance");
if (kerberosUser != null) {
// if there's a KerberosUser associated with this UGI, check the TGT and relogin if it is close to expiring
log.debug("kerberosUser is " + kerberosUser);
try {
log.debug("checking TGT on kerberosUser " + kerberosUser);
kerberosUser.checkTGTAndRelogin();
} catch (LoginException e) {
throw new ProcessException("Unable to relogin with kerberos credentials for " + kerberosUser.getPrincipal(), e);
}
} else {
log.debug("kerberosUser was null, will not refresh TGT with KerberosUser");
}
}
} }

View File

@ -62,7 +62,6 @@ import org.apache.nifi.hbase.scan.ResultCell;
import org.apache.nifi.hbase.scan.ResultHandler; import org.apache.nifi.hbase.scan.ResultHandler;
import org.apache.nifi.hbase.validate.ConfigFilesValidator; import org.apache.nifi.hbase.validate.ConfigFilesValidator;
import org.apache.nifi.kerberos.KerberosCredentialsService; import org.apache.nifi.kerberos.KerberosCredentialsService;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.util.StandardValidators; import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.reporting.InitializationException; import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.security.krb.KerberosKeytabUser; import org.apache.nifi.security.krb.KerberosKeytabUser;
@ -71,7 +70,6 @@ import org.apache.nifi.security.krb.KerberosUser;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import javax.security.auth.login.LoginException;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -461,47 +459,55 @@ public class HBase_1_1_2_ClientService extends AbstractControllerService impleme
@Override @Override
public void put(final String tableName, final Collection<PutFlowFile> puts) throws IOException { public void put(final String tableName, final Collection<PutFlowFile> puts) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) { SecurityUtil.callWithUgi(getUgi(), () -> {
// Create one Put per row.... try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
final Map<String, List<PutColumn>> sorted = new HashMap<>(); // Create one Put per row....
final List<Put> newPuts = new ArrayList<>(); final Map<String, List<PutColumn>> sorted = new HashMap<>();
final List<Put> newPuts = new ArrayList<>();
for (final PutFlowFile putFlowFile : puts) { for (final PutFlowFile putFlowFile : puts) {
final String rowKeyString = new String(putFlowFile.getRow(), StandardCharsets.UTF_8); final String rowKeyString = new String(putFlowFile.getRow(), StandardCharsets.UTF_8);
List<PutColumn> columns = sorted.get(rowKeyString); List<PutColumn> columns = sorted.get(rowKeyString);
if (columns == null) { if (columns == null) {
columns = new ArrayList<>(); columns = new ArrayList<>();
sorted.put(rowKeyString, columns); sorted.put(rowKeyString, columns);
}
columns.addAll(putFlowFile.getColumns());
} }
columns.addAll(putFlowFile.getColumns()); for (final Map.Entry<String, List<PutColumn>> entry : sorted.entrySet()) {
} newPuts.addAll(buildPuts(entry.getKey().getBytes(StandardCharsets.UTF_8), entry.getValue()));
}
for (final Map.Entry<String, List<PutColumn>> entry : sorted.entrySet()) { table.put(newPuts);
newPuts.addAll(buildPuts(entry.getKey().getBytes(StandardCharsets.UTF_8), entry.getValue()));
} }
return null;
table.put(newPuts); });
}
} }
@Override @Override
public void put(final String tableName, final byte[] rowId, final Collection<PutColumn> columns) throws IOException { public void put(final String tableName, final byte[] rowId, final Collection<PutColumn> columns) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) { SecurityUtil.callWithUgi(getUgi(), () -> {
table.put(buildPuts(rowId, new ArrayList(columns))); try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
} table.put(buildPuts(rowId, new ArrayList(columns)));
}
return null;
});
} }
@Override @Override
public boolean checkAndPut(final String tableName, final byte[] rowId, final byte[] family, final byte[] qualifier, final byte[] value, final PutColumn column) throws IOException { public boolean checkAndPut(final String tableName, final byte[] rowId, final byte[] family, final byte[] qualifier, final byte[] value, final PutColumn column) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) { return SecurityUtil.callWithUgi(getUgi(), () -> {
Put put = new Put(rowId); try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
put.addColumn( Put put = new Put(rowId);
column.getColumnFamily(), put.addColumn(
column.getColumnQualifier(), column.getColumnFamily(),
column.getBuffer()); column.getColumnQualifier(),
return table.checkAndPut(rowId, family, qualifier, value, put); column.getBuffer());
} return table.checkAndPut(rowId, family, qualifier, value, put);
}
});
} }
@Override @Override
@ -511,13 +517,16 @@ public class HBase_1_1_2_ClientService extends AbstractControllerService impleme
@Override @Override
public void delete(String tableName, byte[] rowId, String visibilityLabel) throws IOException { public void delete(String tableName, byte[] rowId, String visibilityLabel) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) { SecurityUtil.callWithUgi(getUgi(), () -> {
Delete delete = new Delete(rowId); try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
if (!StringUtils.isEmpty(visibilityLabel)) { Delete delete = new Delete(rowId);
delete.setCellVisibility(new CellVisibility(visibilityLabel)); if (!StringUtils.isEmpty(visibilityLabel)) {
delete.setCellVisibility(new CellVisibility(visibilityLabel));
}
table.delete(delete);
} }
table.delete(delete); return null;
} });
} }
@Override @Override
@ -554,9 +563,12 @@ public class HBase_1_1_2_ClientService extends AbstractControllerService impleme
} }
private void batchDelete(String tableName, List<Delete> deletes) throws IOException { private void batchDelete(String tableName, List<Delete> deletes) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) { SecurityUtil.callWithUgi(getUgi(), () -> {
table.delete(deletes); try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
} table.delete(deletes);
}
return null;
});
} }
@Override @Override
@ -567,64 +579,70 @@ public class HBase_1_1_2_ClientService extends AbstractControllerService impleme
@Override @Override
public void scan(String tableName, Collection<Column> columns, String filterExpression, long minTime, List<String> visibilityLabels, ResultHandler handler) throws IOException { public void scan(String tableName, Collection<Column> columns, String filterExpression, long minTime, List<String> visibilityLabels, ResultHandler handler) throws IOException {
Filter filter = null; SecurityUtil.callWithUgi(getUgi(), () -> {
if (!StringUtils.isBlank(filterExpression)) { Filter filter = null;
ParseFilter parseFilter = new ParseFilter(); if (!StringUtils.isBlank(filterExpression)) {
filter = parseFilter.parseFilterString(filterExpression); ParseFilter parseFilter = new ParseFilter();
} filter = parseFilter.parseFilterString(filterExpression);
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, columns, filter, minTime, visibilityLabels)) {
for (final Result result : scanner) {
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i=0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
} }
}
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, columns, filter, minTime, visibilityLabels)) {
for (final Result result : scanner) {
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
}
}
return null;
});
} }
@Override @Override
public void scan(final String tableName, final byte[] startRow, final byte[] endRow, final Collection<Column> columns, List<String> authorizations, final ResultHandler handler) public void scan(final String tableName, final byte[] startRow, final byte[] endRow, final Collection<Column> columns, List<String> authorizations, final ResultHandler handler)
throws IOException { throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName)); SecurityUtil.callWithUgi(getUgi(), () -> {
final ResultScanner scanner = getResults(table, startRow, endRow, columns, authorizations)) { try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, startRow, endRow, columns, authorizations)) {
for (final Result result : scanner) { for (final Result result : scanner) {
final byte[] rowKey = result.getRow(); final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells(); final Cell[] cells = result.rawCells();
if (cells == null) { if (cells == null) {
continue; continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
} }
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i=0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
} }
} return null;
});
} }
@Override @Override
@ -632,37 +650,40 @@ public class HBase_1_1_2_ClientService extends AbstractControllerService impleme
final Long timerangeMin, final Long timerangeMax, final Integer limitRows, final Boolean isReversed, final Long timerangeMin, final Long timerangeMax, final Integer limitRows, final Boolean isReversed,
final Boolean blockCache, final Collection<Column> columns, List<String> visibilityLabels, final ResultHandler handler) throws IOException { final Boolean blockCache, final Collection<Column> columns, List<String> visibilityLabels, final ResultHandler handler) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName)); SecurityUtil.callWithUgi(getUgi(), () -> {
final ResultScanner scanner = getResults(table, startRow, endRow, filterExpression, timerangeMin, try (final Table table = connection.getTable(TableName.valueOf(tableName));
timerangeMax, limitRows, isReversed, blockCache, columns, visibilityLabels)) { final ResultScanner scanner = getResults(table, startRow, endRow, filterExpression, timerangeMin,
timerangeMax, limitRows, isReversed, blockCache, columns, visibilityLabels)) {
int cnt = 0; int cnt = 0;
final int lim = limitRows != null ? limitRows : 0; final int lim = limitRows != null ? limitRows : 0;
for (final Result result : scanner) { for (final Result result : scanner) {
if (lim > 0 && ++cnt > lim){ if (lim > 0 && ++cnt > lim) {
break; break;
}
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
} }
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
} }
} return null;
});
} }
// //
@ -868,20 +889,7 @@ public class HBase_1_1_2_ClientService extends AbstractControllerService impleme
} }
UserGroupInformation getUgi() { UserGroupInformation getUgi() {
getLogger().trace("getting UGI instance"); SecurityUtil.checkTGTAndRelogin(getLogger(), kerberosUserReference.get());
if (kerberosUserReference.get() != null) {
// if there's a KerberosUser associated with this UGI, check the TGT and relogin if it is close to expiring
KerberosUser kerberosUser = kerberosUserReference.get();
getLogger().debug("kerberosUser is " + kerberosUser);
try {
getLogger().debug("checking TGT on kerberosUser [{}]", new Object[] {kerberosUser});
kerberosUser.checkTGTAndRelogin();
} catch (LoginException e) {
throw new ProcessException("Unable to relogin with kerberos credentials for " + kerberosUser.getPrincipal(), e);
}
} else {
getLogger().debug("kerberosUser was null, will not refresh TGT with KerberosUser");
}
return ugi; return ugi;
} }

View File

@ -23,6 +23,7 @@ import org.apache.hadoop.hbase.client.Result;
import org.apache.hadoop.hbase.client.ResultScanner; import org.apache.hadoop.hbase.client.ResultScanner;
import org.apache.hadoop.hbase.client.Table; import org.apache.hadoop.hbase.client.Table;
import org.apache.hadoop.hbase.filter.Filter; import org.apache.hadoop.hbase.filter.Filter;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.nifi.controller.ConfigurationContext; import org.apache.nifi.controller.ConfigurationContext;
import org.apache.nifi.hadoop.KerberosProperties; import org.apache.nifi.hadoop.KerberosProperties;
import org.apache.nifi.hbase.put.PutColumn; import org.apache.nifi.hbase.put.PutColumn;
@ -33,6 +34,7 @@ import org.mockito.Mockito;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
@ -40,6 +42,9 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
/** /**
@ -52,6 +57,19 @@ public class MockHBaseClientService extends HBase_1_1_2_ClientService {
private Map<String, Result> results = new HashMap<>(); private Map<String, Result> results = new HashMap<>();
private KerberosProperties kerberosProperties; private KerberosProperties kerberosProperties;
private boolean allowExplicitKeytab; private boolean allowExplicitKeytab;
private UserGroupInformation mockUgi;
{
mockUgi = mock(UserGroupInformation.class);
try {
doAnswer(invocation -> {
PrivilegedExceptionAction<?> action = invocation.getArgument(0);
return action.run();
}).when(mockUgi).doAs(any(PrivilegedExceptionAction.class));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public MockHBaseClientService(final Table table, final String family, final KerberosProperties kerberosProperties) { public MockHBaseClientService(final Table table, final String family, final KerberosProperties kerberosProperties) {
this(table, family, kerberosProperties, false); this(table, family, kerberosProperties, false);
@ -209,4 +227,9 @@ public class MockHBaseClientService extends HBase_1_1_2_ClientService {
boolean isAllowExplicitKeytab() { boolean isAllowExplicitKeytab() {
return allowExplicitKeytab; return allowExplicitKeytab;
} }
@Override
UserGroupInformation getUgi() {
return mockUgi;
}
} }

View File

@ -62,7 +62,6 @@ import org.apache.nifi.hbase.scan.ResultCell;
import org.apache.nifi.hbase.scan.ResultHandler; import org.apache.nifi.hbase.scan.ResultHandler;
import org.apache.nifi.hbase.validate.ConfigFilesValidator; import org.apache.nifi.hbase.validate.ConfigFilesValidator;
import org.apache.nifi.kerberos.KerberosCredentialsService; import org.apache.nifi.kerberos.KerberosCredentialsService;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.util.StandardValidators; import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.reporting.InitializationException; import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.security.krb.KerberosKeytabUser; import org.apache.nifi.security.krb.KerberosKeytabUser;
@ -71,7 +70,6 @@ import org.apache.nifi.security.krb.KerberosUser;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import javax.security.auth.login.LoginException;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -460,47 +458,56 @@ public class HBase_2_ClientService extends AbstractControllerService implements
@Override @Override
public void put(final String tableName, final Collection<PutFlowFile> puts) throws IOException { public void put(final String tableName, final Collection<PutFlowFile> puts) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) { SecurityUtil.callWithUgi(getUgi(), () -> {
// Create one Put per row.... try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
final Map<String, List<PutColumn>> sorted = new HashMap<>(); // Create one Put per row....
final List<Put> newPuts = new ArrayList<>(); final Map<String, List<PutColumn>> sorted = new HashMap<>();
final List<Put> newPuts = new ArrayList<>();
for (final PutFlowFile putFlowFile : puts) { for (final PutFlowFile putFlowFile : puts) {
final String rowKeyString = new String(putFlowFile.getRow(), StandardCharsets.UTF_8); final String rowKeyString = new String(putFlowFile.getRow(), StandardCharsets.UTF_8);
List<PutColumn> columns = sorted.get(rowKeyString); List<PutColumn> columns = sorted.get(rowKeyString);
if (columns == null) { if (columns == null) {
columns = new ArrayList<>(); columns = new ArrayList<>();
sorted.put(rowKeyString, columns); sorted.put(rowKeyString, columns);
}
columns.addAll(putFlowFile.getColumns());
} }
columns.addAll(putFlowFile.getColumns()); for (final Map.Entry<String, List<PutColumn>> entry : sorted.entrySet()) {
newPuts.addAll(buildPuts(entry.getKey().getBytes(StandardCharsets.UTF_8), entry.getValue()));
}
table.put(newPuts);
} }
for (final Map.Entry<String, List<PutColumn>> entry : sorted.entrySet()) { return null;
newPuts.addAll(buildPuts(entry.getKey().getBytes(StandardCharsets.UTF_8), entry.getValue())); });
}
table.put(newPuts);
}
} }
@Override @Override
public void put(final String tableName, final byte[] rowId, final Collection<PutColumn> columns) throws IOException { public void put(final String tableName, final byte[] rowId, final Collection<PutColumn> columns) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) { SecurityUtil.callWithUgi(getUgi(), () -> {
table.put(buildPuts(rowId, new ArrayList(columns))); try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
} table.put(buildPuts(rowId, new ArrayList(columns)));
}
return null;
});
} }
@Override @Override
public boolean checkAndPut(final String tableName, final byte[] rowId, final byte[] family, final byte[] qualifier, final byte[] value, final PutColumn column) throws IOException { public boolean checkAndPut(final String tableName, final byte[] rowId, final byte[] family, final byte[] qualifier, final byte[] value, final PutColumn column) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) { return SecurityUtil.callWithUgi(getUgi(), () -> {
Put put = new Put(rowId); try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
put.addColumn( Put put = new Put(rowId);
column.getColumnFamily(), put.addColumn(
column.getColumnQualifier(), column.getColumnFamily(),
column.getBuffer()); column.getColumnQualifier(),
return table.checkAndPut(rowId, family, qualifier, value, put); column.getBuffer());
} return table.checkAndPut(rowId, family, qualifier, value, put);
}
});
} }
@Override @Override
@ -510,13 +517,16 @@ public class HBase_2_ClientService extends AbstractControllerService implements
@Override @Override
public void delete(String tableName, byte[] rowId, String visibilityLabel) throws IOException { public void delete(String tableName, byte[] rowId, String visibilityLabel) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) { SecurityUtil.callWithUgi(getUgi(), () -> {
Delete delete = new Delete(rowId); try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
if (!StringUtils.isEmpty(visibilityLabel)) { Delete delete = new Delete(rowId);
delete.setCellVisibility(new CellVisibility(visibilityLabel)); if (!StringUtils.isEmpty(visibilityLabel)) {
delete.setCellVisibility(new CellVisibility(visibilityLabel));
}
table.delete(delete);
} }
table.delete(delete); return null;
} });
} }
@Override @Override
@ -553,9 +563,12 @@ public class HBase_2_ClientService extends AbstractControllerService implements
} }
private void batchDelete(String tableName, List<Delete> deletes) throws IOException { private void batchDelete(String tableName, List<Delete> deletes) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName))) { SecurityUtil.callWithUgi(getUgi(), () -> {
table.delete(deletes); try (final Table table = connection.getTable(TableName.valueOf(tableName))) {
} table.delete(deletes);
}
return null;
});
} }
@Override @Override
@ -566,64 +579,70 @@ public class HBase_2_ClientService extends AbstractControllerService implements
@Override @Override
public void scan(String tableName, Collection<Column> columns, String filterExpression, long minTime, List<String> visibilityLabels, ResultHandler handler) throws IOException { public void scan(String tableName, Collection<Column> columns, String filterExpression, long minTime, List<String> visibilityLabels, ResultHandler handler) throws IOException {
Filter filter = null; SecurityUtil.callWithUgi(getUgi(), () -> {
if (!StringUtils.isBlank(filterExpression)) { Filter filter = null;
ParseFilter parseFilter = new ParseFilter(); if (!StringUtils.isBlank(filterExpression)) {
filter = parseFilter.parseFilterString(filterExpression); ParseFilter parseFilter = new ParseFilter();
} filter = parseFilter.parseFilterString(filterExpression);
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, columns, filter, minTime, visibilityLabels)) {
for (final Result result : scanner) {
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i=0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
} }
}
try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, columns, filter, minTime, visibilityLabels)) {
for (final Result result : scanner) {
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
}
}
return null;
});
} }
@Override @Override
public void scan(final String tableName, final byte[] startRow, final byte[] endRow, final Collection<Column> columns, List<String> authorizations, final ResultHandler handler) public void scan(final String tableName, final byte[] startRow, final byte[] endRow, final Collection<Column> columns, List<String> authorizations, final ResultHandler handler)
throws IOException { throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName)); SecurityUtil.callWithUgi(getUgi(), () -> {
final ResultScanner scanner = getResults(table, startRow, endRow, columns, authorizations)) { try (final Table table = connection.getTable(TableName.valueOf(tableName));
final ResultScanner scanner = getResults(table, startRow, endRow, columns, authorizations)) {
for (final Result result : scanner) { for (final Result result : scanner) {
final byte[] rowKey = result.getRow(); final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells(); final Cell[] cells = result.rawCells();
if (cells == null) { if (cells == null) {
continue; continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
} }
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i=0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
} }
} return null;
});
} }
@Override @Override
@ -631,37 +650,40 @@ public class HBase_2_ClientService extends AbstractControllerService implements
final Long timerangeMin, final Long timerangeMax, final Integer limitRows, final Boolean isReversed, final Long timerangeMin, final Long timerangeMax, final Integer limitRows, final Boolean isReversed,
final Boolean blockCache, final Collection<Column> columns, List<String> visibilityLabels, final ResultHandler handler) throws IOException { final Boolean blockCache, final Collection<Column> columns, List<String> visibilityLabels, final ResultHandler handler) throws IOException {
try (final Table table = connection.getTable(TableName.valueOf(tableName)); SecurityUtil.callWithUgi(getUgi(), () -> {
final ResultScanner scanner = getResults(table, startRow, endRow, filterExpression, timerangeMin, try (final Table table = connection.getTable(TableName.valueOf(tableName));
timerangeMax, limitRows, isReversed, blockCache, columns, visibilityLabels)) { final ResultScanner scanner = getResults(table, startRow, endRow, filterExpression, timerangeMin,
timerangeMax, limitRows, isReversed, blockCache, columns, visibilityLabels)) {
int cnt = 0; int cnt = 0;
final int lim = limitRows != null ? limitRows : 0; final int lim = limitRows != null ? limitRows : 0;
for (final Result result : scanner) { for (final Result result : scanner) {
if (lim > 0 && ++cnt > lim){ if (lim > 0 && ++cnt > lim) {
break; break;
}
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
} }
final byte[] rowKey = result.getRow();
final Cell[] cells = result.rawCells();
if (cells == null) {
continue;
}
// convert HBase cells to NiFi cells
final ResultCell[] resultCells = new ResultCell[cells.length];
for (int i = 0; i < cells.length; i++) {
final Cell cell = cells[i];
final ResultCell resultCell = getResultCell(cell);
resultCells[i] = resultCell;
}
// delegate to the handler
handler.handle(rowKey, resultCells);
} }
} return null;
});
} }
// //
@ -866,22 +888,8 @@ public class HBase_2_ClientService extends AbstractControllerService implements
return Boolean.parseBoolean(System.getenv(ALLOW_EXPLICIT_KEYTAB)); return Boolean.parseBoolean(System.getenv(ALLOW_EXPLICIT_KEYTAB));
} }
UserGroupInformation getUgi() { UserGroupInformation getUgi() throws IOException {
getLogger().trace("getting UGI instance"); SecurityUtil.checkTGTAndRelogin(getLogger(), kerberosUserReference.get());
if (kerberosUserReference.get() != null) {
// if there's a KerberosUser associated with this UGI, check the TGT and relogin if it is close to expiring
KerberosUser kerberosUser = kerberosUserReference.get();
getLogger().debug("kerberosUser is " + kerberosUser);
try {
getLogger().debug("checking TGT on kerberosUser [{}]", new Object[] {kerberosUser});
kerberosUser.checkTGTAndRelogin();
} catch (LoginException e) {
throw new ProcessException("Unable to relogin with kerberos credentials for " + kerberosUser.getPrincipal(), e);
}
} else {
getLogger().debug("kerberosUser was null, will not refresh TGT with KerberosUser");
}
return ugi; return ugi;
} }
} }

View File

@ -23,6 +23,7 @@ import org.apache.hadoop.hbase.client.Result;
import org.apache.hadoop.hbase.client.ResultScanner; import org.apache.hadoop.hbase.client.ResultScanner;
import org.apache.hadoop.hbase.client.Table; import org.apache.hadoop.hbase.client.Table;
import org.apache.hadoop.hbase.filter.Filter; import org.apache.hadoop.hbase.filter.Filter;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.nifi.controller.ConfigurationContext; import org.apache.nifi.controller.ConfigurationContext;
import org.apache.nifi.hadoop.KerberosProperties; import org.apache.nifi.hadoop.KerberosProperties;
import org.apache.nifi.hbase.put.PutColumn; import org.apache.nifi.hbase.put.PutColumn;
@ -33,6 +34,7 @@ import org.mockito.Mockito;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
@ -40,6 +42,9 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
/** /**
@ -52,6 +57,19 @@ public class MockHBaseClientService extends HBase_2_ClientService {
private Map<String, Result> results = new HashMap<>(); private Map<String, Result> results = new HashMap<>();
private KerberosProperties kerberosProperties; private KerberosProperties kerberosProperties;
private boolean allowExplicitKeytab; private boolean allowExplicitKeytab;
private UserGroupInformation mockUgi;
{
mockUgi = mock(UserGroupInformation.class);
try {
doAnswer(invocation -> {
PrivilegedExceptionAction<?> action = invocation.getArgument(0);
return action.run();
}).when(mockUgi).doAs(any(PrivilegedExceptionAction.class));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public MockHBaseClientService(final Table table, final String family, final KerberosProperties kerberosProperties) { public MockHBaseClientService(final Table table, final String family, final KerberosProperties kerberosProperties) {
this(table, family, kerberosProperties, false); this(table, family, kerberosProperties, false);
@ -79,7 +97,7 @@ public class MockHBaseClientService extends HBase_2_ClientService {
final Cell[] cellArray = new Cell[cells.size()]; final Cell[] cellArray = new Cell[cells.size()];
int i = 0; int i = 0;
for (final Map.Entry<String, String> cellEntry : cells.entrySet()) { for (final Map.Entry<String, String> cellEntry : cells.entrySet()) {
final Cell cell = Mockito.mock(Cell.class); final Cell cell = mock(Cell.class);
when(cell.getRowArray()).thenReturn(rowArray); when(cell.getRowArray()).thenReturn(rowArray);
when(cell.getRowOffset()).thenReturn(0); when(cell.getRowOffset()).thenReturn(0);
when(cell.getRowLength()).thenReturn((short) rowArray.length); when(cell.getRowLength()).thenReturn((short) rowArray.length);
@ -106,7 +124,7 @@ public class MockHBaseClientService extends HBase_2_ClientService {
cellArray[i++] = cell; cellArray[i++] = cell;
} }
final Result result = Mockito.mock(Result.class); final Result result = mock(Result.class);
when(result.getRow()).thenReturn(rowArray); when(result.getRow()).thenReturn(rowArray);
when(result.rawCells()).thenReturn(cellArray); when(result.rawCells()).thenReturn(cellArray);
results.put(rowKey, result); results.put(rowKey, result);
@ -179,28 +197,28 @@ public class MockHBaseClientService extends HBase_2_ClientService {
} }
protected ResultScanner getResults(Table table, byte[] startRow, byte[] endRow, Collection<Column> columns, List<String> labels) throws IOException { protected ResultScanner getResults(Table table, byte[] startRow, byte[] endRow, Collection<Column> columns, List<String> labels) throws IOException {
final ResultScanner scanner = Mockito.mock(ResultScanner.class); final ResultScanner scanner = mock(ResultScanner.class);
Mockito.when(scanner.iterator()).thenReturn(results.values().iterator()); Mockito.when(scanner.iterator()).thenReturn(results.values().iterator());
return scanner; return scanner;
} }
@Override @Override
protected ResultScanner getResults(Table table, Collection<Column> columns, Filter filter, long minTime, List<String> labels) throws IOException { protected ResultScanner getResults(Table table, Collection<Column> columns, Filter filter, long minTime, List<String> labels) throws IOException {
final ResultScanner scanner = Mockito.mock(ResultScanner.class); final ResultScanner scanner = mock(ResultScanner.class);
Mockito.when(scanner.iterator()).thenReturn(results.values().iterator()); Mockito.when(scanner.iterator()).thenReturn(results.values().iterator());
return scanner; return scanner;
} }
protected ResultScanner getResults(final Table table, final String startRow, final String endRow, final String filterExpression, final Long timerangeMin, final Long timerangeMax, protected ResultScanner getResults(final Table table, final String startRow, final String endRow, final String filterExpression, final Long timerangeMin, final Long timerangeMax,
final Integer limitRows, final Boolean isReversed, final Collection<Column> columns) throws IOException { final Integer limitRows, final Boolean isReversed, final Collection<Column> columns) throws IOException {
final ResultScanner scanner = Mockito.mock(ResultScanner.class); final ResultScanner scanner = mock(ResultScanner.class);
Mockito.when(scanner.iterator()).thenReturn(results.values().iterator()); Mockito.when(scanner.iterator()).thenReturn(results.values().iterator());
return scanner; return scanner;
} }
@Override @Override
protected Connection createConnection(ConfigurationContext context) throws IOException { protected Connection createConnection(ConfigurationContext context) throws IOException {
Connection connection = Mockito.mock(Connection.class); Connection connection = mock(Connection.class);
Mockito.when(connection.getTable(table.getName())).thenReturn(table); Mockito.when(connection.getTable(table.getName())).thenReturn(table);
return connection; return connection;
} }
@ -209,4 +227,9 @@ public class MockHBaseClientService extends HBase_2_ClientService {
boolean isAllowExplicitKeytab() { boolean isAllowExplicitKeytab() {
return allowExplicitKeytab; return allowExplicitKeytab;
} }
@Override
UserGroupInformation getUgi() throws IOException {
return mockUgi;
}
} }