diff --git a/packer/rpc/artifact_test.go b/packer/rpc/artifact_test.go index 1cefe082f..d5405ab6e 100644 --- a/packer/rpc/artifact_test.go +++ b/packer/rpc/artifact_test.go @@ -11,9 +11,11 @@ func TestArtifactRPC(t *testing.T) { a := new(packer.MockArtifact) // Start the server - server := NewServer() + client, server := testClientServer(t) + defer client.Close() + defer server.Close() server.RegisterArtifact(a) - client := testClient(t, server) + aClient := client.Artifact() // Test diff --git a/packer/rpc/cache_test.go b/packer/rpc/cache_test.go index 3fce1ee4a..a9a1f110f 100644 --- a/packer/rpc/cache_test.go +++ b/packer/rpc/cache_test.go @@ -51,10 +51,11 @@ func TestCacheRPC(t *testing.T) { c := new(testCache) // Start the server - server := NewServer() - server.RegisterCache(c) - client := testClient(t, server) + client, server := testClientServer(t) defer client.Close() + defer server.Close() + server.RegisterCache(c) + cacheClient := client.Cache() // Test Lock diff --git a/packer/rpc/client.go b/packer/rpc/client.go index 08ddeee12..99f0b296d 100644 --- a/packer/rpc/client.go +++ b/packer/rpc/client.go @@ -12,7 +12,6 @@ import ( type Client struct { mux *MuxConn client *rpc.Client - server *rpc.Server } func NewClient(rwc io.ReadWriteCloser) (*Client, error) { @@ -27,22 +26,9 @@ func NewClient(rwc io.ReadWriteCloser) (*Client, error) { return nil, err } - // Accept connection ID 1 which is what the remote end uses to - // be an RPC client back to us so we can even serve some objects. - serverConn, err := mux.Accept(1) - if err != nil { - mux.Close() - return nil, err - } - - // Start our RPC server on this end - server := rpc.NewServer() - go server.ServeConn(serverConn) - return &Client{ mux: mux, client: rpc.NewClient(clientConn), - server: server, }, nil } @@ -70,6 +56,12 @@ func (c *Client) Cache() packer.Cache { func (c *Client) PostProcessor() packer.PostProcessor { return &postProcessor{ client: c.client, - server: c.server, + } +} + +func (c *Client) Ui() packer.Ui { + return &Ui{ + client: c.client, + endpoint: DefaultUiEndpoint, } } diff --git a/packer/rpc/client_test.go b/packer/rpc/client_test.go index 09976aabb..c0595cbe4 100644 --- a/packer/rpc/client_test.go +++ b/packer/rpc/client_test.go @@ -5,29 +5,44 @@ import ( "testing" ) -func testClient(t *testing.T, server *Server) *Client { +func testConn(t *testing.T) (net.Conn, net.Conn) { l, err := net.Listen("tcp", ":0") if err != nil { t.Fatalf("err: %s", err) } + var serverConn net.Conn + doneCh := make(chan struct{}) go func() { - conn, err := l.Accept() + defer close(doneCh) + defer l.Close() + var err error + serverConn, err = l.Accept() if err != nil { t.Fatalf("err: %s", err) } - server.ServeConn(conn) }() clientConn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("err: %s", err) } + <-doneCh + + return clientConn, serverConn +} + +func testClientServer(t *testing.T) (*Client, *Server) { + clientConn, serverConn := testConn(t) + + server := NewServer(serverConn) + go server.Serve() client, err := NewClient(clientConn) if err != nil { + server.Close() t.Fatalf("err: %s", err) } - return client + return client, server } diff --git a/packer/rpc/post_processor_test.go b/packer/rpc/post_processor_test.go index e22136ed6..418889ac5 100644 --- a/packer/rpc/post_processor_test.go +++ b/packer/rpc/post_processor_test.go @@ -34,10 +34,11 @@ func TestPostProcessorRPC(t *testing.T) { p := new(TestPostProcessor) // Start the server - server := NewServer() - server.RegisterPostProcessor(p) - client := testClient(t, server) + client, server := testClientServer(t) defer client.Close() + defer server.Close() + server.RegisterPostProcessor(p) + ppClient := client.PostProcessor() // Test Configure diff --git a/packer/rpc/server_new.go b/packer/rpc/server_new.go index a4782a0d1..afed8fe58 100644 --- a/packer/rpc/server_new.go +++ b/packer/rpc/server_new.go @@ -12,84 +12,68 @@ import ( var endpointId uint64 const ( - DefaultArtifactEndpoint string = "Artifact" + DefaultArtifactEndpoint string = "Artifact" + DefaultCacheEndpoint = "Cache" + DefaultPostProcessorEndpoint = "PostProcessor" + DefaultUiEndpoint = "Ui" ) // Server represents an RPC server for Packer. This must be paired on // the other side with a Client. type Server struct { - components map[string]interface{} + mux *MuxConn + server *rpc.Server } // NewServer returns a new Packer RPC server. -func NewServer() *Server { +func NewServer(conn io.ReadWriteCloser) *Server { return &Server{ - components: make(map[string]interface{}), + mux: NewMuxConn(conn), + server: rpc.NewServer(), } } +func (s *Server) Close() error { + return s.mux.Close() +} + func (s *Server) RegisterArtifact(a packer.Artifact) { - s.components[DefaultArtifactEndpoint] = a + s.server.RegisterName(DefaultArtifactEndpoint, &ArtifactServer{ + artifact: a, + }) } func (s *Server) RegisterCache(c packer.Cache) { - s.components["Cache"] = c + s.server.RegisterName(DefaultCacheEndpoint, &CacheServer{ + cache: c, + }) } func (s *Server) RegisterPostProcessor(p packer.PostProcessor) { - s.components["PostProcessor"] = p + s.server.RegisterName(DefaultPostProcessorEndpoint, &PostProcessorServer{ + p: p, + }) +} + +func (s *Server) RegisterUi(ui packer.Ui) { + s.server.RegisterName(DefaultUiEndpoint, &UiServer{ + ui: ui, + }) } // ServeConn serves a single connection over the RPC server. It is up // to the caller to obtain a proper io.ReadWriteCloser. -func (s *Server) ServeConn(conn io.ReadWriteCloser) { - mux := NewMuxConn(conn) - defer mux.Close() - +func (s *Server) Serve() { // Accept a connection on stream ID 0, which is always used for // normal client to server connections. - stream, err := mux.Accept(0) + stream, err := s.mux.Accept(0) + defer stream.Close() if err != nil { log.Printf("[ERR] Error retrieving stream for serving: %s", err) return } - clientConn, err := mux.Dial(1) - if err != nil { - log.Printf("[ERR] Error connecting to client stream: %s", err) - return - } - client := rpc.NewClient(clientConn) - - // Create the RPC server - server := rpc.NewServer() - for endpoint, iface := range s.components { - var endpointVal interface{} - - switch v := iface.(type) { - case packer.Artifact: - endpointVal = &ArtifactServer{ - artifact: v, - } - case packer.Cache: - endpointVal = &CacheServer{ - cache: v, - } - case packer.PostProcessor: - endpointVal = &PostProcessorServer{ - client: client, - server: server, - p: v, - } - default: - log.Printf("[ERR] Unknown component for endpoint: %s", endpoint) - return - } - - registerComponent(server, endpoint, endpointVal, false) - } - - server.ServeConn(stream) + s.server.ServeConn(stream) } // registerComponent registers a single Packer RPC component onto diff --git a/packer/rpc/ui_test.go b/packer/rpc/ui_test.go index 5241f7989..4dd2d7036 100644 --- a/packer/rpc/ui_test.go +++ b/packer/rpc/ui_test.go @@ -1,7 +1,6 @@ package rpc import ( - "net/rpc" "reflect" "testing" ) @@ -52,17 +51,12 @@ func TestUiRPC(t *testing.T) { ui := new(testUi) // Start the RPC server - server := rpc.NewServer() - RegisterUi(server, ui) - address := serveSingleConn(server) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterUi(ui) - // Create the client over RPC and run some methods to verify it works - client, err := rpc.Dial("tcp", address) - if err != nil { - panic(err) - } - - uiClient := &Ui{client: client} + uiClient := client.Ui() // Basic error and say tests result, err := uiClient.Ask("query")