From 9f1136db77984ab8419d6c516a6261666cef3122 Mon Sep 17 00:00:00 2001 From: Adrien Delorme Date: Wed, 24 Apr 2019 14:03:36 +0200 Subject: [PATCH] retry: encapsulate & return the last seen error in a RetryExhaustedError --- common/retry/retry.go | 15 +++++++++--- common/retry/retry_test.go | 32 +++++++++++++++++++++---- common/retry/utils_test.go | 48 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 7 deletions(-) create mode 100644 common/retry/utils_test.go diff --git a/common/retry/retry.go b/common/retry/retry.go index eb8fc4f70..2b1a4f231 100644 --- a/common/retry/retry.go +++ b/common/retry/retry.go @@ -25,7 +25,16 @@ type Config struct { ShouldRetry func(error) bool } -var RetryExhaustedError error = fmt.Errorf("Function never succeeded in Retry") +type RetryExhaustedError struct { + Err error +} + +func (err *RetryExhaustedError) Error() string { + if err == nil || err.Err == nil { + return "" + } + return fmt.Sprintf("retry count exhausted. Last err: %s", err.Err) +} // Run fn until context is cancelled up until StartTimeout time has passed. func (cfg Config) Run(ctx context.Context, fn func(context.Context) error) error { @@ -42,10 +51,10 @@ func (cfg Config) Run(ctx context.Context, fn func(context.Context) error) error startTimeout = time.After(cfg.StartTimeout) } + var err error for try := 0; ; try++ { - var err error if cfg.Tries != 0 && try == cfg.Tries { - return RetryExhaustedError + return &RetryExhaustedError{err} } if err = fn(ctx); err == nil { return nil diff --git a/common/retry/retry_test.go b/common/retry/retry_test.go index 334707c0b..41ecc16c3 100644 --- a/common/retry/retry_test.go +++ b/common/retry/retry_test.go @@ -5,6 +5,8 @@ import ( "errors" "testing" "time" + + "github.com/google/go-cmp/cmp" ) func success(context.Context) error { return nil } @@ -18,12 +20,23 @@ var failErr = errors.New("woops !") func fail(context.Context) error { return failErr } +type failOnce bool + +func (ran *failOnce) Run(context.Context) error { + if !*ran { + *ran = true + return failErr + } + return nil +} + func TestConfig_Run(t *testing.T) { cancelledCtx, cancel := context.WithCancel(context.Background()) cancel() type fields struct { StartTimeout time.Duration RetryDelay func() time.Duration + Tries int } type args struct { ctx context.Context @@ -36,26 +49,37 @@ func TestConfig_Run(t *testing.T) { wantErr error }{ {"success", - fields{StartTimeout: time.Second, RetryDelay: nil}, + fields{StartTimeout: time.Second}, args{context.Background(), success}, nil}, {"context cancelled", - fields{StartTimeout: time.Second, RetryDelay: nil}, + fields{StartTimeout: time.Second}, args{cancelledCtx, wait}, context.Canceled}, {"timeout", fields{StartTimeout: 20 * time.Millisecond, RetryDelay: func() time.Duration { return 10 * time.Millisecond }}, args{cancelledCtx, fail}, failErr}, + {"success after one failure", + fields{Tries: 2, RetryDelay: func() time.Duration { return 0 }}, + args{context.Background(), new(failOnce).Run}, + nil}, + {"fail after one failure", + fields{Tries: 1, RetryDelay: func() time.Duration { return 0 }}, + args{context.Background(), new(failOnce).Run}, + &RetryExhaustedError{failErr}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := Config{ StartTimeout: tt.fields.StartTimeout, RetryDelay: tt.fields.RetryDelay, + Tries: tt.fields.Tries, } - if err := cfg.Run(tt.args.ctx, tt.args.fn); err != tt.wantErr { - t.Fatalf("Config.Run() error = %v, wantErr %v", err, tt.wantErr) + err := cfg.Run(tt.args.ctx, tt.args.fn) + if diff := cmp.Diff(err, tt.wantErr, DeepAllowUnexported(RetryExhaustedError{}, errors.New(""))); diff != "" { + t.Fatalf("Config.Run() unexpected error: %s", diff) } }) } diff --git a/common/retry/utils_test.go b/common/retry/utils_test.go new file mode 100644 index 000000000..a23a7d578 --- /dev/null +++ b/common/retry/utils_test.go @@ -0,0 +1,48 @@ +package retry + +import ( + "reflect" + + "github.com/google/go-cmp/cmp" +) + +func DeepAllowUnexported(vs ...interface{}) cmp.Option { + m := make(map[reflect.Type]struct{}) + for _, v := range vs { + structTypes(reflect.ValueOf(v), m) + } + var typs []interface{} + for t := range m { + typs = append(typs, reflect.New(t).Elem().Interface()) + } + return cmp.AllowUnexported(typs...) +} + +func structTypes(v reflect.Value, m map[reflect.Type]struct{}) { + if !v.IsValid() { + return + } + switch v.Kind() { + case reflect.Ptr: + if !v.IsNil() { + structTypes(v.Elem(), m) + } + case reflect.Interface: + if !v.IsNil() { + structTypes(v.Elem(), m) + } + case reflect.Slice, reflect.Array: + for i := 0; i < v.Len(); i++ { + structTypes(v.Index(i), m) + } + case reflect.Map: + for _, k := range v.MapKeys() { + structTypes(v.MapIndex(k), m) + } + case reflect.Struct: + m[v.Type()] = struct{}{} + for i := 0; i < v.NumField(); i++ { + structTypes(v.Field(i), m) + } + } +}