packer/rpc: accept/dial stream IDs are unique [GH-727]

This commit is contained in:
Mitchell Hashimoto 2013-12-20 09:49:44 -08:00
parent 629f3eee21
commit edbdee5dee
5 changed files with 207 additions and 130 deletions

View File

@ -17,7 +17,7 @@ type Client struct {
} }
func NewClient(rwc io.ReadWriteCloser) (*Client, error) { func NewClient(rwc io.ReadWriteCloser) (*Client, error) {
result, err := newClientWithMux(NewMuxConn(rwc, 0), 0) result, err := newClientWithMux(NewMuxConn(rwc), 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -164,7 +164,7 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface
} }
}() }()
if args.StdinStreamId > 0 { if args.StdinStreamId >= 0 {
conn, err := c.mux.Dial(args.StdinStreamId) conn, err := c.mux.Dial(args.StdinStreamId)
if err != nil { if err != nil {
close(doneCh) close(doneCh)
@ -175,7 +175,7 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface
cmd.Stdin = conn cmd.Stdin = conn
} }
if args.StdoutStreamId > 0 { if args.StdoutStreamId >= 0 {
conn, err := c.mux.Dial(args.StdoutStreamId) conn, err := c.mux.Dial(args.StdoutStreamId)
if err != nil { if err != nil {
close(doneCh) close(doneCh)
@ -186,7 +186,7 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface
cmd.Stdout = conn cmd.Stdout = conn
} }
if args.StderrStreamId > 0 { if args.StderrStreamId >= 0 {
conn, err := c.mux.Dial(args.StderrStreamId) conn, err := c.mux.Dial(args.StderrStreamId)
if err != nil { if err != nil {
close(doneCh) close(doneCh)
@ -253,7 +253,6 @@ func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *int
} }
func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) { func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) {
log.Printf("[DEBUG] %s: Connecting to stream %d", name, id)
conn, err := mux.Accept(id) conn, err := mux.Accept(id)
if err != nil { if err != nil {
log.Printf("[ERR] '%s' accept error: %s", name, err) log.Printf("[ERR] '%s' accept error: %s", name, err)

View File

@ -24,14 +24,23 @@ import (
type MuxConn struct { type MuxConn struct {
curId uint32 curId uint32
rwc io.ReadWriteCloser rwc io.ReadWriteCloser
streams map[uint32]*Stream streamsAccept map[uint32]*Stream
streamsDial map[uint32]*Stream
mu sync.RWMutex mu sync.RWMutex
muAccept sync.RWMutex
muDial sync.RWMutex
wlock sync.Mutex wlock sync.Mutex
doneCh chan struct{} doneCh chan struct{}
} }
type muxPacketFrom byte
type muxPacketType byte type muxPacketType byte
const (
muxPacketFromAccept muxPacketFrom = iota
muxPacketFromDial
)
const ( const (
muxPacketSyn muxPacketType = iota muxPacketSyn muxPacketType = iota
muxPacketSynAck muxPacketSynAck
@ -40,13 +49,24 @@ const (
muxPacketData muxPacketData
) )
func (f muxPacketFrom) String() string {
switch f {
case muxPacketFromAccept:
return "accept"
case muxPacketFromDial:
return "dial"
default:
panic("unknown from type")
}
}
// Create a new MuxConn around any io.ReadWriteCloser. // Create a new MuxConn around any io.ReadWriteCloser.
func NewMuxConn(rwc io.ReadWriteCloser, startId uint32) *MuxConn { func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
m := &MuxConn{ m := &MuxConn{
rwc: rwc, rwc: rwc,
streams: make(map[uint32]*Stream), streamsAccept: make(map[uint32]*Stream),
streamsDial: make(map[uint32]*Stream),
doneCh: make(chan struct{}), doneCh: make(chan struct{}),
curId: startId,
} }
go m.cleaner() go m.cleaner()
@ -62,10 +82,14 @@ func (m *MuxConn) Close() error {
defer m.mu.Unlock() defer m.mu.Unlock()
// Close all the streams // Close all the streams
for _, w := range m.streams { for _, w := range m.streamsAccept {
w.Close() w.Close()
} }
m.streams = make(map[uint32]*Stream) for _, w := range m.streamsDial {
w.Close()
}
m.streamsAccept = make(map[uint32]*Stream)
m.streamsDial = make(map[uint32]*Stream)
// Close the actual connection. This will also force the loop // Close the actual connection. This will also force the loop
// to end since it'll read EOF or closed connection. // to end since it'll read EOF or closed connection.
@ -75,14 +99,22 @@ func (m *MuxConn) Close() error {
// Accept accepts a multiplexed connection with the given ID. This // Accept accepts a multiplexed connection with the given ID. This
// will block until a request is made to connect. // will block until a request is made to connect.
func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) { func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
stream, err := m.openStream(id) //log.Printf("[TRACE] %p: Accept on stream ID: %d", m, id)
if err != nil {
return nil, err // Get the stream. It is okay if it is already in the list of streams
} // because we may have prematurely received a syn for it.
m.muAccept.Lock()
stream, ok := m.streamsAccept[id]
if !ok {
stream = newStream(muxPacketFromAccept, id, m)
m.streamsAccept[id] = stream
}
m.muAccept.Unlock()
// If the stream isn't closed, then it is already open somehow
stream.mu.Lock() stream.mu.Lock()
defer stream.mu.Unlock() defer stream.mu.Unlock()
// If the stream isn't closed, then it is already open somehow
if stream.state != streamStateSynRecv && stream.state != streamStateClosed { if stream.state != streamStateSynRecv && stream.state != streamStateClosed {
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state) return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
} }
@ -97,7 +129,7 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
if stream.state == streamStateSynRecv { if stream.state == streamStateSynRecv {
// Send a syn-ack // Send a syn-ack
if _, err := m.write(stream.id, muxPacketSynAck, nil); err != nil { if _, err := stream.write(muxPacketSynAck, nil); err != nil {
return nil, err return nil, err
} }
} }
@ -112,110 +144,68 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
// Dial opens a connection to the remote end using the given stream ID. // Dial opens a connection to the remote end using the given stream ID.
// An Accept on the remote end will only work with if the IDs match. // An Accept on the remote end will only work with if the IDs match.
func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) { func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
stream, err := m.openStream(id) m.muDial.Lock()
if err != nil {
return nil, err // If we have any streams with this ID, then it is a failure. The
// reaper should clear out old streams once in awhile.
if stream, ok := m.streamsDial[id]; ok {
m.muDial.Unlock()
return nil, fmt.Errorf(
"Stream %d already open for dial. State: %d", stream.state)
} }
// If the stream isn't closed, then it is already open somehow // Create the new stream and put it in our list. We can then
// unlock because dialing will no longer be allowed on that ID.
stream := newStream(muxPacketFromDial, id, m)
m.streamsDial[id] = stream
// Don't let anyone else mess with this stream
stream.mu.Lock() stream.mu.Lock()
defer stream.mu.Unlock() defer stream.mu.Unlock()
if stream.state != streamStateClosed {
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state) m.muDial.Unlock()
}
// Open a connection // Open a connection
if _, err := m.write(stream.id, muxPacketSyn, nil); err != nil { if _, err := stream.write(muxPacketSyn, nil); err != nil {
return nil, err return nil, err
} }
// It is safe to set the state after the write above because
// we hold the stream lock.
stream.setState(streamStateSynSent) stream.setState(streamStateSynSent)
if err := stream.waitState(streamStateEstablished); err != nil { if err := stream.waitState(streamStateEstablished); err != nil {
return nil, err return nil, err
} }
m.write(id, muxPacketAck, nil) stream.write(muxPacketAck, nil)
return stream, nil return stream, nil
} }
// NextId returns the next available stream ID that isn't currently // NextId returns the next available listen stream ID that isn't currently
// taken. // taken.
func (m *MuxConn) NextId() uint32 { func (m *MuxConn) NextId() uint32 {
m.mu.Lock() m.muAccept.Lock()
defer m.mu.Unlock() defer m.muAccept.Unlock()
for { for {
result := m.curId result := m.curId
m.curId += 2 m.curId += 1
if _, ok := m.streams[result]; !ok { if _, ok := m.streamsAccept[result]; !ok {
return result return result
} }
} }
} }
func (m *MuxConn) openStream(id uint32) (*Stream, error) {
// First grab a read-lock if we have the stream already we can
// cheaply return it.
m.mu.RLock()
if stream, ok := m.streams[id]; ok {
m.mu.RUnlock()
return stream, nil
}
// Now acquire a full blown write lock so we can create the stream
m.mu.RUnlock()
m.mu.Lock()
defer m.mu.Unlock()
// Make sure we attempt to use the next biggest stream ID
if id >= m.curId {
m.curId = id + 1
}
// We have to check this again because there is a time period
// above where we couldn't lost this lock.
if stream, ok := m.streams[id]; ok {
return stream, nil
}
// Create the stream object and channel where data will be sent to
dataR, dataW := io.Pipe()
writeCh := make(chan []byte, 256)
// Set the data channel so we can write to it.
stream := &Stream{
id: id,
mux: m,
reader: dataR,
writeCh: writeCh,
stateChange: make(map[chan<- streamState]struct{}),
}
stream.setState(streamStateClosed)
// Start the goroutine that will read from the queue and write
// data out.
go func() {
defer dataW.Close()
for {
data := <-writeCh
if data == nil {
// A nil is a tombstone letting us know we're done
// accepting data.
return
}
if _, err := dataW.Write(data); err != nil {
return
}
}
}()
m.streams[id] = stream
return m.streams[id], nil
}
func (m *MuxConn) cleaner() { func (m *MuxConn) cleaner() {
checks := []struct {
Map map[uint32]*Stream
Lock *sync.RWMutex
}{
{m.streamsAccept, &m.muAccept},
{m.streamsDial, &m.muDial},
}
for { for {
done := false done := false
select { select {
@ -224,23 +214,28 @@ func (m *MuxConn) cleaner() {
done = true done = true
} }
m.mu.Lock() for _, check := range checks {
for id, s := range m.streams { check.Lock.Lock()
for id, s := range check.Map {
s.mu.Lock() s.mu.Lock()
if s.state == streamStateClosed {
delete(m.streams, id) if done && s.state != streamStateClosed {
} s.closeWriter()
s.mu.Unlock() }
if s.state == streamStateClosed {
// Only clean up the streams that have been closed
// for a certain amount of time.
since := time.Now().UTC().Sub(s.stateUpdated)
if since > 2*time.Second {
delete(check.Map, id)
}
} }
if done {
for _, s := range m.streams {
s.mu.Lock()
s.closeWriter()
s.mu.Unlock() s.mu.Unlock()
} }
check.Lock.Unlock()
} }
m.mu.Unlock()
if done { if done {
return return
@ -256,10 +251,15 @@ func (m *MuxConn) loop() {
close(m.doneCh) close(m.doneCh)
}() }()
var from muxPacketFrom
var id uint32 var id uint32
var packetType muxPacketType var packetType muxPacketType
var length int32 var length int32
for { for {
if err := binary.Read(m.rwc, binary.BigEndian, &from); err != nil {
log.Printf("[ERR] Error reading stream direction: %s", err)
return
}
if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil { if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil {
log.Printf("[ERR] Error reading stream ID: %s", err) log.Printf("[ERR] Error reading stream ID: %s", err)
return return
@ -282,15 +282,47 @@ func (m *MuxConn) loop() {
} }
} }
stream, err := m.openStream(id) // Get the proper stream. Note that the map we look into is
if err != nil { // opposite the "from" because if the dial side is talking to
log.Printf("[ERR] Error opening stream %d: %s", id, err) // us, we need to look into the accept map, and so on.
return //
// Note: we also switch the "from" value so that logging
// below is correct.
var stream *Stream
switch from {
case muxPacketFromDial:
m.muAccept.Lock()
stream = m.streamsAccept[id]
m.muAccept.Unlock()
from = muxPacketFromAccept
case muxPacketFromAccept:
m.muDial.Lock()
stream = m.streamsDial[id]
m.muDial.Unlock()
from = muxPacketFromDial
default:
panic(fmt.Sprintf("Unknown stream direction: %d", from))
} }
//log.Printf("[TRACE] Stream %d received packet %d", id, packetType) //log.Printf("[TRACE] %p: Stream %d (%s) received packet %d", m, id, from, packetType)
switch packetType { switch packetType {
case muxPacketSyn: case muxPacketSyn:
// If the stream is nil, this is the only case where we'll
// automatically create the stream struct.
if stream == nil {
var ok bool
m.muAccept.Lock()
stream, ok = m.streamsAccept[id]
if !ok {
stream = newStream(muxPacketFromAccept, id, m)
m.streamsAccept[id] = stream
}
m.muAccept.Unlock()
}
stream.mu.Lock() stream.mu.Lock()
switch stream.state { switch stream.state {
case streamStateClosed: case streamStateClosed:
@ -332,15 +364,15 @@ func (m *MuxConn) loop() {
case streamStateEstablished: case streamStateEstablished:
stream.closeWriter() stream.closeWriter()
stream.setState(streamStateCloseWait) stream.setState(streamStateCloseWait)
m.write(id, muxPacketAck, nil) stream.write(muxPacketAck, nil)
case streamStateFinWait2: case streamStateFinWait2:
stream.closeWriter() stream.closeWriter()
stream.setState(streamStateClosed) stream.setState(streamStateClosed)
m.write(id, muxPacketAck, nil) stream.write(muxPacketAck, nil)
case streamStateFinWait1: case streamStateFinWait1:
stream.closeWriter() stream.closeWriter()
stream.setState(streamStateClosing) stream.setState(streamStateClosing)
m.write(id, muxPacketAck, nil) stream.write(muxPacketAck, nil)
default: default:
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state) log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
} }
@ -367,10 +399,13 @@ func (m *MuxConn) loop() {
} }
} }
func (m *MuxConn) write(id uint32, dataType muxPacketType, p []byte) (int, error) { func (m *MuxConn) write(from muxPacketFrom, id uint32, dataType muxPacketType, p []byte) (int, error) {
m.wlock.Lock() m.wlock.Lock()
defer m.wlock.Unlock() defer m.wlock.Unlock()
if err := binary.Write(m.rwc, binary.BigEndian, from); err != nil {
return 0, err
}
if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil { if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil {
return 0, err return 0, err
} }
@ -389,6 +424,7 @@ func (m *MuxConn) write(id uint32, dataType muxPacketType, p []byte) (int, error
// Stream is a single stream of data and implements io.ReadWriteCloser. // Stream is a single stream of data and implements io.ReadWriteCloser.
// A Stream is full-duplex so you can write data as well as read data. // A Stream is full-duplex so you can write data as well as read data.
type Stream struct { type Stream struct {
from muxPacketFrom
id uint32 id uint32
mux *MuxConn mux *MuxConn
reader io.Reader reader io.Reader
@ -414,6 +450,44 @@ const (
streamStateLastAck streamStateLastAck
) )
func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream {
// Create the stream object and channel where data will be sent to
dataR, dataW := io.Pipe()
writeCh := make(chan []byte, 256)
// Set the data channel so we can write to it.
stream := &Stream{
from: from,
id: id,
mux: m,
reader: dataR,
writeCh: writeCh,
stateChange: make(map[chan<- streamState]struct{}),
}
stream.setState(streamStateClosed)
// Start the goroutine that will read from the queue and write
// data out.
go func() {
defer dataW.Close()
for {
data := <-writeCh
if data == nil {
// A nil is a tombstone letting us know we're done
// accepting data.
return
}
if _, err := dataW.Write(data); err != nil {
return
}
}
}()
return stream
}
func (s *Stream) Close() error { func (s *Stream) Close() error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -428,7 +502,7 @@ func (s *Stream) Close() error {
s.setState(streamStateLastAck) s.setState(streamStateLastAck)
} }
s.mux.write(s.id, muxPacketFin, nil) s.write(muxPacketFin, nil)
return nil return nil
} }
@ -445,7 +519,7 @@ func (s *Stream) Write(p []byte) (int, error) {
return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state) return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state)
} }
return s.mux.write(s.id, muxPacketData, p) return s.write(muxPacketData, p)
} }
func (s *Stream) closeWriter() { func (s *Stream) closeWriter() {
@ -453,7 +527,7 @@ func (s *Stream) closeWriter() {
} }
func (s *Stream) setState(state streamState) { func (s *Stream) setState(state streamState) {
//log.Printf("[TRACE] Stream %d went to state %d", s.id, state) //log.Printf("[TRACE] %p: Stream %d (%s) went to state %d", s.mux, s.id, s.from, state)
s.state = state s.state = state
s.stateUpdated = time.Now().UTC() s.stateUpdated = time.Now().UTC()
for ch, _ := range s.stateChange { for ch, _ := range s.stateChange {
@ -482,3 +556,7 @@ func (s *Stream) waitState(target streamState) error {
return fmt.Errorf("Stream %d went to bad state: %d", s.id, state) return fmt.Errorf("Stream %d went to bad state: %d", s.id, state)
} }
} }
func (s *Stream) write(dataType muxPacketType, p []byte) (int, error) {
return s.mux.write(s.from, s.id, dataType, p)
}

View File

@ -33,7 +33,7 @@ func testMux(t *testing.T) (client *MuxConn, server *MuxConn) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
server = NewMuxConn(conn, 1) server = NewMuxConn(conn)
}() }()
// Client side // Client side
@ -41,7 +41,7 @@ func testMux(t *testing.T) (client *MuxConn, server *MuxConn) {
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
client = NewMuxConn(conn, 0) client = NewMuxConn(conn)
// Wait for the server // Wait for the server
<-doneCh <-doneCh
@ -241,14 +241,14 @@ func TestMuxConnNextId(t *testing.T) {
a := client.NextId() a := client.NextId()
b := client.NextId() b := client.NextId()
if a != 0 || b != 2 { if a != 0 || b != 1 {
t.Fatalf("IDs should increment") t.Fatalf("IDs should increment")
} }
a = server.NextId() a = server.NextId()
b = server.NextId() b = server.NextId()
if a != 1 || b != 3 { if a != 0 || b != 1 {
t.Fatalf("IDs should increment: %d %d", a, b) t.Fatalf("IDs should increment: %d %d", a, b)
} }
} }

View File

@ -36,7 +36,7 @@ type Server struct {
// NewServer returns a new Packer RPC server. // NewServer returns a new Packer RPC server.
func NewServer(conn io.ReadWriteCloser) *Server { func NewServer(conn io.ReadWriteCloser) *Server {
result := newServerWithMux(NewMuxConn(conn, 1), 0) result := newServerWithMux(NewMuxConn(conn), 0)
result.closeMux = true result.closeMux = true
return result return result
} }