azure: upgrade Azure/go-autorest to v10.12.0

This commit is contained in:
Christopher Boumenot 2018-07-11 14:17:56 -07:00
parent b8de835a91
commit f60921ad4b
20 changed files with 1353 additions and 560 deletions

View File

@ -26,10 +26,10 @@ const (
// OAuthConfig represents the endpoints needed // OAuthConfig represents the endpoints needed
// in OAuth operations // in OAuth operations
type OAuthConfig struct { type OAuthConfig struct {
AuthorityEndpoint url.URL AuthorityEndpoint url.URL `json:"authorityEndpoint"`
AuthorizeEndpoint url.URL AuthorizeEndpoint url.URL `json:"authorizeEndpoint"`
TokenEndpoint url.URL TokenEndpoint url.URL `json:"tokenEndpoint"`
DeviceCodeEndpoint url.URL DeviceCodeEndpoint url.URL `json:"deviceCodeEndpoint"`
} }
// IsZero returns true if the OAuthConfig object is zero-initialized. // IsZero returns true if the OAuthConfig object is zero-initialized.

View File

@ -1,20 +0,0 @@
// +build !windows
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// msiPath is the path to the MSI Extension settings file (to discover the endpoint)
var msiPath = "/var/lib/waagent/ManagedIdentity-Settings"

View File

@ -1,25 +0,0 @@
// +build windows
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"os"
"strings"
)
// msiPath is the path to the MSI Extension settings file (to discover the endpoint)
var msiPath = strings.Join([]string{os.Getenv("SystemDrive"), "WindowsAzure/Config/ManagedIdentity-Settings"}, "/")

View File

@ -15,14 +15,18 @@ package adal
// limitations under the License. // limitations under the License.
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
@ -54,6 +58,12 @@ const (
// metadataHeader is the header required by MSI extension // metadataHeader is the header required by MSI extension
metadataHeader = "Metadata" metadataHeader = "Metadata"
// msiEndpoint is the well known endpoint for getting MSI authentications tokens
msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
// the default number of attempts to refresh an MSI authentication token
defaultMaxMSIRefreshAttempts = 5
) )
// OAuthTokenProvider is an interface which should be implemented by an access token retriever // OAuthTokenProvider is an interface which should be implemented by an access token retriever
@ -74,6 +84,13 @@ type Refresher interface {
EnsureFresh() error EnsureFresh() error
} }
// RefresherWithContext is an interface for token refresh functionality
type RefresherWithContext interface {
RefreshWithContext(ctx context.Context) error
RefreshExchangeWithContext(ctx context.Context, resource string) error
EnsureFreshWithContext(ctx context.Context) error
}
// TokenRefreshCallback is the type representing callbacks that will be called after // TokenRefreshCallback is the type representing callbacks that will be called after
// a successful token refresh // a successful token refresh
type TokenRefreshCallback func(Token) error type TokenRefreshCallback func(Token) error
@ -124,6 +141,12 @@ func (t *Token) OAuthToken() string {
return t.AccessToken return t.AccessToken
} }
// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
// that is submitted when acquiring an oAuth token.
type ServicePrincipalSecret interface {
SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
}
// ServicePrincipalNoSecret represents a secret type that contains no secret // ServicePrincipalNoSecret represents a secret type that contains no secret
// meaning it is not valid for fetching a fresh token. This is used by Manual // meaning it is not valid for fetching a fresh token. This is used by Manual
type ServicePrincipalNoSecret struct { type ServicePrincipalNoSecret struct {
@ -135,15 +158,19 @@ func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePr
return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token") return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
} }
// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form // MarshalJSON implements the json.Marshaler interface.
// that is submitted when acquiring an oAuth token. func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) {
type ServicePrincipalSecret interface { type tokenType struct {
SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error Type string `json:"type"`
}
return json.Marshal(tokenType{
Type: "ServicePrincipalNoSecret",
})
} }
// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization. // ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
type ServicePrincipalTokenSecret struct { type ServicePrincipalTokenSecret struct {
ClientSecret string ClientSecret string `json:"value"`
} }
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
@ -153,49 +180,24 @@ func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *Ser
return nil return nil
} }
// MarshalJSON implements the json.Marshaler interface.
func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) {
type tokenType struct {
Type string `json:"type"`
Value string `json:"value"`
}
return json.Marshal(tokenType{
Type: "ServicePrincipalTokenSecret",
Value: tokenSecret.ClientSecret,
})
}
// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs. // ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
type ServicePrincipalCertificateSecret struct { type ServicePrincipalCertificateSecret struct {
Certificate *x509.Certificate Certificate *x509.Certificate
PrivateKey *rsa.PrivateKey PrivateKey *rsa.PrivateKey
} }
// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
type ServicePrincipalMSISecret struct {
}
// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
type ServicePrincipalUsernamePasswordSecret struct {
Username string
Password string
}
// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
type ServicePrincipalAuthorizationCodeSecret struct {
ClientSecret string
AuthorizationCode string
RedirectURI string
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("code", secret.AuthorizationCode)
v.Set("client_secret", secret.ClientSecret)
v.Set("redirect_uri", secret.RedirectURI)
return nil
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("username", secret.Username)
v.Set("password", secret.Password)
return nil
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
return nil
}
// SignJwt returns the JWT signed with the certificate's private key. // SignJwt returns the JWT signed with the certificate's private key.
func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) { func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
hasher := sha1.New() hasher := sha1.New()
@ -216,9 +218,9 @@ func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalTo
token := jwt.New(jwt.SigningMethodRS256) token := jwt.New(jwt.SigningMethodRS256)
token.Header["x5t"] = thumbprint token.Header["x5t"] = thumbprint
token.Claims = jwt.MapClaims{ token.Claims = jwt.MapClaims{
"aud": spt.oauthConfig.TokenEndpoint.String(), "aud": spt.inner.OauthConfig.TokenEndpoint.String(),
"iss": spt.clientID, "iss": spt.inner.ClientID,
"sub": spt.clientID, "sub": spt.inner.ClientID,
"jti": base64.URLEncoding.EncodeToString(jti), "jti": base64.URLEncoding.EncodeToString(jti),
"nbf": time.Now().Unix(), "nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour * 24).Unix(), "exp": time.Now().Add(time.Hour * 24).Unix(),
@ -241,19 +243,151 @@ func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *Se
return nil return nil
} }
// MarshalJSON implements the json.Marshaler interface.
func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported")
}
// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
type ServicePrincipalMSISecret struct {
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) {
return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported")
}
// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
type ServicePrincipalUsernamePasswordSecret struct {
Username string `json:"username"`
Password string `json:"password"`
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("username", secret.Username)
v.Set("password", secret.Password)
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) {
type tokenType struct {
Type string `json:"type"`
Username string `json:"username"`
Password string `json:"password"`
}
return json.Marshal(tokenType{
Type: "ServicePrincipalUsernamePasswordSecret",
Username: secret.Username,
Password: secret.Password,
})
}
// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
type ServicePrincipalAuthorizationCodeSecret struct {
ClientSecret string `json:"value"`
AuthorizationCode string `json:"authCode"`
RedirectURI string `json:"redirect"`
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("code", secret.AuthorizationCode)
v.Set("client_secret", secret.ClientSecret)
v.Set("redirect_uri", secret.RedirectURI)
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) {
type tokenType struct {
Type string `json:"type"`
Value string `json:"value"`
AuthCode string `json:"authCode"`
Redirect string `json:"redirect"`
}
return json.Marshal(tokenType{
Type: "ServicePrincipalAuthorizationCodeSecret",
Value: secret.ClientSecret,
AuthCode: secret.AuthorizationCode,
Redirect: secret.RedirectURI,
})
}
// ServicePrincipalToken encapsulates a Token created for a Service Principal. // ServicePrincipalToken encapsulates a Token created for a Service Principal.
type ServicePrincipalToken struct { type ServicePrincipalToken struct {
token Token inner servicePrincipalToken
secret ServicePrincipalSecret refreshLock *sync.RWMutex
oauthConfig OAuthConfig sender Sender
clientID string
resource string
autoRefresh bool
refreshLock *sync.RWMutex
refreshWithin time.Duration
sender Sender
refreshCallbacks []TokenRefreshCallback refreshCallbacks []TokenRefreshCallback
// MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
MaxMSIRefreshAttempts int
}
// MarshalTokenJSON returns the marshalled inner token.
func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) {
return json.Marshal(spt.inner.Token)
}
// SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks.
func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) {
spt.refreshCallbacks = callbacks
}
// MarshalJSON implements the json.Marshaler interface.
func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
return json.Marshal(spt.inner)
}
// UnmarshalJSON implements the json.Unmarshaler interface.
func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error {
// need to determine the token type
raw := map[string]interface{}{}
err := json.Unmarshal(data, &raw)
if err != nil {
return err
}
secret := raw["secret"].(map[string]interface{})
switch secret["type"] {
case "ServicePrincipalNoSecret":
spt.inner.Secret = &ServicePrincipalNoSecret{}
case "ServicePrincipalTokenSecret":
spt.inner.Secret = &ServicePrincipalTokenSecret{}
case "ServicePrincipalCertificateSecret":
return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported")
case "ServicePrincipalMSISecret":
return errors.New("unmarshalling ServicePrincipalMSISecret is not supported")
case "ServicePrincipalUsernamePasswordSecret":
spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{}
case "ServicePrincipalAuthorizationCodeSecret":
spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{}
default:
return fmt.Errorf("unrecognized token type '%s'", secret["type"])
}
err = json.Unmarshal(data, &spt.inner)
if err != nil {
return err
}
spt.refreshLock = &sync.RWMutex{}
spt.sender = &http.Client{}
return nil
}
// internal type used for marshalling/unmarshalling
type servicePrincipalToken struct {
Token Token `json:"token"`
Secret ServicePrincipalSecret `json:"secret"`
OauthConfig OAuthConfig `json:"oauth"`
ClientID string `json:"clientID"`
Resource string `json:"resource"`
AutoRefresh bool `json:"autoRefresh"`
RefreshWithin time.Duration `json:"refreshWithin"`
} }
func validateOAuthConfig(oac OAuthConfig) error { func validateOAuthConfig(oac OAuthConfig) error {
@ -278,13 +412,15 @@ func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, reso
return nil, fmt.Errorf("parameter 'secret' cannot be nil") return nil, fmt.Errorf("parameter 'secret' cannot be nil")
} }
spt := &ServicePrincipalToken{ spt := &ServicePrincipalToken{
oauthConfig: oauthConfig, inner: servicePrincipalToken{
secret: secret, OauthConfig: oauthConfig,
clientID: id, Secret: secret,
resource: resource, ClientID: id,
autoRefresh: true, Resource: resource,
AutoRefresh: true,
RefreshWithin: defaultRefresh,
},
refreshLock: &sync.RWMutex{}, refreshLock: &sync.RWMutex{},
refreshWithin: defaultRefresh,
sender: &http.Client{}, sender: &http.Client{},
refreshCallbacks: callbacks, refreshCallbacks: callbacks,
} }
@ -315,7 +451,39 @@ func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID s
return nil, err return nil, err
} }
spt.token = token spt.inner.Token = token
return spt, nil
}
// NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret
func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if secret == nil {
return nil, fmt.Errorf("parameter 'secret' cannot be nil")
}
if token.IsZero() {
return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
}
spt, err := NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
secret,
callbacks...)
if err != nil {
return nil, err
}
spt.inner.Token = token
return spt, nil return spt, nil
} }
@ -441,24 +609,7 @@ func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clie
// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines. // GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
func GetMSIVMEndpoint() (string, error) { func GetMSIVMEndpoint() (string, error) {
return getMSIVMEndpoint(msiPath) return msiEndpoint, nil
}
func getMSIVMEndpoint(path string) (string, error) {
// Read MSI settings
bytes, err := ioutil.ReadFile(path)
if err != nil {
return "", err
}
msiSettings := struct {
URL string `json:"url"`
}{}
err = json.Unmarshal(bytes, &msiSettings)
if err != nil {
return "", err
}
return msiSettings.URL, nil
} }
// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension. // NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
@ -491,24 +642,32 @@ func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedI
return nil, err return nil, err
} }
oauthConfig, err := NewOAuthConfig(msiEndpointURL.String(), "") v := url.Values{}
if err != nil { v.Set("resource", resource)
return nil, err v.Set("api-version", "2018-02-01")
if userAssignedID != nil {
v.Set("client_id", *userAssignedID)
} }
msiEndpointURL.RawQuery = v.Encode()
spt := &ServicePrincipalToken{ spt := &ServicePrincipalToken{
oauthConfig: *oauthConfig, inner: servicePrincipalToken{
secret: &ServicePrincipalMSISecret{}, OauthConfig: OAuthConfig{
resource: resource, TokenEndpoint: *msiEndpointURL,
autoRefresh: true, },
refreshLock: &sync.RWMutex{}, Secret: &ServicePrincipalMSISecret{},
refreshWithin: defaultRefresh, Resource: resource,
sender: &http.Client{}, AutoRefresh: true,
refreshCallbacks: callbacks, RefreshWithin: defaultRefresh,
},
refreshLock: &sync.RWMutex{},
sender: &http.Client{},
refreshCallbacks: callbacks,
MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
} }
if userAssignedID != nil { if userAssignedID != nil {
spt.clientID = *userAssignedID spt.inner.ClientID = *userAssignedID
} }
return spt, nil return spt, nil
@ -537,12 +696,18 @@ func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError
// EnsureFresh will refresh the token if it will expire within the refresh window (as set by // EnsureFresh will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. // RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (spt *ServicePrincipalToken) EnsureFresh() error { func (spt *ServicePrincipalToken) EnsureFresh() error {
if spt.autoRefresh && spt.token.WillExpireIn(spt.refreshWithin) { return spt.EnsureFreshWithContext(context.Background())
}
// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
if spt.inner.AutoRefresh && spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
// take the write lock then check to see if the token was already refreshed // take the write lock then check to see if the token was already refreshed
spt.refreshLock.Lock() spt.refreshLock.Lock()
defer spt.refreshLock.Unlock() defer spt.refreshLock.Unlock()
if spt.token.WillExpireIn(spt.refreshWithin) { if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
return spt.refreshInternal(spt.resource) return spt.refreshInternal(ctx, spt.inner.Resource)
} }
} }
return nil return nil
@ -552,7 +717,7 @@ func (spt *ServicePrincipalToken) EnsureFresh() error {
func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error { func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
if spt.refreshCallbacks != nil { if spt.refreshCallbacks != nil {
for _, callback := range spt.refreshCallbacks { for _, callback := range spt.refreshCallbacks {
err := callback(spt.token) err := callback(spt.inner.Token)
if err != nil { if err != nil {
return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err) return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
} }
@ -564,21 +729,33 @@ func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
// Refresh obtains a fresh token for the Service Principal. // Refresh obtains a fresh token for the Service Principal.
// This method is not safe for concurrent use and should be syncrhonized. // This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) Refresh() error { func (spt *ServicePrincipalToken) Refresh() error {
return spt.RefreshWithContext(context.Background())
}
// RefreshWithContext obtains a fresh token for the Service Principal.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
spt.refreshLock.Lock() spt.refreshLock.Lock()
defer spt.refreshLock.Unlock() defer spt.refreshLock.Unlock()
return spt.refreshInternal(spt.resource) return spt.refreshInternal(ctx, spt.inner.Resource)
} }
// RefreshExchange refreshes the token, but for a different resource. // RefreshExchange refreshes the token, but for a different resource.
// This method is not safe for concurrent use and should be syncrhonized. // This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshExchange(resource string) error { func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
return spt.RefreshExchangeWithContext(context.Background(), resource)
}
// RefreshExchangeWithContext refreshes the token, but for a different resource.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
spt.refreshLock.Lock() spt.refreshLock.Lock()
defer spt.refreshLock.Unlock() defer spt.refreshLock.Unlock()
return spt.refreshInternal(resource) return spt.refreshInternal(ctx, resource)
} }
func (spt *ServicePrincipalToken) getGrantType() string { func (spt *ServicePrincipalToken) getGrantType() string {
switch spt.secret.(type) { switch spt.inner.Secret.(type) {
case *ServicePrincipalUsernamePasswordSecret: case *ServicePrincipalUsernamePasswordSecret:
return OAuthGrantTypeUserPass return OAuthGrantTypeUserPass
case *ServicePrincipalAuthorizationCodeSecret: case *ServicePrincipalAuthorizationCodeSecret:
@ -588,37 +765,64 @@ func (spt *ServicePrincipalToken) getGrantType() string {
} }
} }
func (spt *ServicePrincipalToken) refreshInternal(resource string) error { func isIMDS(u url.URL) bool {
v := url.Values{} imds, err := url.Parse(msiEndpoint)
v.Set("client_id", spt.clientID) if err != nil {
v.Set("resource", resource) return false
if spt.token.RefreshToken != "" {
v.Set("grant_type", OAuthGrantTypeRefreshToken)
v.Set("refresh_token", spt.token.RefreshToken)
} else {
v.Set("grant_type", spt.getGrantType())
err := spt.secret.SetAuthenticationValues(spt, &v)
if err != nil {
return err
}
} }
return u.Host == imds.Host && u.Path == imds.Path
}
s := v.Encode() func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
body := ioutil.NopCloser(strings.NewReader(s)) req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
req, err := http.NewRequest(http.MethodPost, spt.oauthConfig.TokenEndpoint.String(), body)
if err != nil { if err != nil {
return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err) return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
} }
req = req.WithContext(ctx)
if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
v := url.Values{}
v.Set("client_id", spt.inner.ClientID)
v.Set("resource", resource)
req.ContentLength = int64(len(s)) if spt.inner.Token.RefreshToken != "" {
req.Header.Set(contentType, mimeTypeFormPost) v.Set("grant_type", OAuthGrantTypeRefreshToken)
if _, ok := spt.secret.(*ServicePrincipalMSISecret); ok { v.Set("refresh_token", spt.inner.Token.RefreshToken)
// web apps must specify client_secret when refreshing tokens
// see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens
if spt.getGrantType() == OAuthGrantTypeAuthorizationCode {
err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
if err != nil {
return err
}
}
} else {
v.Set("grant_type", spt.getGrantType())
err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
if err != nil {
return err
}
}
s := v.Encode()
body := ioutil.NopCloser(strings.NewReader(s))
req.ContentLength = int64(len(s))
req.Header.Set(contentType, mimeTypeFormPost)
req.Body = body
}
if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
req.Method = http.MethodGet
req.Header.Set(metadataHeader, "true") req.Header.Set(metadataHeader, "true")
} }
resp, err := spt.sender.Do(req)
var resp *http.Response
if isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
} else {
resp, err = spt.sender.Do(req)
}
if err != nil { if err != nil {
return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err) return newTokenRefreshError(fmt.Sprintf("adal: Failed to execute the refresh request. Error = '%v'", err), nil)
} }
defer resp.Body.Close() defer resp.Body.Close()
@ -626,11 +830,15 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
if err != nil { if err != nil {
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body", resp.StatusCode), resp) return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp)
} }
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp) return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp)
} }
// for the following error cases don't return a TokenRefreshError. the operation succeeded
// but some transient failure happened during deserialization. by returning a generic error
// the retry logic will kick in (we don't retry on TokenRefreshError).
if err != nil { if err != nil {
return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err) return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
} }
@ -643,20 +851,99 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)) return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb))
} }
spt.token = token spt.inner.Token = token
return spt.InvokeRefreshCallbacks(token) return spt.InvokeRefreshCallbacks(token)
} }
// retry logic specific to retrieving a token from the IMDS endpoint
func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) {
// copied from client.go due to circular dependency
retries := []int{
http.StatusRequestTimeout, // 408
http.StatusTooManyRequests, // 429
http.StatusInternalServerError, // 500
http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout, // 504
}
// extra retry status codes specific to IMDS
retries = append(retries,
http.StatusNotFound,
http.StatusGone,
// all remaining 5xx
http.StatusNotImplemented,
http.StatusHTTPVersionNotSupported,
http.StatusVariantAlsoNegotiates,
http.StatusInsufficientStorage,
http.StatusLoopDetected,
http.StatusNotExtended,
http.StatusNetworkAuthenticationRequired)
// see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance
const maxDelay time.Duration = 60 * time.Second
attempt := 0
delay := time.Duration(0)
for attempt < maxAttempts {
resp, err = sender.Do(req)
// retry on temporary network errors, e.g. transient network failures.
// if we don't receive a response then assume we can't connect to the
// endpoint so we're likely not running on an Azure VM so don't retry.
if (err != nil && !isTemporaryNetworkError(err)) || resp == nil || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) {
return
}
// perform exponential backoff with a cap.
// must increment attempt before calculating delay.
attempt++
// the base value of 2 is the "delta backoff" as specified in the guidance doc
delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second)
if delay > maxDelay {
delay = maxDelay
}
select {
case <-time.After(delay):
// intentionally left blank
case <-req.Context().Done():
err = req.Context().Err()
return
}
}
return
}
// returns true if the specified error is a temporary network error or false if it's not.
// if the error doesn't implement the net.Error interface the return value is true.
func isTemporaryNetworkError(err error) bool {
if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) {
return true
}
return false
}
// returns true if slice ints contains the value n
func containsInt(ints []int, n int) bool {
for _, i := range ints {
if i == n {
return true
}
}
return false
}
// SetAutoRefresh enables or disables automatic refreshing of stale tokens. // SetAutoRefresh enables or disables automatic refreshing of stale tokens.
func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) { func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
spt.autoRefresh = autoRefresh spt.inner.AutoRefresh = autoRefresh
} }
// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will // SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
// refresh the token. // refresh the token.
func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) { func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
spt.refreshWithin = d spt.inner.RefreshWithin = d
return return
} }
@ -668,12 +955,12 @@ func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
func (spt *ServicePrincipalToken) OAuthToken() string { func (spt *ServicePrincipalToken) OAuthToken() string {
spt.refreshLock.RLock() spt.refreshLock.RLock()
defer spt.refreshLock.RUnlock() defer spt.refreshLock.RUnlock()
return spt.token.OAuthToken() return spt.inner.Token.OAuthToken()
} }
// Token returns a copy of the current token. // Token returns a copy of the current token.
func (spt *ServicePrincipalToken) Token() Token { func (spt *ServicePrincipalToken) Token() Token {
spt.refreshLock.RLock() spt.refreshLock.RLock()
defer spt.refreshLock.RUnlock() defer spt.refreshLock.RUnlock()
return spt.token return spt.inner.Token
} }

View File

@ -113,17 +113,19 @@ func (ba *BearerAuthorizer) WithAuthorization() PrepareDecorator {
return PreparerFunc(func(r *http.Request) (*http.Request, error) { return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r) r, err := p.Prepare(r)
if err == nil { if err == nil {
refresher, ok := ba.tokenProvider.(adal.Refresher) // the ordering is important here, prefer RefresherWithContext if available
if ok { if refresher, ok := ba.tokenProvider.(adal.RefresherWithContext); ok {
err := refresher.EnsureFresh() err = refresher.EnsureFreshWithContext(r.Context())
if err != nil { } else if refresher, ok := ba.tokenProvider.(adal.Refresher); ok {
var resp *http.Response err = refresher.EnsureFresh()
if tokError, ok := err.(adal.TokenRefreshError); ok { }
resp = tokError.Response() if err != nil {
} var resp *http.Response
return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp, if tokError, ok := err.(adal.TokenRefreshError); ok {
"Failed to refresh the Token for request to %s", r.URL) resp = tokError.Response()
} }
return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp,
"Failed to refresh the Token for request to %s", r.URL)
} }
return Prepare(r, WithHeader(headerAuthorization, fmt.Sprintf("Bearer %s", ba.tokenProvider.OAuthToken()))) return Prepare(r, WithHeader(headerAuthorization, fmt.Sprintf("Bearer %s", ba.tokenProvider.OAuthToken())))
} }

View File

@ -72,6 +72,7 @@ package autorest
// limitations under the License. // limitations under the License.
import ( import (
"context"
"net/http" "net/http"
"time" "time"
) )
@ -130,3 +131,20 @@ func NewPollingRequest(resp *http.Response, cancel <-chan struct{}) (*http.Reque
return req, nil return req, nil
} }
// NewPollingRequestWithContext allocates and returns a new http.Request with the specified context to poll for the passed response.
func NewPollingRequestWithContext(ctx context.Context, resp *http.Response) (*http.Request, error) {
location := GetLocation(resp)
if location == "" {
return nil, NewErrorWithResponse("autorest", "NewPollingRequestWithContext", resp, "Location header missing from response that requires polling")
}
req, err := Prepare((&http.Request{}).WithContext(ctx),
AsGet(),
WithBaseURL(location))
if err != nil {
return nil, NewErrorWithError(err, "autorest", "NewPollingRequestWithContext", nil, "Failure creating poll request to %s", location)
}
return req, nil
}

File diff suppressed because it is too large Load Diff

View File

@ -279,16 +279,29 @@ func WithErrorUnlessStatusCode(codes ...int) autorest.RespondDecorator {
resp.Body = ioutil.NopCloser(&b) resp.Body = ioutil.NopCloser(&b)
if decodeErr != nil { if decodeErr != nil {
return fmt.Errorf("autorest/azure: error response cannot be parsed: %q error: %v", b.String(), decodeErr) return fmt.Errorf("autorest/azure: error response cannot be parsed: %q error: %v", b.String(), decodeErr)
} else if e.ServiceError == nil { }
if e.ServiceError == nil {
// Check if error is unwrapped ServiceError // Check if error is unwrapped ServiceError
if err := json.Unmarshal(b.Bytes(), &e.ServiceError); err != nil || e.ServiceError.Message == "" { if err := json.Unmarshal(b.Bytes(), &e.ServiceError); err != nil {
e.ServiceError = &ServiceError{ return err
Code: "Unknown",
Message: "Unknown service error",
}
} }
} }
if e.ServiceError.Message == "" {
// if we're here it means the returned error wasn't OData v4 compliant.
// try to unmarshal the body as raw JSON in hopes of getting something.
rawBody := map[string]interface{}{}
if err := json.Unmarshal(b.Bytes(), &rawBody); err != nil {
return err
}
e.ServiceError = &ServiceError{
Code: "Unknown",
Message: "Unknown service error",
}
if len(rawBody) > 0 {
e.ServiceError.Details = []map[string]interface{}{rawBody}
}
}
e.Response = resp
e.RequestID = ExtractRequestID(resp) e.RequestID = ExtractRequestID(resp)
if e.StatusCode == nil { if e.StatusCode == nil {
e.StatusCode = resp.StatusCode e.StatusCode = resp.StatusCode

View File

@ -64,7 +64,7 @@ func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator {
} }
} }
} }
return resp, fmt.Errorf("failed request: %s", err) return resp, err
}) })
} }
} }
@ -115,7 +115,7 @@ func register(client autorest.Client, originalReq *http.Request, re RequestError
if err != nil { if err != nil {
return err return err
} }
req.Cancel = originalReq.Cancel req = req.WithContext(originalReq.Context())
resp, err := autorest.SendWithSender(client, req, resp, err := autorest.SendWithSender(client, req,
autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
@ -154,7 +154,7 @@ func register(client autorest.Client, originalReq *http.Request, re RequestError
if err != nil { if err != nil {
return err return err
} }
req.Cancel = originalReq.Cancel req = req.WithContext(originalReq.Context())
resp, err := autorest.SendWithSender(client, req, resp, err := autorest.SendWithSender(client, req,
autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
@ -178,9 +178,9 @@ func register(client autorest.Client, originalReq *http.Request, re RequestError
break break
} }
delayed := autorest.DelayWithRetryAfter(resp, originalReq.Cancel) delayed := autorest.DelayWithRetryAfter(resp, originalReq.Context().Done())
if !delayed { if !delayed && !autorest.DelayForBackoff(client.PollingDelay, 0, originalReq.Context().Done()) {
autorest.DelayForBackoff(client.PollingDelay, 0, originalReq.Cancel) return originalReq.Context().Err()
} }
} }
if !(time.Since(now) < client.PollingDuration) { if !(time.Since(now) < client.PollingDuration) {

View File

@ -5,6 +5,20 @@ time.Time types. And both convert to time.Time through a ToTime method.
*/ */
package date package date
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import ( import (
"fmt" "fmt"
"time" "time"

View File

@ -1,5 +1,19 @@
package date package date
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import ( import (
"regexp" "regexp"
"time" "time"

View File

@ -1,5 +1,19 @@
package date package date
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import ( import (
"errors" "errors"
"time" "time"

View File

@ -1,5 +1,19 @@
package date package date
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"

View File

@ -1,5 +1,19 @@
package date package date
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import ( import (
"strings" "strings"
"time" "time"

View File

@ -86,7 +86,7 @@ func SendWithSender(s Sender, r *http.Request, decorators ...SendDecorator) (*ht
func AfterDelay(d time.Duration) SendDecorator { func AfterDelay(d time.Duration) SendDecorator {
return func(s Sender) Sender { return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) { return SenderFunc(func(r *http.Request) (*http.Response, error) {
if !DelayForBackoff(d, 0, r.Cancel) { if !DelayForBackoff(d, 0, r.Context().Done()) {
return nil, fmt.Errorf("autorest: AfterDelay canceled before full delay") return nil, fmt.Errorf("autorest: AfterDelay canceled before full delay")
} }
return s.Do(r) return s.Do(r)
@ -165,7 +165,7 @@ func DoPollForStatusCodes(duration time.Duration, delay time.Duration, codes ...
resp, err = s.Do(r) resp, err = s.Do(r)
if err == nil && ResponseHasStatusCode(resp, codes...) { if err == nil && ResponseHasStatusCode(resp, codes...) {
r, err = NewPollingRequest(resp, r.Cancel) r, err = NewPollingRequestWithContext(r.Context(), resp)
for err == nil && ResponseHasStatusCode(resp, codes...) { for err == nil && ResponseHasStatusCode(resp, codes...) {
Respond(resp, Respond(resp,
@ -198,7 +198,9 @@ func DoRetryForAttempts(attempts int, backoff time.Duration) SendDecorator {
if err == nil { if err == nil {
return resp, err return resp, err
} }
DelayForBackoff(backoff, attempt, r.Cancel) if !DelayForBackoff(backoff, attempt, r.Context().Done()) {
return nil, r.Context().Err()
}
} }
return resp, err return resp, err
}) })
@ -221,14 +223,18 @@ func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) Se
return resp, err return resp, err
} }
resp, err = s.Do(rr.Request()) resp, err = s.Do(rr.Request())
// if the error isn't temporary don't bother retrying
if err != nil && !IsTemporaryNetworkError(err) {
return nil, err
}
// we want to retry if err is not nil (e.g. transient network failure). note that for failed authentication // we want to retry if err is not nil (e.g. transient network failure). note that for failed authentication
// resp and err will both have a value, so in this case we don't want to retry as it will never succeed. // resp and err will both have a value, so in this case we don't want to retry as it will never succeed.
if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) { if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) {
return resp, err return resp, err
} }
delayed := DelayWithRetryAfter(resp, r.Cancel) delayed := DelayWithRetryAfter(resp, r.Context().Done())
if !delayed { if !delayed && !DelayForBackoff(backoff, attempt, r.Context().Done()) {
DelayForBackoff(backoff, attempt, r.Cancel) return nil, r.Context().Err()
} }
// don't count a 429 against the number of attempts // don't count a 429 against the number of attempts
// so that we continue to retry until it succeeds // so that we continue to retry until it succeeds
@ -277,7 +283,9 @@ func DoRetryForDuration(d time.Duration, backoff time.Duration) SendDecorator {
if err == nil { if err == nil {
return resp, err return resp, err
} }
DelayForBackoff(backoff, attempt, r.Cancel) if !DelayForBackoff(backoff, attempt, r.Context().Done()) {
return nil, r.Context().Err()
}
} }
return resp, err return resp, err
}) })

View File

@ -3,6 +3,20 @@ Package to provides helpers to ease working with pointer values of marshalled st
*/ */
package to package to
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// String returns a string value for the passed string pointer. It returns the empty string if the // String returns a string value for the passed string pointer. It returns the empty string if the
// pointer is nil. // pointer is nil.
func String(s *string) string { func String(s *string) string {

View File

@ -20,6 +20,7 @@ import (
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"reflect" "reflect"
@ -216,3 +217,12 @@ func IsTokenRefreshError(err error) bool {
} }
return false return false
} }
// IsTemporaryNetworkError returns true if the specified error is a temporary network error or false
// if it's not. If the error doesn't implement the net.Error interface the return value is true.
func IsTemporaryNetworkError(err error) bool {
if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) {
return true
}
return false
}

View File

@ -136,29 +136,29 @@ func validatePtr(x reflect.Value, v Constraint) error {
func validateInt(x reflect.Value, v Constraint) error { func validateInt(x reflect.Value, v Constraint) error {
i := x.Int() i := x.Int()
r, ok := v.Rule.(int) r, ok := toInt64(v.Rule)
if !ok { if !ok {
return createError(x, v, fmt.Sprintf("rule must be integer value for %v constraint; got: %v", v.Name, v.Rule)) return createError(x, v, fmt.Sprintf("rule must be integer value for %v constraint; got: %v", v.Name, v.Rule))
} }
switch v.Name { switch v.Name {
case MultipleOf: case MultipleOf:
if i%int64(r) != 0 { if i%r != 0 {
return createError(x, v, fmt.Sprintf("value must be a multiple of %v", r)) return createError(x, v, fmt.Sprintf("value must be a multiple of %v", r))
} }
case ExclusiveMinimum: case ExclusiveMinimum:
if i <= int64(r) { if i <= r {
return createError(x, v, fmt.Sprintf("value must be greater than %v", r)) return createError(x, v, fmt.Sprintf("value must be greater than %v", r))
} }
case ExclusiveMaximum: case ExclusiveMaximum:
if i >= int64(r) { if i >= r {
return createError(x, v, fmt.Sprintf("value must be less than %v", r)) return createError(x, v, fmt.Sprintf("value must be less than %v", r))
} }
case InclusiveMinimum: case InclusiveMinimum:
if i < int64(r) { if i < r {
return createError(x, v, fmt.Sprintf("value must be greater than or equal to %v", r)) return createError(x, v, fmt.Sprintf("value must be greater than or equal to %v", r))
} }
case InclusiveMaximum: case InclusiveMaximum:
if i > int64(r) { if i > r {
return createError(x, v, fmt.Sprintf("value must be less than or equal to %v", r)) return createError(x, v, fmt.Sprintf("value must be less than or equal to %v", r))
} }
default: default:
@ -388,6 +388,17 @@ func createError(x reflect.Value, v Constraint, err string) error {
v.Target, v.Name, getInterfaceValue(x), err) v.Target, v.Name, getInterfaceValue(x), err)
} }
func toInt64(v interface{}) (int64, bool) {
if i64, ok := v.(int64); ok {
return i64, true
}
// older generators emit max constants as int, so if int64 fails fall back to int
if i32, ok := v.(int); ok {
return int64(i32), true
}
return 0, false
}
// NewErrorWithValidationError appends package type and method name in // NewErrorWithValidationError appends package type and method name in
// validation error. // validation error.
// //

View File

@ -16,5 +16,5 @@ package autorest
// Version returns the semantic version (see http://semver.org). // Version returns the semantic version (see http://semver.org).
func Version() string { func Version() string {
return "v10.3.0" return "v10.12.0"
} }

60
vendor/vendor.json vendored
View File

@ -65,56 +65,56 @@
"versionExact": "v17.3.1" "versionExact": "v17.3.1"
}, },
{ {
"checksumSHA1": "+P6HOINDh/n2z4GqEkluzuGP5p0=", "checksumSHA1": "q3bhLdUVAz5dmDFvfKJJSTd/BaE=",
"comment": "v7.0.7", "comment": "v7.0.7",
"path": "github.com/Azure/go-autorest/autorest", "path": "github.com/Azure/go-autorest/autorest",
"revision": "ed4b7f5bf1ec0c9ede1fda2681d96771282f2862", "revision": "1f7cd6cfe0adea687ad44a512dfe76140f804318",
"revisionTime": "2018-03-26T17:06:54Z", "revisionTime": "2018-06-28T21:22:21Z",
"version": "v10.4.0", "version": "v10.12.0",
"versionExact": "v10.4.0" "versionExact": "v10.12.0"
}, },
{ {
"checksumSHA1": "HzA52MbMWnsR31CFrub5biN90/Q=", "checksumSHA1": "Ygkay0Aq7zeg8A9DSRWI+flRcFk=",
"path": "github.com/Azure/go-autorest/autorest/adal", "path": "github.com/Azure/go-autorest/autorest/adal",
"revision": "ed4b7f5bf1ec0c9ede1fda2681d96771282f2862", "revision": "1f7cd6cfe0adea687ad44a512dfe76140f804318",
"revisionTime": "2018-03-26T17:06:54Z", "revisionTime": "2018-06-28T21:22:21Z",
"version": "v10.4.0", "version": "v10.12.0",
"versionExact": "v10.4.0" "versionExact": "v10.12.0"
}, },
{ {
"checksumSHA1": "bDFbLGwpCT8TRmqEKtPY/U1DAY8=", "checksumSHA1": "O8hvZ5SZ6o/mGOTivIagc1uPdl4=",
"comment": "v7.0.7", "comment": "v7.0.7",
"path": "github.com/Azure/go-autorest/autorest/azure", "path": "github.com/Azure/go-autorest/autorest/azure",
"revision": "ed4b7f5bf1ec0c9ede1fda2681d96771282f2862", "revision": "1f7cd6cfe0adea687ad44a512dfe76140f804318",
"revisionTime": "2018-03-26T17:06:54Z", "revisionTime": "2018-06-28T21:22:21Z",
"version": "v10.4.0", "version": "v10.12.0",
"versionExact": "v10.4.0" "versionExact": "v10.12.0"
}, },
{ {
"checksumSHA1": "LSF/pNrjhIxl6jiS6bKooBFCOxI=", "checksumSHA1": "KcxXWvXhVIkfqdGR+TQqWyCbgZk=",
"comment": "v7.0.7", "comment": "v7.0.7",
"path": "github.com/Azure/go-autorest/autorest/date", "path": "github.com/Azure/go-autorest/autorest/date",
"revision": "58f6f26e200fa5dfb40c9cd1c83f3e2c860d779d", "revision": "1f7cd6cfe0adea687ad44a512dfe76140f804318",
"revisionTime": "2017-04-28T17:52:31Z", "revisionTime": "2018-06-28T21:22:21Z",
"version": "=v8.0.0", "version": "v10.12.0",
"versionExact": "v8.0.0" "versionExact": "v10.12.0"
}, },
{ {
"checksumSHA1": "Ev8qCsbFjDlMlX0N2tYAhYQFpUc=", "checksumSHA1": "8XiRKZWE6XGsb0eJfG4P98JwaXw=",
"comment": "v7.0.7", "comment": "v7.0.7",
"path": "github.com/Azure/go-autorest/autorest/to", "path": "github.com/Azure/go-autorest/autorest/to",
"revision": "58f6f26e200fa5dfb40c9cd1c83f3e2c860d779d", "revision": "1f7cd6cfe0adea687ad44a512dfe76140f804318",
"revisionTime": "2017-04-28T17:52:31Z", "revisionTime": "2018-06-28T21:22:21Z",
"version": "=v8.0.0", "version": "v10.12.0",
"versionExact": "v8.0.0" "versionExact": "v10.12.0"
}, },
{ {
"checksumSHA1": "CdDkG+J8wqXQVQ0f0xal+eolB1w=", "checksumSHA1": "SVU/fTZ9osBm+WPdd5y71RabAzE=",
"path": "github.com/Azure/go-autorest/autorest/validation", "path": "github.com/Azure/go-autorest/autorest/validation",
"revision": "ed4b7f5bf1ec0c9ede1fda2681d96771282f2862", "revision": "1f7cd6cfe0adea687ad44a512dfe76140f804318",
"revisionTime": "2018-03-26T17:06:54Z", "revisionTime": "2018-06-28T21:22:21Z",
"version": "v10.4.0", "version": "v10.12.0",
"versionExact": "v10.4.0" "versionExact": "v10.12.0"
}, },
{ {
"checksumSHA1": "TgrN0l/E16deTlLYNt8wf66urSU=", "checksumSHA1": "TgrN0l/E16deTlLYNt8wf66urSU=",