// NOTE(skipor): code took from google.golang.org/grpc/proxy.go and google.golang.org/grpc/rpc_util.go // Modifications: // * rename newProxyDialer to NewProxyDialer // * log when proxy using proxy /* * * Copyright 2017 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package dial import ( "bufio" "errors" "fmt" "io" "net" "net/http" "net/http/httputil" "net/url" "strings" "golang.org/x/net/context" "google.golang.org/grpc/grpclog" ) // NewProxyDialer returns a dialer that connects to proxy first if necessary. // The returned dialer checks if a proxy is necessary, dial to the proxy with the // provided dialer, does HTTP CONNECT handshake and returns the connection. func NewProxyDialer(dialer func(context.Context, string) (net.Conn, error)) func(context.Context, string) (net.Conn, error) { return func(ctx context.Context, addr string) (conn net.Conn, err error) { var skipHandshake bool newAddr, err := mapAddress(ctx, addr) if err != nil { if err != errDisabled { return nil, err } skipHandshake = true newAddr = addr } if addr != newAddr { grpclog.Infof("Using proxy %s for dialing %s", newAddr, addr) } conn, err = dialer(ctx, newAddr) if err != nil { return } if !skipHandshake { conn, err = doHTTPConnectHandshake(ctx, conn, addr) } return } } var ( // errDisabled indicates that proxy is disabled for the address. errDisabled = errors.New("proxy is disabled for the address") // The following variable will be overwritten in the tests. httpProxyFromEnvironment = http.ProxyFromEnvironment ) func mapAddress(ctx context.Context, address string) (string, error) { req := &http.Request{ URL: &url.URL{ Scheme: "https", Host: address, }, } url, err := httpProxyFromEnvironment(req) if err != nil { return "", err } if url == nil { return "", errDisabled } return url.Host, nil } // To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader. // It's possible that this reader reads more than what's need for the response and stores // those bytes in the buffer. // bufConn wraps the original net.Conn and the bufio.Reader to make sure we don't lose the // bytes in the buffer. type bufConn struct { net.Conn r io.Reader } func (c *bufConn) Read(b []byte) (int, error) { return c.r.Read(b) } func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_ net.Conn, err error) { defer func() { if err != nil { conn.Close() } }() req := (&http.Request{ Method: http.MethodConnect, URL: &url.URL{Host: addr}, Header: map[string][]string{"User-Agent": {grpcUA}}, }) if err := sendHTTPRequest(ctx, req, conn); err != nil { return nil, fmt.Errorf("failed to write the HTTP request: %v", err) } r := bufio.NewReader(conn) resp, err := http.ReadResponse(r, req) if err != nil { return nil, fmt.Errorf("reading server HTTP response: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { dump, err := httputil.DumpResponse(resp, true) if err != nil { return nil, fmt.Errorf("failed to do connect handshake, status code: %s", resp.Status) } return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump) } return &bufConn{Conn: conn, r: r}, nil } func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error { req = req.WithContext(ctx) if err := req.Write(conn); err != nil { return fmt.Errorf("failed to write the HTTP request: %v", err) } return nil } // parseDialTarget returns the network and address to pass to dialer func parseDialTarget(target string) (net string, addr string) { net = "tcp" m1 := strings.Index(target, ":") m2 := strings.Index(target, ":/") // handle unix:addr which will fail with url.Parse if m1 >= 0 && m2 < 0 { if n := target[0:m1]; n == "unix" { net = n addr = target[m1+1:] return net, addr } } if m2 >= 0 { t, err := url.Parse(target) if err != nil { return net, target } scheme := t.Scheme addr = t.Path if scheme == "unix" { net = scheme if addr == "" { addr = t.Host } return net, addr } } return net, target }