109 lines
2.9 KiB
Go
109 lines
2.9 KiB
Go
package requestid
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/google/uuid"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
const (
|
|
clientTraceIDHeader = "x-client-trace-id"
|
|
clientRequestIDHeader = "x-client-request-id"
|
|
serverRequestIDHeader = "x-request-id"
|
|
serverTraceIDHeader = "x-server-trace-id"
|
|
)
|
|
|
|
func Interceptor() func(ctx context.Context, method string, req interface{}, reply interface{}, conn *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
|
return func(ctx context.Context, method string, req interface{}, reply interface{}, conn *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
|
clientTraceID := uuid.New().String()
|
|
clientRequestID := uuid.New().String()
|
|
var responseHeader metadata.MD
|
|
opts = append(opts, grpc.Header(&responseHeader))
|
|
ctx = withClientRequestIDs(ctx, clientTraceID, clientRequestID)
|
|
err := invoker(ctx, method, req, reply, conn, opts...)
|
|
return wrapError(err, clientTraceID, clientRequestID, responseHeader)
|
|
}
|
|
}
|
|
|
|
type RequestIDs struct {
|
|
ClientTraceID string
|
|
ClientRequestID string
|
|
ServerRequestID string
|
|
ServerTraceID string
|
|
}
|
|
|
|
type errorWithRequestIDs struct {
|
|
origErr error
|
|
ids RequestIDs
|
|
}
|
|
|
|
func (e *errorWithRequestIDs) Error() string {
|
|
switch {
|
|
case e.ids.ServerRequestID != "":
|
|
return fmt.Sprintf("request-id = %s %s", e.ids.ServerRequestID, e.origErr.Error())
|
|
case e.ids.ClientRequestID != "":
|
|
return fmt.Sprintf("client-request-id = %s %s", e.ids.ClientRequestID, e.origErr.Error())
|
|
default:
|
|
return e.origErr.Error()
|
|
}
|
|
}
|
|
|
|
func (e errorWithRequestIDs) GRPCStatus() *status.Status {
|
|
return status.Convert(e.origErr)
|
|
}
|
|
|
|
func RequestIDsFromError(err error) (*RequestIDs, bool) {
|
|
if withID, ok := err.(*errorWithRequestIDs); ok {
|
|
return &withID.ids, ok
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
func wrapError(err error, clientTraceID, clientRequestID string, responseHeader metadata.MD) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
if _, ok := err.(*errorWithRequestIDs); ok {
|
|
return err
|
|
}
|
|
|
|
serverRequestID := getServerHeader(responseHeader, serverRequestIDHeader)
|
|
serverTraceID := getServerHeader(responseHeader, serverTraceIDHeader)
|
|
|
|
return &errorWithRequestIDs{
|
|
err,
|
|
RequestIDs{
|
|
ClientTraceID: clientTraceID,
|
|
ClientRequestID: clientRequestID,
|
|
ServerRequestID: serverRequestID,
|
|
ServerTraceID: serverTraceID,
|
|
},
|
|
}
|
|
}
|
|
|
|
func getServerHeader(responseHeader metadata.MD, key string) string {
|
|
serverHeaderIDRaw := responseHeader.Get(key)
|
|
if len(serverHeaderIDRaw) == 0 {
|
|
return ""
|
|
}
|
|
|
|
return serverHeaderIDRaw[0]
|
|
}
|
|
|
|
func withClientRequestIDs(ctx context.Context, clientTraceID, clientRequestID string) context.Context {
|
|
md, ok := metadata.FromOutgoingContext(ctx)
|
|
if !ok {
|
|
md = metadata.MD{}
|
|
} else {
|
|
md = md.Copy()
|
|
}
|
|
md.Set(clientRequestIDHeader, clientRequestID)
|
|
md.Set(clientTraceIDHeader, clientTraceID)
|
|
return metadata.NewOutgoingContext(ctx, md)
|
|
}
|