2019-09-23 21:09:10 +03:00

171 lines
4.2 KiB
Go

package retry
import (
"context"
"strconv"
"time"
"github.com/google/uuid"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// Interceptor returns retry interceptor, that can be parametrized by specific call options.
// Without any option, it uses default options, that basically retries nothing.
// Default retry quantity is 0, backoff function is nil, retry codes are DefaultRetriableCodes, AttemptHeader is false, and perCallTimeout is 0.
func Interceptor(callOpts ...grpc.CallOption) grpc.UnaryClientInterceptor {
i := unaryInterceptor{opts: *defaultOptions.applyOptions(callOpts)}
return i.InterceptUnary
}
type unaryInterceptor struct {
opts options
}
func (i *unaryInterceptor) InterceptUnary(ctx context.Context, method string, req, reply interface{}, conn *grpc.ClientConn, invoker grpc.UnaryInvoker, callOpts ...grpc.CallOption) error {
opts := i.opts.applyOptions(callOpts)
ctx = addIdempotencyToken(ctx)
caller := grpcCaller{ctx, method, req, reply, conn, invoker, callOpts}
// TODO(seukyaso): consider adding some configurable callbacks for notifying/logging purpose
callContextIsDone, err := caller.Call(0, opts)
for r := 0; opts.maxRetryCount < 0 || r < opts.maxRetryCount; r++ {
if err == nil {
return nil
}
// check for parent context errors, return if context is Cancelled or Deadline exceeded
select {
case <-ctx.Done():
return contextErrorToGRPCError(ctx.Err())
default:
}
// Always retry if call context is Done (cancelled or Deadline exceeded).
// Thus, we ignore call errors in this case.
if !callContextIsDone && !opts.isRetriable(err) {
return err
}
err = opts.waitBackoff(ctx, r)
if err != nil {
return err
}
callContextIsDone, err = caller.Call(r+1, opts)
}
return err
}
type grpcCaller struct {
ctx context.Context
method string
req interface{}
reply interface{}
conn *grpc.ClientConn
invoker grpc.UnaryInvoker
callOpts []grpc.CallOption
}
func (c *grpcCaller) Call(attempt int, opts *options) (callContextIsDone bool, err error) {
callCtx := c.ctx
var cancel context.CancelFunc
if attempt > 0 {
callCtx = opts.addRetryAttemptToHeader(callCtx, attempt)
if opts.perCallTimeout > 0 {
callCtx, cancel = context.WithTimeout(callCtx, opts.perCallTimeout)
defer cancel()
}
}
err = c.invoker(callCtx, c.method, c.req, c.reply, c.conn, c.callOpts...)
select {
case <-callCtx.Done():
callContextIsDone = true
default:
}
return
}
func (opts *options) applyOptions(callOpts []grpc.CallOption) *options {
ret := *opts
for _, opt := range callOpts {
if do, ok := opt.(interceptorOption); ok {
do.applyFunc(&ret)
}
}
return &ret
}
func contextErrorToGRPCError(err error) error {
switch err {
case context.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled:
return status.Error(codes.Canceled, err.Error())
default:
return status.Error(codes.Unknown, err.Error())
}
}
func (opts *options) waitBackoff(ctx context.Context, attempt int) error {
if opts.backoffFunc == nil {
return nil
}
waitTime := opts.backoffFunc(attempt)
if waitTime > 0 {
select {
case <-ctx.Done():
return contextErrorToGRPCError(ctx.Err())
case <-time.After(waitTime):
}
}
return nil
}
func (opts *options) isRetriable(err error) bool {
errCode := status.Code(err)
for _, code := range opts.retriableCodes {
if errCode == code {
return true
}
}
return false
}
func addIdempotencyToken(ctx context.Context) context.Context {
const idempotencyTokenMetadataKey = "idempotency-key"
idempotencyTokenPresent := false
md, ok := metadata.FromOutgoingContext(ctx)
if ok {
_, idempotencyTokenPresent = md[idempotencyTokenMetadataKey]
}
if !idempotencyTokenPresent {
ctx = metadata.AppendToOutgoingContext(ctx, idempotencyTokenMetadataKey, uuid.New().String())
}
return ctx
}
func (opts *options) addRetryAttemptToHeader(ctx context.Context, attempt int) context.Context {
const AttemptMetadataKey = "x-retry-attempt"
if opts.addCountToHeader {
ctx = metadata.AppendToOutgoingContext(ctx, AttemptMetadataKey, strconv.Itoa(attempt))
}
return ctx
}