packer-cn/vendor/github.com/yandex-cloud/go-sdk/pkg/requestid/interceptor.go

109 lines
2.9 KiB
Go

package requestid
import (
"context"
"fmt"
"github.com/google/uuid"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
const (
clientTraceIDHeader = "x-client-trace-id"
clientRequestIDHeader = "x-client-request-id"
serverRequestIDHeader = "x-request-id"
serverTraceIDHeader = "x-server-trace-id"
)
func Interceptor() func(ctx context.Context, method string, req interface{}, reply interface{}, conn *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
return func(ctx context.Context, method string, req interface{}, reply interface{}, conn *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
clientTraceID := uuid.New().String()
clientRequestID := uuid.New().String()
var responseHeader metadata.MD
opts = append(opts, grpc.Header(&responseHeader))
ctx = withClientRequestIDs(ctx, clientTraceID, clientRequestID)
err := invoker(ctx, method, req, reply, conn, opts...)
return wrapError(err, clientTraceID, clientRequestID, responseHeader)
}
}
type RequestIDs struct {
ClientTraceID string
ClientRequestID string
ServerRequestID string
ServerTraceID string
}
type errorWithRequestIDs struct {
origErr error
ids RequestIDs
}
func (e *errorWithRequestIDs) Error() string {
switch {
case e.ids.ServerRequestID != "":
return fmt.Sprintf("request-id = %s %s", e.ids.ServerRequestID, e.origErr.Error())
case e.ids.ClientRequestID != "":
return fmt.Sprintf("client-request-id = %s %s", e.ids.ClientRequestID, e.origErr.Error())
default:
return e.origErr.Error()
}
}
func (e errorWithRequestIDs) GRPCStatus() *status.Status {
return status.Convert(e.origErr)
}
func RequestIDsFromError(err error) (*RequestIDs, bool) {
if withID, ok := err.(*errorWithRequestIDs); ok {
return &withID.ids, ok
}
return nil, false
}
func wrapError(err error, clientTraceID, clientRequestID string, responseHeader metadata.MD) error {
if err == nil {
return nil
}
if _, ok := err.(*errorWithRequestIDs); ok {
return err
}
serverRequestID := getServerHeader(responseHeader, serverRequestIDHeader)
serverTraceID := getServerHeader(responseHeader, serverTraceIDHeader)
return &errorWithRequestIDs{
err,
RequestIDs{
ClientTraceID: clientTraceID,
ClientRequestID: clientRequestID,
ServerRequestID: serverRequestID,
ServerTraceID: serverTraceID,
},
}
}
func getServerHeader(responseHeader metadata.MD, key string) string {
serverHeaderIDRaw := responseHeader.Get(key)
if len(serverHeaderIDRaw) == 0 {
return ""
}
return serverHeaderIDRaw[0]
}
func withClientRequestIDs(ctx context.Context, clientTraceID, clientRequestID string) context.Context {
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
md = metadata.MD{}
} else {
md = md.Copy()
}
md.Set(clientRequestIDHeader, clientRequestID)
md.Set(clientTraceIDHeader, clientTraceID)
return metadata.NewOutgoingContext(ctx, md)
}