HLRC changes for kerberos grant type (#43642) (#43822)

The TODO from last PR for kerbero grant type was missed.
This commit adds the changes for kerberos grant type in HLRC.
This commit is contained in:
Yogesh Gaikwad 2019-07-02 00:55:02 +10:00 committed by GitHub
parent 1e47ea5f18
commit 031d5e96ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 197 additions and 56 deletions

View File

@ -40,6 +40,7 @@ public final class CreateTokenRequest implements Validatable, ToXContentObject {
private final String username;
private final char[] password;
private final String refreshToken;
private final char[] kerberosTicket;
/**
* General purpose constructor. This constructor is typically not useful, and one of the following factory methods should be used
@ -48,10 +49,11 @@ public final class CreateTokenRequest implements Validatable, ToXContentObject {
* <li>{@link #passwordGrant(String, char[])}</li>
* <li>{@link #refreshTokenGrant(String)}</li>
* <li>{@link #clientCredentialsGrant()}</li>
* <li>{@link #kerberosGrant(char[])}</li>
* </ul>
*/
public CreateTokenRequest(String grantType, @Nullable String scope, @Nullable String username, @Nullable char[] password,
@Nullable String refreshToken) {
@Nullable String refreshToken, @Nullable char[] kerberosTicket) {
if (Strings.isNullOrEmpty(grantType)) {
throw new IllegalArgumentException("grant_type is required");
}
@ -60,6 +62,7 @@ public final class CreateTokenRequest implements Validatable, ToXContentObject {
this.password = password;
this.scope = scope;
this.refreshToken = refreshToken;
this.kerberosTicket = kerberosTicket;
}
public static CreateTokenRequest passwordGrant(String username, char[] password) {
@ -69,18 +72,25 @@ public final class CreateTokenRequest implements Validatable, ToXContentObject {
if (password == null || password.length == 0) {
throw new IllegalArgumentException("password is required");
}
return new CreateTokenRequest("password", null, username, password, null);
return new CreateTokenRequest("password", null, username, password, null, null);
}
public static CreateTokenRequest refreshTokenGrant(String refreshToken) {
if (Strings.isNullOrEmpty(refreshToken)) {
throw new IllegalArgumentException("refresh_token is required");
}
return new CreateTokenRequest("refresh_token", null, null, null, refreshToken);
return new CreateTokenRequest("refresh_token", null, null, null, refreshToken, null);
}
public static CreateTokenRequest clientCredentialsGrant() {
return new CreateTokenRequest("client_credentials", null, null, null, null);
return new CreateTokenRequest("client_credentials", null, null, null, null, null);
}
public static CreateTokenRequest kerberosGrant(char[] kerberosTicket) {
if (kerberosTicket == null || kerberosTicket.length == 0) {
throw new IllegalArgumentException("kerberos ticket is required");
}
return new CreateTokenRequest("_kerberos", null, null, null, null, kerberosTicket);
}
public String getGrantType() {
@ -103,6 +113,10 @@ public final class CreateTokenRequest implements Validatable, ToXContentObject {
return refreshToken;
}
public char[] getKerberosTicket() {
return kerberosTicket;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject()
@ -124,6 +138,14 @@ public final class CreateTokenRequest implements Validatable, ToXContentObject {
if (refreshToken != null) {
builder.field("refresh_token", refreshToken);
}
if (kerberosTicket != null) {
byte[] kerberosTicketBytes = CharArrays.toUtf8Bytes(kerberosTicket);
try {
builder.field("kerberos_ticket").utf8Value(kerberosTicketBytes, 0, kerberosTicketBytes.length);
} finally {
Arrays.fill(kerberosTicketBytes, (byte) 0);
}
}
return builder.endObject();
}
@ -140,13 +162,15 @@ public final class CreateTokenRequest implements Validatable, ToXContentObject {
Objects.equals(scope, that.scope) &&
Objects.equals(username, that.username) &&
Arrays.equals(password, that.password) &&
Objects.equals(refreshToken, that.refreshToken);
Objects.equals(refreshToken, that.refreshToken) &&
Arrays.equals(kerberosTicket, that.kerberosTicket);
}
@Override
public int hashCode() {
int result = Objects.hash(grantType, scope, username, refreshToken);
result = 31 * result + Arrays.hashCode(password);
result = 31 * result + Arrays.hashCode(kerberosTicket);
return result;
}
}

View File

@ -66,31 +66,54 @@ public class CreateTokenRequestTests extends ESTestCase {
assertThat(Strings.toString(request), equalTo("{\"grant_type\":\"client_credentials\"}"));
}
public void testCreateTokenFromKerberosTicket() {
final CreateTokenRequest request = CreateTokenRequest.kerberosGrant("top secret kerberos ticket".toCharArray());
assertThat(request.getGrantType(), equalTo("_kerberos"));
assertThat(request.getScope(), nullValue());
assertThat(request.getUsername(), nullValue());
assertThat(request.getPassword(), nullValue());
assertThat(request.getRefreshToken(), nullValue());
assertThat(new String(request.getKerberosTicket()), equalTo("top secret kerberos ticket"));
assertThat(Strings.toString(request), equalTo("{\"grant_type\":\"_kerberos\"," +
"\"kerberos_ticket\":\"top secret kerberos ticket\"}"));
}
public void testEqualsAndHashCode() {
final String grantType = randomAlphaOfLength(8);
final String scope = randomBoolean() ? null : randomAlphaOfLength(6);
final String username = randomBoolean() ? null : randomAlphaOfLengthBetween(4, 10);
final char[] password = randomBoolean() ? null : randomAlphaOfLengthBetween(8, 12).toCharArray();
final String refreshToken = randomBoolean() ? null : randomAlphaOfLengthBetween(12, 24);
final CreateTokenRequest request = new CreateTokenRequest(grantType, scope, username, password, refreshToken);
final char[] kerberosTicket = randomBoolean() ? null : randomAlphaOfLengthBetween(8, 12).toCharArray();
final CreateTokenRequest request = new CreateTokenRequest(grantType, scope, username, password, refreshToken, kerberosTicket);
EqualsHashCodeTestUtils.checkEqualsAndHashCode(request,
r -> new CreateTokenRequest(r.getGrantType(), r.getScope(), r.getUsername(), r.getPassword(), r.getRefreshToken()),
r -> new CreateTokenRequest(r.getGrantType(), r.getScope(), r.getUsername(), r.getPassword(),
r.getRefreshToken(), r.getKerberosTicket()),
this::mutate);
}
private CreateTokenRequest mutate(CreateTokenRequest req) {
switch (randomIntBetween(1, 5)) {
case 1:
return new CreateTokenRequest("g", req.getScope(), req.getUsername(), req.getPassword(), req.getRefreshToken());
case 2:
return new CreateTokenRequest(req.getGrantType(), "s", req.getUsername(), req.getPassword(), req.getRefreshToken());
case 3:
return new CreateTokenRequest(req.getGrantType(), req.getScope(), "u", req.getPassword(), req.getRefreshToken());
case 4:
final char[] password = {'p'};
return new CreateTokenRequest(req.getGrantType(), req.getScope(), req.getUsername(), password, req.getRefreshToken());
case 5:
return new CreateTokenRequest(req.getGrantType(), req.getScope(), req.getUsername(), req.getPassword(), "r");
switch (randomIntBetween(1, 6)) {
case 1:
return new CreateTokenRequest("g", req.getScope(), req.getUsername(), req.getPassword(), req.getRefreshToken(),
req.getKerberosTicket());
case 2:
return new CreateTokenRequest(req.getGrantType(), "s", req.getUsername(), req.getPassword(), req.getRefreshToken(),
req.getKerberosTicket());
case 3:
return new CreateTokenRequest(req.getGrantType(), req.getScope(), "u", req.getPassword(), req.getRefreshToken(),
req.getKerberosTicket());
case 4:
final char[] password = { 'p' };
return new CreateTokenRequest(req.getGrantType(), req.getScope(), req.getUsername(), password, req.getRefreshToken(),
req.getKerberosTicket());
case 5:
final char[] kerberosTicket = { 'k' };
return new CreateTokenRequest(req.getGrantType(), req.getScope(), req.getUsername(), req.getPassword(), req.getRefreshToken(),
kerberosTicket);
case 6:
return new CreateTokenRequest(req.getGrantType(), req.getScope(), req.getUsername(), req.getPassword(), "r",
req.getKerberosTicket());
}
throw new IllegalStateException("Bad random number");
}

View File

@ -30,11 +30,13 @@ import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.CharArrays;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.geo.GeoPoint;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
@ -59,6 +61,7 @@ import java.time.Instant;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
@ -358,6 +361,21 @@ public abstract class StreamInput extends InputStream {
return null;
}
@Nullable
public SecureString readOptionalSecureString() throws IOException {
SecureString value = null;
BytesReference bytesRef = readOptionalBytesReference();
if (bytesRef != null) {
byte[] bytes = BytesReference.toBytes(bytesRef);
try {
value = new SecureString(CharArrays.utf8BytesToChars(bytes));
} finally {
Arrays.fill(bytes, (byte) 0);
}
}
return value;
}
@Nullable
public Float readOptionalFloat() throws IOException {
if (readBoolean()) {
@ -415,6 +433,16 @@ public abstract class StreamInput extends InputStream {
return spare.toString();
}
public SecureString readSecureString() throws IOException {
BytesReference bytesRef = readBytesReference();
byte[] bytes = BytesReference.toBytes(bytesRef);
try {
return new SecureString(CharArrays.utf8BytesToChars(bytes));
} finally {
Arrays.fill(bytes, (byte) 0);
}
}
public final float readFloat() throws IOException {
return Float.intBitsToFloat(readInt());
}

View File

@ -32,10 +32,13 @@ import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.common.CharArrays;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.geo.GeoPoint;
import org.elasticsearch.common.io.stream.Writeable.Writer;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
@ -58,6 +61,7 @@ import java.nio.file.NotDirectoryException;
import java.time.ZoneId;
import java.time.Instant;
import java.time.ZonedDateTime;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
@ -324,6 +328,19 @@ public abstract class StreamOutput extends OutputStream {
}
}
public void writeOptionalSecureString(@Nullable SecureString secureStr) throws IOException {
if (secureStr == null) {
writeOptionalBytesReference(null);
} else {
final byte[] secureStrBytes = CharArrays.toUtf8Bytes(secureStr.getChars());
try {
writeOptionalBytesReference(new BytesArray(secureStrBytes));
} finally {
Arrays.fill(secureStrBytes, (byte) 0);
}
}
}
/**
* Writes an optional {@link Integer}.
*/
@ -414,6 +431,15 @@ public abstract class StreamOutput extends OutputStream {
writeBytes(buffer, offset);
}
public void writeSecureString(SecureString secureStr) throws IOException {
final byte[] secureStrBytes = CharArrays.toUtf8Bytes(secureStr.getChars());
try {
writeBytesReference(new BytesArray(secureStrBytes));
} finally {
Arrays.fill(secureStrBytes, (byte) 0);
}
}
public void writeFloat(float v) throws IOException {
writeInt(Float.floatToIntBits(v));
}

View File

@ -25,6 +25,7 @@ import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.test.ESTestCase;
import java.io.ByteArrayInputStream;
@ -49,7 +50,9 @@ import java.util.stream.IntStream;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasToString;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.iterableWithSize;
import static org.hamcrest.Matchers.nullValue;
public class StreamTests extends ESTestCase {
@ -405,4 +408,30 @@ public class StreamTests extends ESTestCase {
}
}
public void testSecureStringSerialization() throws IOException {
try (BytesStreamOutput output = new BytesStreamOutput()) {
final SecureString secureString = new SecureString("super secret".toCharArray());
output.writeSecureString(secureString);
final BytesReference bytesReference = output.bytes();
final StreamInput input = bytesReference.streamInput();
assertThat(secureString, is(equalTo(input.readSecureString())));
}
try (BytesStreamOutput output = new BytesStreamOutput()) {
final SecureString secureString = randomBoolean() ? null : new SecureString("super secret".toCharArray());
output.writeOptionalSecureString(secureString);
final BytesReference bytesReference = output.bytes();
final StreamInput input = bytesReference.streamInput();
if (secureString != null) {
assertThat(input.readOptionalSecureString(), is(equalTo(secureString)));
} else {
assertThat(input.readOptionalSecureString(), is(nullValue()));
}
}
}
}

View File

@ -8,17 +8,13 @@ package org.elasticsearch.xpack.core.security.action.token;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.CharArrays;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.SecureString;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.Locale;
@ -214,32 +210,23 @@ public final class CreateTokenRequest extends ActionRequest {
throw new IllegalArgumentException("a request with the client_credentials grant_type cannot be sent to version [" +
out.getVersion() + "]");
}
if (out.getVersion().before(Version.V_7_3_0) && GrantType.KERBEROS.getValue().equals(grantType)) {
throw new IllegalArgumentException("a request with the _kerberos grant_type cannot be sent to version [" +
out.getVersion() + "]");
}
out.writeString(grantType);
if (out.getVersion().onOrAfter(Version.V_6_2_0)) {
out.writeOptionalString(username);
if (password == null) {
out.writeOptionalBytesReference(null);
} else {
final byte[] passwordBytes = CharArrays.toUtf8Bytes(password.getChars());
try {
out.writeOptionalBytesReference(new BytesArray(passwordBytes));
} finally {
Arrays.fill(passwordBytes, (byte) 0);
}
}
out.writeOptionalSecureString(password);
out.writeOptionalString(refreshToken);
out.writeOptionalSecureString(kerberosTicket);
} else {
if ("refresh_token".equals(grantType)) {
throw new IllegalArgumentException("a refresh request cannot be sent to an older version");
} else {
out.writeString(username);
final byte[] passwordBytes = CharArrays.toUtf8Bytes(password.getChars());
try {
out.writeByteArray(passwordBytes);
} finally {
Arrays.fill(passwordBytes, (byte) 0);
}
out.writeSecureString(password);
}
}
out.writeOptionalString(scope);
@ -251,26 +238,12 @@ public final class CreateTokenRequest extends ActionRequest {
grantType = in.readString();
if (in.getVersion().onOrAfter(Version.V_6_2_0)) {
username = in.readOptionalString();
BytesReference bytesRef = in.readOptionalBytesReference();
if (bytesRef != null) {
byte[] bytes = BytesReference.toBytes(bytesRef);
try {
password = new SecureString(CharArrays.utf8BytesToChars(bytes));
} finally {
Arrays.fill(bytes, (byte) 0);
}
} else {
password = null;
}
password = in.readOptionalSecureString();
refreshToken = in.readOptionalString();
kerberosTicket = in.readOptionalSecureString();
} else {
username = in.readString();
final byte[] passwordBytes = in.readByteArray();
try {
password = new SecureString(CharArrays.utf8BytesToChars(passwordBytes));
} finally {
Arrays.fill(passwordBytes, (byte) 0);
}
password = in.readSecureString();
}
scope = in.readOptionalString();
}

View File

@ -6,8 +6,15 @@
package org.elasticsearch.xpack.core.security.action.token;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.security.action.token.CreateTokenRequest.GrantType;
import java.io.IOException;
import java.util.Arrays;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasItem;
@ -106,4 +113,35 @@ public class CreateTokenRequestTests extends ESTestCase {
ve = request.validate();
assertNull(ve);
}
public void testSerialization() throws IOException {
final String grantType = randomFrom(Arrays.stream(GrantType.values()).map(gt -> gt.getValue()).collect(Collectors.toList()));
final String username = randomBoolean() ? randomAlphaOfLength(5) : null;
final String scope = randomBoolean() ? randomAlphaOfLength(5) : null;
final SecureString password = randomBoolean() ? new SecureString(new char[] { 'p', 'a', 's', 's' }) : null;
final SecureString kerberosTicket = randomBoolean() ? new SecureString(new char[] { 'k', 'e', 'r', 'b' }) : null;
final String refreshToken = randomBoolean() ? randomAlphaOfLength(5) : null;
final CreateTokenRequest request = new CreateTokenRequest(grantType, username, password, kerberosTicket, scope, refreshToken);
try (BytesStreamOutput out = new BytesStreamOutput()) {
request.writeTo(out);
try (StreamInput in = out.bytes().streamInput()) {
final CreateTokenRequest serialized = new CreateTokenRequest();
serialized.readFrom(in);
assertEquals(grantType, serialized.getGrantType());
if (scope != null) {
assertEquals(scope, serialized.getScope());
}
if (password != null) {
assertEquals(password, serialized.getPassword());
}
if (kerberosTicket != null) {
assertEquals(kerberosTicket, serialized.getKerberosTicket());
}
if (refreshToken != null) {
assertEquals(refreshToken, serialized.getRefreshToken());
}
}
}
}
}