packer/rpc: MuxConn implements three-way handshake
This commit is contained in:
parent
311fb2064d
commit
a2f46a989f
|
@ -253,9 +253,10 @@ func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *int
|
|||
}
|
||||
|
||||
func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) {
|
||||
log.Printf("[DEBUG] %s: Connecting to stream %d", name, id)
|
||||
conn, err := mux.Accept(id)
|
||||
if err != nil {
|
||||
log.Printf("'%s' accept error: %s", name, err)
|
||||
log.Printf("[ERR] '%s' accept error: %s", name, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -271,8 +272,8 @@ func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io
|
|||
}
|
||||
|
||||
written, err := io.Copy(dst, src)
|
||||
log.Printf("%d bytes written for '%s'", written, name)
|
||||
log.Printf("[INFO] %d bytes written for '%s'", written, name)
|
||||
if err != nil {
|
||||
log.Printf("'%s' copy error: %s", name, err)
|
||||
log.Printf("[ERR] '%s' copy error: %s", name, err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ type muxPacketType byte
|
|||
|
||||
const (
|
||||
muxPacketSyn muxPacketType = iota
|
||||
muxPacketSynAck
|
||||
muxPacketAck
|
||||
muxPacketFin
|
||||
muxPacketData
|
||||
|
@ -77,49 +78,27 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
|
|||
|
||||
// If the stream isn't closed, then it is already open somehow
|
||||
stream.mu.Lock()
|
||||
defer stream.mu.Unlock()
|
||||
if stream.state != streamStateSynRecv && stream.state != streamStateClosed {
|
||||
stream.mu.Unlock()
|
||||
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
|
||||
}
|
||||
|
||||
if stream.state == streamStateSynRecv {
|
||||
// Fast track establishing since we already got the syn
|
||||
stream.setState(streamStateEstablished)
|
||||
stream.mu.Unlock()
|
||||
}
|
||||
|
||||
if stream.state != streamStateEstablished {
|
||||
// Go into the listening state
|
||||
if stream.state == streamStateClosed {
|
||||
// Go into the listening state and wait for a syn
|
||||
stream.setState(streamStateListen)
|
||||
|
||||
// Register a state change listener to wait for changes
|
||||
stateCh := make(chan streamState, 10)
|
||||
stream.registerStateListener(stateCh)
|
||||
defer func() {
|
||||
stream.mu.Lock()
|
||||
defer stream.mu.Unlock()
|
||||
stream.deregisterStateListener(stateCh)
|
||||
}()
|
||||
|
||||
stream.mu.Unlock()
|
||||
|
||||
// Wait for the connection to establish
|
||||
ACCEPT_ESTABLISH_LOOP:
|
||||
for {
|
||||
state := <-stateCh
|
||||
switch state {
|
||||
case streamStateListen:
|
||||
case streamStateEstablished:
|
||||
break ACCEPT_ESTABLISH_LOOP
|
||||
default:
|
||||
defer stream.mu.Unlock()
|
||||
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
|
||||
}
|
||||
if err := stream.waitState(streamStateSynRecv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Send the ack down
|
||||
if _, err := m.write(stream.id, muxPacketAck, nil); err != nil {
|
||||
if stream.state == streamStateSynRecv {
|
||||
// Send a syn-ack
|
||||
if _, err := m.write(stream.id, muxPacketSynAck, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.waitState(streamStateEstablished); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -136,8 +115,8 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
|
|||
|
||||
// If the stream isn't closed, then it is already open somehow
|
||||
stream.mu.Lock()
|
||||
defer stream.mu.Unlock()
|
||||
if stream.state != streamStateClosed {
|
||||
stream.mu.Unlock()
|
||||
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
|
||||
}
|
||||
|
||||
|
@ -147,28 +126,12 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
|
|||
}
|
||||
stream.setState(streamStateSynSent)
|
||||
|
||||
// Register a state change listener to wait for changes
|
||||
stateCh := make(chan streamState, 10)
|
||||
stream.registerStateListener(stateCh)
|
||||
defer func() {
|
||||
stream.mu.Lock()
|
||||
defer stream.mu.Unlock()
|
||||
stream.deregisterStateListener(stateCh)
|
||||
}()
|
||||
|
||||
stream.mu.Unlock()
|
||||
|
||||
for {
|
||||
state := <-stateCh
|
||||
switch state {
|
||||
case streamStateSynSent:
|
||||
case streamStateEstablished:
|
||||
return stream, nil
|
||||
default:
|
||||
defer stream.mu.Unlock()
|
||||
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
|
||||
}
|
||||
if err := stream.waitState(streamStateEstablished); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.write(id, muxPacketAck, nil)
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// NextId returns the next available stream ID that isn't currently
|
||||
|
@ -247,6 +210,7 @@ func (m *MuxConn) loop() {
|
|||
// Force close every stream that we know about when we exit so
|
||||
// that they all read EOF and don't block forever.
|
||||
defer func() {
|
||||
log.Printf("[INFO] Mux connection loop exiting")
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for _, w := range m.streams {
|
||||
|
@ -288,12 +252,23 @@ func (m *MuxConn) loop() {
|
|||
return
|
||||
}
|
||||
|
||||
//log.Printf("[DEBUG] Stream %d received packet %d", id, packetType)
|
||||
log.Printf("[TRACE] Stream %d received packet %d", id, packetType)
|
||||
switch packetType {
|
||||
case muxPacketSyn:
|
||||
stream.mu.Lock()
|
||||
switch stream.state {
|
||||
case streamStateClosed:
|
||||
fallthrough
|
||||
case streamStateListen:
|
||||
stream.setState(streamStateSynRecv)
|
||||
default:
|
||||
log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
|
||||
}
|
||||
stream.mu.Unlock()
|
||||
case muxPacketAck:
|
||||
stream.mu.Lock()
|
||||
switch stream.state {
|
||||
case streamStateSynSent:
|
||||
case streamStateSynRecv:
|
||||
stream.setState(streamStateEstablished)
|
||||
case streamStateFinWait1:
|
||||
stream.setState(streamStateFinWait2)
|
||||
|
@ -301,15 +276,13 @@ func (m *MuxConn) loop() {
|
|||
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
|
||||
}
|
||||
stream.mu.Unlock()
|
||||
case muxPacketSyn:
|
||||
case muxPacketSynAck:
|
||||
stream.mu.Lock()
|
||||
switch stream.state {
|
||||
case streamStateClosed:
|
||||
stream.setState(streamStateSynRecv)
|
||||
case streamStateListen:
|
||||
case streamStateSynSent:
|
||||
stream.setState(streamStateEstablished)
|
||||
default:
|
||||
log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
|
||||
log.Printf("[ERR] SynAck received for stream in state: %d", stream.state)
|
||||
}
|
||||
stream.mu.Unlock()
|
||||
case muxPacketFin:
|
||||
|
@ -451,6 +424,7 @@ func (s *Stream) deregisterStateListener(ch chan<- streamState) {
|
|||
}
|
||||
|
||||
func (s *Stream) setState(state streamState) {
|
||||
log.Printf("[TRACE] Stream %d went to state %d", s.id, state)
|
||||
s.state = state
|
||||
s.stateUpdated = time.Now().UTC()
|
||||
for ch, _ := range s.stateChange {
|
||||
|
@ -460,3 +434,22 @@ func (s *Stream) setState(state streamState) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Stream) waitState(target streamState) error {
|
||||
// Register a state change listener to wait for changes
|
||||
stateCh := make(chan streamState, 10)
|
||||
s.registerStateListener(stateCh)
|
||||
s.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.deregisterStateListener(stateCh)
|
||||
}()
|
||||
|
||||
state := <-stateCh
|
||||
if state == target {
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("Stream %d went to bad state: %d", s.id, state)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue