68516fc05c
Builder looks up tenant ID before asking for token. Client config did not allow that. Also found that token provider was not properly initialized. Fixes 7267
405 lines
10 KiB
Go
405 lines
10 KiB
Go
package arm
|
|
|
|
import (
|
|
crand "crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
mrand "math/rand"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Azure/go-autorest/autorest/azure"
|
|
jwt "github.com/dgrijalva/jwt-go"
|
|
"github.com/hashicorp/packer/packer"
|
|
)
|
|
|
|
func Test_ClientConfig_RequiredParametersSet(t *testing.T) {
|
|
|
|
tests := []struct {
|
|
name string
|
|
config ClientConfig
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "no client_id, client_secret or subscription_id should enable MSI auth",
|
|
config: ClientConfig{},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "subscription_id is set will trigger device flow",
|
|
config: ClientConfig{
|
|
SubscriptionID: "error",
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "client_id without client_secret, client_cert_path or client_jwt should error",
|
|
config: ClientConfig{
|
|
ClientID: "error",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "client_secret without client_id should error",
|
|
config: ClientConfig{
|
|
ClientSecret: "error",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "client_cert_path without client_id should error",
|
|
config: ClientConfig{
|
|
ClientCertPath: "/dev/null",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "client_jwt without client_id should error",
|
|
config: ClientConfig{
|
|
ClientJWT: "error",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "missing subscription_id when using secret",
|
|
config: ClientConfig{
|
|
ClientID: "ok",
|
|
ClientSecret: "ok",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "missing subscription_id when using certificate",
|
|
config: ClientConfig{
|
|
ClientID: "ok",
|
|
ClientCertPath: "ok",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "missing subscription_id when using JWT",
|
|
config: ClientConfig{
|
|
ClientID: "ok",
|
|
ClientJWT: "ok",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "too many client_* values",
|
|
config: ClientConfig{
|
|
SubscriptionID: "ok",
|
|
ClientID: "ok",
|
|
ClientSecret: "ok",
|
|
ClientCertPath: "error",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "too many client_* values (2)",
|
|
config: ClientConfig{
|
|
SubscriptionID: "ok",
|
|
ClientID: "ok",
|
|
ClientSecret: "ok",
|
|
ClientJWT: "error",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "tenant_id alone should fail",
|
|
config: ClientConfig{
|
|
TenantID: "ok",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
errs := &packer.MultiError{}
|
|
tt.config.assertRequiredParametersSet(errs)
|
|
if (len(errs.Errors) != 0) != tt.wantErr {
|
|
t.Errorf("newConfig() error = %v, wantErr %v", errs, tt.wantErr)
|
|
return
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_ClientConfig_DeviceLogin(t *testing.T) {
|
|
getEnvOrSkip(t, "AZURE_DEVICE_LOGIN")
|
|
cfg := ClientConfig{
|
|
SubscriptionID: getEnvOrSkip(t, "AZURE_SUBSCRIPTION"),
|
|
cloudEnvironment: getCloud(),
|
|
}
|
|
assertValid(t, cfg)
|
|
|
|
spt, sptkv, err := cfg.getServicePrincipalTokens(
|
|
func(s string) { fmt.Printf("SAY: %s\n", s) })
|
|
if err != nil {
|
|
t.Fatalf("Expected nil err, but got: %v", err)
|
|
}
|
|
token := spt.Token()
|
|
if token.AccessToken == "" {
|
|
t.Fatal("Expected management token to have non-nil access token")
|
|
}
|
|
if token.RefreshToken == "" {
|
|
t.Fatal("Expected management token to have non-nil refresh token")
|
|
}
|
|
kvtoken := sptkv.Token()
|
|
if kvtoken.AccessToken == "" {
|
|
t.Fatal("Expected keyvault token to have non-nil access token")
|
|
}
|
|
if kvtoken.RefreshToken == "" {
|
|
t.Fatal("Expected keyvault token to have non-nil refresh token")
|
|
}
|
|
}
|
|
|
|
func Test_ClientConfig_ClientPassword(t *testing.T) {
|
|
cfg := ClientConfig{
|
|
SubscriptionID: getEnvOrSkip(t, "AZURE_SUBSCRIPTION"),
|
|
ClientID: getEnvOrSkip(t, "AZURE_CLIENTID"),
|
|
ClientSecret: getEnvOrSkip(t, "AZURE_CLIENTSECRET"),
|
|
TenantID: getEnvOrSkip(t, "AZURE_TENANTID"),
|
|
cloudEnvironment: getCloud(),
|
|
}
|
|
assertValid(t, cfg)
|
|
|
|
spt, sptkv, err := cfg.getServicePrincipalTokens(func(s string) { fmt.Printf("SAY: %s\n", s) })
|
|
if err != nil {
|
|
t.Fatalf("Expected nil err, but got: %v", err)
|
|
}
|
|
token := spt.Token()
|
|
if token.AccessToken == "" {
|
|
t.Fatal("Expected management token to have non-nil access token")
|
|
}
|
|
if token.RefreshToken != "" {
|
|
t.Fatal("Expected management token to have no refresh token")
|
|
}
|
|
kvtoken := sptkv.Token()
|
|
if kvtoken.AccessToken == "" {
|
|
t.Fatal("Expected keyvault token to have non-nil access token")
|
|
}
|
|
if kvtoken.RefreshToken != "" {
|
|
t.Fatal("Expected keyvault token to have no refresh token")
|
|
}
|
|
}
|
|
|
|
func Test_ClientConfig_ClientCert(t *testing.T) {
|
|
cfg := ClientConfig{
|
|
SubscriptionID: getEnvOrSkip(t, "AZURE_SUBSCRIPTION"),
|
|
ClientID: getEnvOrSkip(t, "AZURE_CLIENTID"),
|
|
ClientCertPath: getEnvOrSkip(t, "AZURE_CLIENTCERT"),
|
|
TenantID: getEnvOrSkip(t, "AZURE_TENANTID"),
|
|
cloudEnvironment: getCloud(),
|
|
}
|
|
assertValid(t, cfg)
|
|
|
|
spt, sptkv, err := cfg.getServicePrincipalTokens(func(s string) { fmt.Printf("SAY: %s\n", s) })
|
|
if err != nil {
|
|
t.Fatalf("Expected nil err, but got: %v", err)
|
|
}
|
|
token := spt.Token()
|
|
if token.AccessToken == "" {
|
|
t.Fatal("Expected management token to have non-nil access token")
|
|
}
|
|
if token.RefreshToken != "" {
|
|
t.Fatal("Expected management token to have no refresh token")
|
|
}
|
|
kvtoken := sptkv.Token()
|
|
if kvtoken.AccessToken == "" {
|
|
t.Fatal("Expected keyvault token to have non-nil access token")
|
|
}
|
|
if kvtoken.RefreshToken != "" {
|
|
t.Fatal("Expected keyvault token to have no refresh token")
|
|
}
|
|
}
|
|
|
|
func Test_ClientConfig_ClientJWT(t *testing.T) {
|
|
cfg := ClientConfig{
|
|
SubscriptionID: getEnvOrSkip(t, "AZURE_SUBSCRIPTION"),
|
|
ClientID: getEnvOrSkip(t, "AZURE_CLIENTID"),
|
|
ClientJWT: getEnvOrSkip(t, "AZURE_CLIENTJWT"),
|
|
TenantID: getEnvOrSkip(t, "AZURE_TENANTID"),
|
|
cloudEnvironment: getCloud(),
|
|
}
|
|
assertValid(t, cfg)
|
|
|
|
spt, sptkv, err := cfg.getServicePrincipalTokens(func(s string) { fmt.Printf("SAY: %s\n", s) })
|
|
if err != nil {
|
|
t.Fatalf("Expected nil err, but got: %v", err)
|
|
}
|
|
token := spt.Token()
|
|
if token.AccessToken == "" {
|
|
t.Fatal("Expected management token to have non-nil access token")
|
|
}
|
|
if token.RefreshToken != "" {
|
|
t.Fatal("Expected management token to have no refresh token")
|
|
}
|
|
kvtoken := sptkv.Token()
|
|
if kvtoken.AccessToken == "" {
|
|
t.Fatal("Expected keyvault token to have non-nil access token")
|
|
}
|
|
if kvtoken.RefreshToken != "" {
|
|
t.Fatal("Expected keyvault token to have no refresh token")
|
|
}
|
|
}
|
|
|
|
func getEnvOrSkip(t *testing.T, envVar string) string {
|
|
v := os.Getenv(envVar)
|
|
if v == "" {
|
|
t.Skipf("%s is empty, skipping", envVar)
|
|
}
|
|
return v
|
|
}
|
|
|
|
func getCloud() *azure.Environment {
|
|
cloudName := os.Getenv("AZURE_CLOUD")
|
|
if cloudName == "" {
|
|
cloudName = "AZUREPUBLICCLOUD"
|
|
}
|
|
c, _ := azure.EnvironmentFromName(cloudName)
|
|
return &c
|
|
}
|
|
|
|
// tests for assertRequiredParametersSet
|
|
|
|
func Test_ClientConfig_CanUseDeviceCode(t *testing.T) {
|
|
// TenantID is optional, but Builder will look up tenant ID before requesting
|
|
t.Run("without TenantID", func(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
assertValid(t, cfg)
|
|
})
|
|
t.Run("with TenantID", func(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.TenantID = "12345"
|
|
assertValid(t, cfg)
|
|
})
|
|
}
|
|
|
|
func assertValid(t *testing.T, cfg ClientConfig) {
|
|
errs := &packer.MultiError{}
|
|
cfg.assertRequiredParametersSet(errs)
|
|
if len(errs.Errors) != 0 {
|
|
t.Fatal("Expected errs to be empty: ", errs)
|
|
}
|
|
}
|
|
|
|
func assertInvalid(t *testing.T, cfg ClientConfig) {
|
|
errs := &packer.MultiError{}
|
|
cfg.assertRequiredParametersSet(errs)
|
|
if len(errs.Errors) == 0 {
|
|
t.Fatal("Expected errs to be non-empty")
|
|
}
|
|
}
|
|
|
|
func Test_ClientConfig_CanUseClientSecret(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.ClientID = "12345"
|
|
cfg.ClientSecret = "12345"
|
|
|
|
assertValid(t, cfg)
|
|
}
|
|
|
|
func Test_ClientConfig_CanUseClientSecretWithTenantID(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.ClientID = "12345"
|
|
cfg.ClientSecret = "12345"
|
|
cfg.TenantID = "12345"
|
|
|
|
assertValid(t, cfg)
|
|
}
|
|
|
|
func Test_ClientConfig_CanUseClientJWT(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.ClientID = "12345"
|
|
cfg.ClientJWT = getJWT(10*time.Minute, true)
|
|
|
|
assertValid(t, cfg)
|
|
}
|
|
|
|
func Test_ClientConfig_CanUseClientJWTWithTenantID(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.ClientID = "12345"
|
|
cfg.ClientJWT = getJWT(10*time.Minute, true)
|
|
cfg.TenantID = "12345"
|
|
|
|
assertValid(t, cfg)
|
|
}
|
|
|
|
func Test_ClientConfig_CannotUseBothClientJWTAndSecret(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.ClientID = "12345"
|
|
cfg.ClientSecret = "12345"
|
|
cfg.ClientJWT = getJWT(10*time.Minute, true)
|
|
|
|
assertInvalid(t, cfg)
|
|
}
|
|
|
|
func Test_ClientConfig_ClientJWTShouldBeValidForAtLeast5Minutes(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.ClientID = "12345"
|
|
cfg.ClientJWT = getJWT(time.Minute, true)
|
|
|
|
assertInvalid(t, cfg)
|
|
}
|
|
|
|
func Test_ClientConfig_ClientJWTShouldHaveThumbprint(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.ClientID = "12345"
|
|
cfg.ClientJWT = getJWT(10*time.Minute, false)
|
|
|
|
assertInvalid(t, cfg)
|
|
}
|
|
|
|
func emptyClientConfig() ClientConfig {
|
|
cfg := ClientConfig{}
|
|
_ = cfg.setCloudEnvironment()
|
|
return cfg
|
|
}
|
|
|
|
func Test_getJWT(t *testing.T) {
|
|
if getJWT(time.Minute, true) == "" {
|
|
t.Fatalf("getJWT is broken")
|
|
}
|
|
}
|
|
|
|
func newRandReader() io.Reader {
|
|
var seed int64
|
|
binary.Read(crand.Reader, binary.LittleEndian, &seed)
|
|
|
|
return mrand.New(mrand.NewSource(seed))
|
|
}
|
|
|
|
func getJWT(validFor time.Duration, withX5tHeader bool) string {
|
|
token := jwt.New(jwt.SigningMethodRS256)
|
|
key, _ := rsa.GenerateKey(newRandReader(), 2048)
|
|
|
|
token.Claims = jwt.MapClaims{
|
|
"aud": "https://login.microsoftonline.com/tenant.onmicrosoft.com/oauth2/token?api-version=1.0",
|
|
"iss": "355dff10-cd78-11e8-89fe-000d3afd16e3",
|
|
"sub": "355dff10-cd78-11e8-89fe-000d3afd16e3",
|
|
"jti": base64.URLEncoding.EncodeToString([]byte{0}),
|
|
"nbf": time.Now().Unix(),
|
|
"exp": time.Now().Add(validFor).Unix(),
|
|
}
|
|
if withX5tHeader {
|
|
token.Header["x5t"] = base64.URLEncoding.EncodeToString([]byte("thumbprint"))
|
|
}
|
|
|
|
jwt, _ := token.SignedString(key)
|
|
return jwt
|
|
}
|