packer/rpc: accept/dial stream IDs are unique [GH-727]
This commit is contained in:
parent
629f3eee21
commit
edbdee5dee
@ -17,7 +17,7 @@ type Client struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(rwc io.ReadWriteCloser) (*Client, error) {
|
func NewClient(rwc io.ReadWriteCloser) (*Client, error) {
|
||||||
result, err := newClientWithMux(NewMuxConn(rwc, 0), 0)
|
result, err := newClientWithMux(NewMuxConn(rwc), 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -164,7 +164,7 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if args.StdinStreamId > 0 {
|
if args.StdinStreamId >= 0 {
|
||||||
conn, err := c.mux.Dial(args.StdinStreamId)
|
conn, err := c.mux.Dial(args.StdinStreamId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
close(doneCh)
|
close(doneCh)
|
||||||
@ -175,7 +175,7 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface
|
|||||||
cmd.Stdin = conn
|
cmd.Stdin = conn
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.StdoutStreamId > 0 {
|
if args.StdoutStreamId >= 0 {
|
||||||
conn, err := c.mux.Dial(args.StdoutStreamId)
|
conn, err := c.mux.Dial(args.StdoutStreamId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
close(doneCh)
|
close(doneCh)
|
||||||
@ -186,7 +186,7 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface
|
|||||||
cmd.Stdout = conn
|
cmd.Stdout = conn
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.StderrStreamId > 0 {
|
if args.StderrStreamId >= 0 {
|
||||||
conn, err := c.mux.Dial(args.StderrStreamId)
|
conn, err := c.mux.Dial(args.StderrStreamId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
close(doneCh)
|
close(doneCh)
|
||||||
@ -253,7 +253,6 @@ func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *int
|
|||||||
}
|
}
|
||||||
|
|
||||||
func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) {
|
func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) {
|
||||||
log.Printf("[DEBUG] %s: Connecting to stream %d", name, id)
|
|
||||||
conn, err := mux.Accept(id)
|
conn, err := mux.Accept(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[ERR] '%s' accept error: %s", name, err)
|
log.Printf("[ERR] '%s' accept error: %s", name, err)
|
||||||
|
@ -22,16 +22,25 @@ import (
|
|||||||
// are established using a subset of the TCP protocol. Only a subset is
|
// are established using a subset of the TCP protocol. Only a subset is
|
||||||
// necessary since we assume ordering on the underlying RWC.
|
// necessary since we assume ordering on the underlying RWC.
|
||||||
type MuxConn struct {
|
type MuxConn struct {
|
||||||
curId uint32
|
curId uint32
|
||||||
rwc io.ReadWriteCloser
|
rwc io.ReadWriteCloser
|
||||||
streams map[uint32]*Stream
|
streamsAccept map[uint32]*Stream
|
||||||
mu sync.RWMutex
|
streamsDial map[uint32]*Stream
|
||||||
wlock sync.Mutex
|
mu sync.RWMutex
|
||||||
doneCh chan struct{}
|
muAccept sync.RWMutex
|
||||||
|
muDial sync.RWMutex
|
||||||
|
wlock sync.Mutex
|
||||||
|
doneCh chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type muxPacketFrom byte
|
||||||
type muxPacketType byte
|
type muxPacketType byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
muxPacketFromAccept muxPacketFrom = iota
|
||||||
|
muxPacketFromDial
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
muxPacketSyn muxPacketType = iota
|
muxPacketSyn muxPacketType = iota
|
||||||
muxPacketSynAck
|
muxPacketSynAck
|
||||||
@ -40,13 +49,24 @@ const (
|
|||||||
muxPacketData
|
muxPacketData
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (f muxPacketFrom) String() string {
|
||||||
|
switch f {
|
||||||
|
case muxPacketFromAccept:
|
||||||
|
return "accept"
|
||||||
|
case muxPacketFromDial:
|
||||||
|
return "dial"
|
||||||
|
default:
|
||||||
|
panic("unknown from type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create a new MuxConn around any io.ReadWriteCloser.
|
// Create a new MuxConn around any io.ReadWriteCloser.
|
||||||
func NewMuxConn(rwc io.ReadWriteCloser, startId uint32) *MuxConn {
|
func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
|
||||||
m := &MuxConn{
|
m := &MuxConn{
|
||||||
rwc: rwc,
|
rwc: rwc,
|
||||||
streams: make(map[uint32]*Stream),
|
streamsAccept: make(map[uint32]*Stream),
|
||||||
doneCh: make(chan struct{}),
|
streamsDial: make(map[uint32]*Stream),
|
||||||
curId: startId,
|
doneCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
go m.cleaner()
|
go m.cleaner()
|
||||||
@ -62,10 +82,14 @@ func (m *MuxConn) Close() error {
|
|||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
// Close all the streams
|
// Close all the streams
|
||||||
for _, w := range m.streams {
|
for _, w := range m.streamsAccept {
|
||||||
w.Close()
|
w.Close()
|
||||||
}
|
}
|
||||||
m.streams = make(map[uint32]*Stream)
|
for _, w := range m.streamsDial {
|
||||||
|
w.Close()
|
||||||
|
}
|
||||||
|
m.streamsAccept = make(map[uint32]*Stream)
|
||||||
|
m.streamsDial = make(map[uint32]*Stream)
|
||||||
|
|
||||||
// Close the actual connection. This will also force the loop
|
// Close the actual connection. This will also force the loop
|
||||||
// to end since it'll read EOF or closed connection.
|
// to end since it'll read EOF or closed connection.
|
||||||
@ -75,14 +99,22 @@ func (m *MuxConn) Close() error {
|
|||||||
// Accept accepts a multiplexed connection with the given ID. This
|
// Accept accepts a multiplexed connection with the given ID. This
|
||||||
// will block until a request is made to connect.
|
// will block until a request is made to connect.
|
||||||
func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
|
func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
|
||||||
stream, err := m.openStream(id)
|
//log.Printf("[TRACE] %p: Accept on stream ID: %d", m, id)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
// Get the stream. It is okay if it is already in the list of streams
|
||||||
}
|
// because we may have prematurely received a syn for it.
|
||||||
|
m.muAccept.Lock()
|
||||||
|
stream, ok := m.streamsAccept[id]
|
||||||
|
if !ok {
|
||||||
|
stream = newStream(muxPacketFromAccept, id, m)
|
||||||
|
m.streamsAccept[id] = stream
|
||||||
|
}
|
||||||
|
m.muAccept.Unlock()
|
||||||
|
|
||||||
// If the stream isn't closed, then it is already open somehow
|
|
||||||
stream.mu.Lock()
|
stream.mu.Lock()
|
||||||
defer stream.mu.Unlock()
|
defer stream.mu.Unlock()
|
||||||
|
|
||||||
|
// If the stream isn't closed, then it is already open somehow
|
||||||
if stream.state != streamStateSynRecv && stream.state != streamStateClosed {
|
if stream.state != streamStateSynRecv && stream.state != streamStateClosed {
|
||||||
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
|
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
|
||||||
}
|
}
|
||||||
@ -97,7 +129,7 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
|
|||||||
|
|
||||||
if stream.state == streamStateSynRecv {
|
if stream.state == streamStateSynRecv {
|
||||||
// Send a syn-ack
|
// Send a syn-ack
|
||||||
if _, err := m.write(stream.id, muxPacketSynAck, nil); err != nil {
|
if _, err := stream.write(muxPacketSynAck, nil); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -112,110 +144,68 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
|
|||||||
// Dial opens a connection to the remote end using the given stream ID.
|
// Dial opens a connection to the remote end using the given stream ID.
|
||||||
// An Accept on the remote end will only work with if the IDs match.
|
// An Accept on the remote end will only work with if the IDs match.
|
||||||
func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
|
func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
|
||||||
stream, err := m.openStream(id)
|
m.muDial.Lock()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
// If we have any streams with this ID, then it is a failure. The
|
||||||
|
// reaper should clear out old streams once in awhile.
|
||||||
|
if stream, ok := m.streamsDial[id]; ok {
|
||||||
|
m.muDial.Unlock()
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"Stream %d already open for dial. State: %d", stream.state)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the stream isn't closed, then it is already open somehow
|
// Create the new stream and put it in our list. We can then
|
||||||
|
// unlock because dialing will no longer be allowed on that ID.
|
||||||
|
stream := newStream(muxPacketFromDial, id, m)
|
||||||
|
m.streamsDial[id] = stream
|
||||||
|
|
||||||
|
// Don't let anyone else mess with this stream
|
||||||
stream.mu.Lock()
|
stream.mu.Lock()
|
||||||
defer stream.mu.Unlock()
|
defer stream.mu.Unlock()
|
||||||
if stream.state != streamStateClosed {
|
|
||||||
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
|
m.muDial.Unlock()
|
||||||
}
|
|
||||||
|
|
||||||
// Open a connection
|
// Open a connection
|
||||||
if _, err := m.write(stream.id, muxPacketSyn, nil); err != nil {
|
if _, err := stream.write(muxPacketSyn, nil); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// It is safe to set the state after the write above because
|
||||||
|
// we hold the stream lock.
|
||||||
stream.setState(streamStateSynSent)
|
stream.setState(streamStateSynSent)
|
||||||
|
|
||||||
if err := stream.waitState(streamStateEstablished); err != nil {
|
if err := stream.waitState(streamStateEstablished); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.write(id, muxPacketAck, nil)
|
stream.write(muxPacketAck, nil)
|
||||||
return stream, nil
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NextId returns the next available stream ID that isn't currently
|
// NextId returns the next available listen stream ID that isn't currently
|
||||||
// taken.
|
// taken.
|
||||||
func (m *MuxConn) NextId() uint32 {
|
func (m *MuxConn) NextId() uint32 {
|
||||||
m.mu.Lock()
|
m.muAccept.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.muAccept.Unlock()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
result := m.curId
|
result := m.curId
|
||||||
m.curId += 2
|
m.curId += 1
|
||||||
if _, ok := m.streams[result]; !ok {
|
if _, ok := m.streamsAccept[result]; !ok {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MuxConn) openStream(id uint32) (*Stream, error) {
|
|
||||||
// First grab a read-lock if we have the stream already we can
|
|
||||||
// cheaply return it.
|
|
||||||
m.mu.RLock()
|
|
||||||
if stream, ok := m.streams[id]; ok {
|
|
||||||
m.mu.RUnlock()
|
|
||||||
return stream, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now acquire a full blown write lock so we can create the stream
|
|
||||||
m.mu.RUnlock()
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
// Make sure we attempt to use the next biggest stream ID
|
|
||||||
if id >= m.curId {
|
|
||||||
m.curId = id + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// We have to check this again because there is a time period
|
|
||||||
// above where we couldn't lost this lock.
|
|
||||||
if stream, ok := m.streams[id]; ok {
|
|
||||||
return stream, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the stream object and channel where data will be sent to
|
|
||||||
dataR, dataW := io.Pipe()
|
|
||||||
writeCh := make(chan []byte, 256)
|
|
||||||
|
|
||||||
// Set the data channel so we can write to it.
|
|
||||||
stream := &Stream{
|
|
||||||
id: id,
|
|
||||||
mux: m,
|
|
||||||
reader: dataR,
|
|
||||||
writeCh: writeCh,
|
|
||||||
stateChange: make(map[chan<- streamState]struct{}),
|
|
||||||
}
|
|
||||||
stream.setState(streamStateClosed)
|
|
||||||
|
|
||||||
// Start the goroutine that will read from the queue and write
|
|
||||||
// data out.
|
|
||||||
go func() {
|
|
||||||
defer dataW.Close()
|
|
||||||
|
|
||||||
for {
|
|
||||||
data := <-writeCh
|
|
||||||
if data == nil {
|
|
||||||
// A nil is a tombstone letting us know we're done
|
|
||||||
// accepting data.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := dataW.Write(data); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
m.streams[id] = stream
|
|
||||||
return m.streams[id], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MuxConn) cleaner() {
|
func (m *MuxConn) cleaner() {
|
||||||
|
checks := []struct {
|
||||||
|
Map map[uint32]*Stream
|
||||||
|
Lock *sync.RWMutex
|
||||||
|
}{
|
||||||
|
{m.streamsAccept, &m.muAccept},
|
||||||
|
{m.streamsDial, &m.muDial},
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
done := false
|
done := false
|
||||||
select {
|
select {
|
||||||
@ -224,23 +214,28 @@ func (m *MuxConn) cleaner() {
|
|||||||
done = true
|
done = true
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mu.Lock()
|
for _, check := range checks {
|
||||||
for id, s := range m.streams {
|
check.Lock.Lock()
|
||||||
s.mu.Lock()
|
for id, s := range check.Map {
|
||||||
if s.state == streamStateClosed {
|
|
||||||
delete(m.streams, id)
|
|
||||||
}
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
if done {
|
|
||||||
for _, s := range m.streams {
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.closeWriter()
|
|
||||||
|
if done && s.state != streamStateClosed {
|
||||||
|
s.closeWriter()
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.state == streamStateClosed {
|
||||||
|
// Only clean up the streams that have been closed
|
||||||
|
// for a certain amount of time.
|
||||||
|
since := time.Now().UTC().Sub(s.stateUpdated)
|
||||||
|
if since > 2*time.Second {
|
||||||
|
delete(check.Map, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
check.Lock.Unlock()
|
||||||
}
|
}
|
||||||
m.mu.Unlock()
|
|
||||||
|
|
||||||
if done {
|
if done {
|
||||||
return
|
return
|
||||||
@ -256,10 +251,15 @@ func (m *MuxConn) loop() {
|
|||||||
close(m.doneCh)
|
close(m.doneCh)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var from muxPacketFrom
|
||||||
var id uint32
|
var id uint32
|
||||||
var packetType muxPacketType
|
var packetType muxPacketType
|
||||||
var length int32
|
var length int32
|
||||||
for {
|
for {
|
||||||
|
if err := binary.Read(m.rwc, binary.BigEndian, &from); err != nil {
|
||||||
|
log.Printf("[ERR] Error reading stream direction: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil {
|
if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil {
|
||||||
log.Printf("[ERR] Error reading stream ID: %s", err)
|
log.Printf("[ERR] Error reading stream ID: %s", err)
|
||||||
return
|
return
|
||||||
@ -282,15 +282,47 @@ func (m *MuxConn) loop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stream, err := m.openStream(id)
|
// Get the proper stream. Note that the map we look into is
|
||||||
if err != nil {
|
// opposite the "from" because if the dial side is talking to
|
||||||
log.Printf("[ERR] Error opening stream %d: %s", id, err)
|
// us, we need to look into the accept map, and so on.
|
||||||
return
|
//
|
||||||
|
// Note: we also switch the "from" value so that logging
|
||||||
|
// below is correct.
|
||||||
|
var stream *Stream
|
||||||
|
switch from {
|
||||||
|
case muxPacketFromDial:
|
||||||
|
m.muAccept.Lock()
|
||||||
|
stream = m.streamsAccept[id]
|
||||||
|
m.muAccept.Unlock()
|
||||||
|
|
||||||
|
from = muxPacketFromAccept
|
||||||
|
case muxPacketFromAccept:
|
||||||
|
m.muDial.Lock()
|
||||||
|
stream = m.streamsDial[id]
|
||||||
|
m.muDial.Unlock()
|
||||||
|
|
||||||
|
from = muxPacketFromDial
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("Unknown stream direction: %d", from))
|
||||||
}
|
}
|
||||||
|
|
||||||
//log.Printf("[TRACE] Stream %d received packet %d", id, packetType)
|
//log.Printf("[TRACE] %p: Stream %d (%s) received packet %d", m, id, from, packetType)
|
||||||
switch packetType {
|
switch packetType {
|
||||||
case muxPacketSyn:
|
case muxPacketSyn:
|
||||||
|
// If the stream is nil, this is the only case where we'll
|
||||||
|
// automatically create the stream struct.
|
||||||
|
if stream == nil {
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
m.muAccept.Lock()
|
||||||
|
stream, ok = m.streamsAccept[id]
|
||||||
|
if !ok {
|
||||||
|
stream = newStream(muxPacketFromAccept, id, m)
|
||||||
|
m.streamsAccept[id] = stream
|
||||||
|
}
|
||||||
|
m.muAccept.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
stream.mu.Lock()
|
stream.mu.Lock()
|
||||||
switch stream.state {
|
switch stream.state {
|
||||||
case streamStateClosed:
|
case streamStateClosed:
|
||||||
@ -332,15 +364,15 @@ func (m *MuxConn) loop() {
|
|||||||
case streamStateEstablished:
|
case streamStateEstablished:
|
||||||
stream.closeWriter()
|
stream.closeWriter()
|
||||||
stream.setState(streamStateCloseWait)
|
stream.setState(streamStateCloseWait)
|
||||||
m.write(id, muxPacketAck, nil)
|
stream.write(muxPacketAck, nil)
|
||||||
case streamStateFinWait2:
|
case streamStateFinWait2:
|
||||||
stream.closeWriter()
|
stream.closeWriter()
|
||||||
stream.setState(streamStateClosed)
|
stream.setState(streamStateClosed)
|
||||||
m.write(id, muxPacketAck, nil)
|
stream.write(muxPacketAck, nil)
|
||||||
case streamStateFinWait1:
|
case streamStateFinWait1:
|
||||||
stream.closeWriter()
|
stream.closeWriter()
|
||||||
stream.setState(streamStateClosing)
|
stream.setState(streamStateClosing)
|
||||||
m.write(id, muxPacketAck, nil)
|
stream.write(muxPacketAck, nil)
|
||||||
default:
|
default:
|
||||||
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
|
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
|
||||||
}
|
}
|
||||||
@ -367,10 +399,13 @@ func (m *MuxConn) loop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MuxConn) write(id uint32, dataType muxPacketType, p []byte) (int, error) {
|
func (m *MuxConn) write(from muxPacketFrom, id uint32, dataType muxPacketType, p []byte) (int, error) {
|
||||||
m.wlock.Lock()
|
m.wlock.Lock()
|
||||||
defer m.wlock.Unlock()
|
defer m.wlock.Unlock()
|
||||||
|
|
||||||
|
if err := binary.Write(m.rwc, binary.BigEndian, from); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil {
|
if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -389,6 +424,7 @@ func (m *MuxConn) write(id uint32, dataType muxPacketType, p []byte) (int, error
|
|||||||
// Stream is a single stream of data and implements io.ReadWriteCloser.
|
// Stream is a single stream of data and implements io.ReadWriteCloser.
|
||||||
// A Stream is full-duplex so you can write data as well as read data.
|
// A Stream is full-duplex so you can write data as well as read data.
|
||||||
type Stream struct {
|
type Stream struct {
|
||||||
|
from muxPacketFrom
|
||||||
id uint32
|
id uint32
|
||||||
mux *MuxConn
|
mux *MuxConn
|
||||||
reader io.Reader
|
reader io.Reader
|
||||||
@ -414,6 +450,44 @@ const (
|
|||||||
streamStateLastAck
|
streamStateLastAck
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream {
|
||||||
|
// Create the stream object and channel where data will be sent to
|
||||||
|
dataR, dataW := io.Pipe()
|
||||||
|
writeCh := make(chan []byte, 256)
|
||||||
|
|
||||||
|
// Set the data channel so we can write to it.
|
||||||
|
stream := &Stream{
|
||||||
|
from: from,
|
||||||
|
id: id,
|
||||||
|
mux: m,
|
||||||
|
reader: dataR,
|
||||||
|
writeCh: writeCh,
|
||||||
|
stateChange: make(map[chan<- streamState]struct{}),
|
||||||
|
}
|
||||||
|
stream.setState(streamStateClosed)
|
||||||
|
|
||||||
|
// Start the goroutine that will read from the queue and write
|
||||||
|
// data out.
|
||||||
|
go func() {
|
||||||
|
defer dataW.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
data := <-writeCh
|
||||||
|
if data == nil {
|
||||||
|
// A nil is a tombstone letting us know we're done
|
||||||
|
// accepting data.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := dataW.Write(data); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return stream
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Stream) Close() error {
|
func (s *Stream) Close() error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@ -428,7 +502,7 @@ func (s *Stream) Close() error {
|
|||||||
s.setState(streamStateLastAck)
|
s.setState(streamStateLastAck)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.mux.write(s.id, muxPacketFin, nil)
|
s.write(muxPacketFin, nil)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -445,7 +519,7 @@ func (s *Stream) Write(p []byte) (int, error) {
|
|||||||
return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state)
|
return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.mux.write(s.id, muxPacketData, p)
|
return s.write(muxPacketData, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stream) closeWriter() {
|
func (s *Stream) closeWriter() {
|
||||||
@ -453,7 +527,7 @@ func (s *Stream) closeWriter() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stream) setState(state streamState) {
|
func (s *Stream) setState(state streamState) {
|
||||||
//log.Printf("[TRACE] Stream %d went to state %d", s.id, state)
|
//log.Printf("[TRACE] %p: Stream %d (%s) went to state %d", s.mux, s.id, s.from, state)
|
||||||
s.state = state
|
s.state = state
|
||||||
s.stateUpdated = time.Now().UTC()
|
s.stateUpdated = time.Now().UTC()
|
||||||
for ch, _ := range s.stateChange {
|
for ch, _ := range s.stateChange {
|
||||||
@ -482,3 +556,7 @@ func (s *Stream) waitState(target streamState) error {
|
|||||||
return fmt.Errorf("Stream %d went to bad state: %d", s.id, state)
|
return fmt.Errorf("Stream %d went to bad state: %d", s.id, state)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Stream) write(dataType muxPacketType, p []byte) (int, error) {
|
||||||
|
return s.mux.write(s.from, s.id, dataType, p)
|
||||||
|
}
|
||||||
|
@ -33,7 +33,7 @@ func testMux(t *testing.T) (client *MuxConn, server *MuxConn) {
|
|||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server = NewMuxConn(conn, 1)
|
server = NewMuxConn(conn)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Client side
|
// Client side
|
||||||
@ -41,7 +41,7 @@ func testMux(t *testing.T) (client *MuxConn, server *MuxConn) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
client = NewMuxConn(conn, 0)
|
client = NewMuxConn(conn)
|
||||||
|
|
||||||
// Wait for the server
|
// Wait for the server
|
||||||
<-doneCh
|
<-doneCh
|
||||||
@ -241,14 +241,14 @@ func TestMuxConnNextId(t *testing.T) {
|
|||||||
a := client.NextId()
|
a := client.NextId()
|
||||||
b := client.NextId()
|
b := client.NextId()
|
||||||
|
|
||||||
if a != 0 || b != 2 {
|
if a != 0 || b != 1 {
|
||||||
t.Fatalf("IDs should increment")
|
t.Fatalf("IDs should increment")
|
||||||
}
|
}
|
||||||
|
|
||||||
a = server.NextId()
|
a = server.NextId()
|
||||||
b = server.NextId()
|
b = server.NextId()
|
||||||
|
|
||||||
if a != 1 || b != 3 {
|
if a != 0 || b != 1 {
|
||||||
t.Fatalf("IDs should increment: %d %d", a, b)
|
t.Fatalf("IDs should increment: %d %d", a, b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -36,7 +36,7 @@ type Server struct {
|
|||||||
|
|
||||||
// NewServer returns a new Packer RPC server.
|
// NewServer returns a new Packer RPC server.
|
||||||
func NewServer(conn io.ReadWriteCloser) *Server {
|
func NewServer(conn io.ReadWriteCloser) *Server {
|
||||||
result := newServerWithMux(NewMuxConn(conn, 1), 0)
|
result := newServerWithMux(NewMuxConn(conn), 0)
|
||||||
result.closeMux = true
|
result.closeMux = true
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user