405 lines
10 KiB
Go
405 lines
10 KiB
Go
package arm
|
|
|
|
import (
|
|
crand "crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
mrand "math/rand"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Azure/go-autorest/autorest/azure"
|
|
jwt "github.com/dgrijalva/jwt-go"
|
|
"github.com/hashicorp/packer/packer"
|
|
)
|
|
|
|
func Test_ClientConfig_RequiredParametersSet(t *testing.T) {
|
|
|
|
tests := []struct {
|
|
name string
|
|
config ClientConfig
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "no client_id, client_secret or subscription_id should enable MSI auth",
|
|
config: ClientConfig{},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "subscription_id is set will trigger device flow",
|
|
config: ClientConfig{
|
|
SubscriptionID: "error",
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "client_id without client_secret, client_cert_path or client_jwt should error",
|
|
config: ClientConfig{
|
|
ClientID: "error",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "client_secret without client_id should error",
|
|
config: ClientConfig{
|
|
ClientSecret: "error",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "client_cert_path without client_id should error",
|
|
config: ClientConfig{
|
|
ClientCertPath: "/dev/null",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "client_jwt without client_id should error",
|
|
config: ClientConfig{
|
|
ClientJWT: "error",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "missing subscription_id when using secret",
|
|
config: ClientConfig{
|
|
ClientID: "ok",
|
|
ClientSecret: "ok",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "missing subscription_id when using certificate",
|
|
config: ClientConfig{
|
|
ClientID: "ok",
|
|
ClientCertPath: "ok",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "missing subscription_id when using JWT",
|
|
config: ClientConfig{
|
|
ClientID: "ok",
|
|
ClientJWT: "ok",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "too many client_* values",
|
|
config: ClientConfig{
|
|
SubscriptionID: "ok",
|
|
ClientID: "ok",
|
|
ClientSecret: "ok",
|
|
ClientCertPath: "error",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "too many client_* values (2)",
|
|
config: ClientConfig{
|
|
SubscriptionID: "ok",
|
|
ClientID: "ok",
|
|
ClientSecret: "ok",
|
|
ClientJWT: "error",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "tenant_id alone should fail",
|
|
config: ClientConfig{
|
|
TenantID: "ok",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
errs := &packer.MultiError{}
|
|
tt.config.assertRequiredParametersSet(errs)
|
|
if (len(errs.Errors) != 0) != tt.wantErr {
|
|
t.Errorf("newConfig() error = %v, wantErr %v", errs, tt.wantErr)
|
|
return
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_ClientConfig_DeviceLogin(t *testing.T) {
|
|
getEnvOrSkip(t, "AZURE_DEVICE_LOGIN")
|
|
cfg := ClientConfig{
|
|
SubscriptionID: getEnvOrSkip(t, "AZURE_SUBSCRIPTION"),
|
|
cloudEnvironment: getCloud(),
|
|
}
|
|
assertValid(t, cfg)
|
|
|
|
spt, sptkv, err := cfg.getServicePrincipalTokens(
|
|
func(s string) { fmt.Printf("SAY: %s\n", s) })
|
|
if err != nil {
|
|
t.Fatalf("Expected nil err, but got: %v", err)
|
|
}
|
|
token := spt.Token()
|
|
if token.AccessToken == "" {
|
|
t.Fatal("Expected management token to have non-nil access token")
|
|
}
|
|
if token.RefreshToken == "" {
|
|
t.Fatal("Expected management token to have non-nil refresh token")
|
|
}
|
|
kvtoken := sptkv.Token()
|
|
if kvtoken.AccessToken == "" {
|
|
t.Fatal("Expected keyvault token to have non-nil access token")
|
|
}
|
|
if kvtoken.RefreshToken == "" {
|
|
t.Fatal("Expected keyvault token to have non-nil refresh token")
|
|
}
|
|
}
|
|
|
|
func Test_ClientConfig_ClientPassword(t *testing.T) {
|
|
cfg := ClientConfig{
|
|
SubscriptionID: getEnvOrSkip(t, "AZURE_SUBSCRIPTION"),
|
|
ClientID: getEnvOrSkip(t, "AZURE_CLIENTID"),
|
|
ClientSecret: getEnvOrSkip(t, "AZURE_CLIENTSECRET"),
|
|
TenantID: getEnvOrSkip(t, "AZURE_TENANTID"),
|
|
cloudEnvironment: getCloud(),
|
|
}
|
|
assertValid(t, cfg)
|
|
|
|
spt, sptkv, err := cfg.getServicePrincipalTokens(func(s string) { fmt.Printf("SAY: %s\n", s) })
|
|
if err != nil {
|
|
t.Fatalf("Expected nil err, but got: %v", err)
|
|
}
|
|
token := spt.Token()
|
|
if token.AccessToken == "" {
|
|
t.Fatal("Expected management token to have non-nil access token")
|
|
}
|
|
if token.RefreshToken != "" {
|
|
t.Fatal("Expected management token to have no refresh token")
|
|
}
|
|
kvtoken := sptkv.Token()
|
|
if kvtoken.AccessToken == "" {
|
|
t.Fatal("Expected keyvault token to have non-nil access token")
|
|
}
|
|
if kvtoken.RefreshToken != "" {
|
|
t.Fatal("Expected keyvault token to have no refresh token")
|
|
}
|
|
}
|
|
|
|
func Test_ClientConfig_ClientCert(t *testing.T) {
|
|
cfg := ClientConfig{
|
|
SubscriptionID: getEnvOrSkip(t, "AZURE_SUBSCRIPTION"),
|
|
ClientID: getEnvOrSkip(t, "AZURE_CLIENTID"),
|
|
ClientCertPath: getEnvOrSkip(t, "AZURE_CLIENTCERT"),
|
|
TenantID: getEnvOrSkip(t, "AZURE_TENANTID"),
|
|
cloudEnvironment: getCloud(),
|
|
}
|
|
assertValid(t, cfg)
|
|
|
|
spt, sptkv, err := cfg.getServicePrincipalTokens(func(s string) { fmt.Printf("SAY: %s\n", s) })
|
|
if err != nil {
|
|
t.Fatalf("Expected nil err, but got: %v", err)
|
|
}
|
|
token := spt.Token()
|
|
if token.AccessToken == "" {
|
|
t.Fatal("Expected management token to have non-nil access token")
|
|
}
|
|
if token.RefreshToken != "" {
|
|
t.Fatal("Expected management token to have no refresh token")
|
|
}
|
|
kvtoken := sptkv.Token()
|
|
if kvtoken.AccessToken == "" {
|
|
t.Fatal("Expected keyvault token to have non-nil access token")
|
|
}
|
|
if kvtoken.RefreshToken != "" {
|
|
t.Fatal("Expected keyvault token to have no refresh token")
|
|
}
|
|
}
|
|
|
|
func Test_ClientConfig_ClientJWT(t *testing.T) {
|
|
cfg := ClientConfig{
|
|
SubscriptionID: getEnvOrSkip(t, "AZURE_SUBSCRIPTION"),
|
|
ClientID: getEnvOrSkip(t, "AZURE_CLIENTID"),
|
|
ClientJWT: getEnvOrSkip(t, "AZURE_CLIENTJWT"),
|
|
TenantID: getEnvOrSkip(t, "AZURE_TENANTID"),
|
|
cloudEnvironment: getCloud(),
|
|
}
|
|
assertValid(t, cfg)
|
|
|
|
spt, sptkv, err := cfg.getServicePrincipalTokens(func(s string) { fmt.Printf("SAY: %s\n", s) })
|
|
if err != nil {
|
|
t.Fatalf("Expected nil err, but got: %v", err)
|
|
}
|
|
token := spt.Token()
|
|
if token.AccessToken == "" {
|
|
t.Fatal("Expected management token to have non-nil access token")
|
|
}
|
|
if token.RefreshToken != "" {
|
|
t.Fatal("Expected management token to have no refresh token")
|
|
}
|
|
kvtoken := sptkv.Token()
|
|
if kvtoken.AccessToken == "" {
|
|
t.Fatal("Expected keyvault token to have non-nil access token")
|
|
}
|
|
if kvtoken.RefreshToken != "" {
|
|
t.Fatal("Expected keyvault token to have no refresh token")
|
|
}
|
|
}
|
|
|
|
func getEnvOrSkip(t *testing.T, envVar string) string {
|
|
v := os.Getenv(envVar)
|
|
if v == "" {
|
|
t.Skipf("%s is empty, skipping", envVar)
|
|
}
|
|
return v
|
|
}
|
|
|
|
func getCloud() *azure.Environment {
|
|
cloudName := os.Getenv("AZURE_CLOUD")
|
|
if cloudName == "" {
|
|
cloudName = "AZUREPUBLICCLOUD"
|
|
}
|
|
c, _ := azure.EnvironmentFromName(cloudName)
|
|
return &c
|
|
}
|
|
|
|
// tests for assertRequiredParametersSet
|
|
|
|
func Test_ClientConfig_CanUseDeviceCode(t *testing.T) {
|
|
// TenantID is optional, but Builder will look up tenant ID before requesting
|
|
t.Run("without TenantID", func(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
assertValid(t, cfg)
|
|
})
|
|
t.Run("with TenantID", func(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.TenantID = "12345"
|
|
assertValid(t, cfg)
|
|
})
|
|
}
|
|
|
|
func assertValid(t *testing.T, cfg ClientConfig) {
|
|
errs := &packer.MultiError{}
|
|
cfg.assertRequiredParametersSet(errs)
|
|
if len(errs.Errors) != 0 {
|
|
t.Fatal("Expected errs to be empty: ", errs)
|
|
}
|
|
}
|
|
|
|
func assertInvalid(t *testing.T, cfg ClientConfig) {
|
|
errs := &packer.MultiError{}
|
|
cfg.assertRequiredParametersSet(errs)
|
|
if len(errs.Errors) == 0 {
|
|
t.Fatal("Expected errs to be non-empty")
|
|
}
|
|
}
|
|
|
|
func Test_ClientConfig_CanUseClientSecret(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.ClientID = "12345"
|
|
cfg.ClientSecret = "12345"
|
|
|
|
assertValid(t, cfg)
|
|
}
|
|
|
|
func Test_ClientConfig_CanUseClientSecretWithTenantID(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.ClientID = "12345"
|
|
cfg.ClientSecret = "12345"
|
|
cfg.TenantID = "12345"
|
|
|
|
assertValid(t, cfg)
|
|
}
|
|
|
|
func Test_ClientConfig_CanUseClientJWT(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.ClientID = "12345"
|
|
cfg.ClientJWT = getJWT(10*time.Minute, true)
|
|
|
|
assertValid(t, cfg)
|
|
}
|
|
|
|
func Test_ClientConfig_CanUseClientJWTWithTenantID(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.ClientID = "12345"
|
|
cfg.ClientJWT = getJWT(10*time.Minute, true)
|
|
cfg.TenantID = "12345"
|
|
|
|
assertValid(t, cfg)
|
|
}
|
|
|
|
func Test_ClientConfig_CannotUseBothClientJWTAndSecret(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.ClientID = "12345"
|
|
cfg.ClientSecret = "12345"
|
|
cfg.ClientJWT = getJWT(10*time.Minute, true)
|
|
|
|
assertInvalid(t, cfg)
|
|
}
|
|
|
|
func Test_ClientConfig_ClientJWTShouldBeValidForAtLeast5Minutes(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.ClientID = "12345"
|
|
cfg.ClientJWT = getJWT(time.Minute, true)
|
|
|
|
assertInvalid(t, cfg)
|
|
}
|
|
|
|
func Test_ClientConfig_ClientJWTShouldHaveThumbprint(t *testing.T) {
|
|
cfg := emptyClientConfig()
|
|
cfg.SubscriptionID = "12345"
|
|
cfg.ClientID = "12345"
|
|
cfg.ClientJWT = getJWT(10*time.Minute, false)
|
|
|
|
assertInvalid(t, cfg)
|
|
}
|
|
|
|
func emptyClientConfig() ClientConfig {
|
|
cfg := ClientConfig{}
|
|
_ = cfg.setCloudEnvironment()
|
|
return cfg
|
|
}
|
|
|
|
func Test_getJWT(t *testing.T) {
|
|
if getJWT(time.Minute, true) == "" {
|
|
t.Fatalf("getJWT is broken")
|
|
}
|
|
}
|
|
|
|
func newRandReader() io.Reader {
|
|
var seed int64
|
|
binary.Read(crand.Reader, binary.LittleEndian, &seed)
|
|
|
|
return mrand.New(mrand.NewSource(seed))
|
|
}
|
|
|
|
func getJWT(validFor time.Duration, withX5tHeader bool) string {
|
|
token := jwt.New(jwt.SigningMethodRS256)
|
|
key, _ := rsa.GenerateKey(newRandReader(), 2048)
|
|
|
|
token.Claims = jwt.MapClaims{
|
|
"aud": "https://login.microsoftonline.com/tenant.onmicrosoft.com/oauth2/token?api-version=1.0",
|
|
"iss": "355dff10-cd78-11e8-89fe-000d3afd16e3",
|
|
"sub": "355dff10-cd78-11e8-89fe-000d3afd16e3",
|
|
"jti": base64.URLEncoding.EncodeToString([]byte{0}),
|
|
"nbf": time.Now().Unix(),
|
|
"exp": time.Now().Add(validFor).Unix(),
|
|
}
|
|
if withX5tHeader {
|
|
token.Header["x5t"] = base64.URLEncoding.EncodeToString([]byte("thumbprint"))
|
|
}
|
|
|
|
jwt, _ := token.SignedString(key)
|
|
return jwt
|
|
}
|