Merge pull request #6851 from hashicorp/pb_rc
[WIP] progress bar race conditions
This commit is contained in:
commit
4357714a29
|
@ -4,7 +4,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"github.com/cheggaaa/pb"
|
"github.com/cheggaaa/pb"
|
||||||
)
|
)
|
||||||
|
@ -36,16 +35,25 @@ type StackableProgressBar struct {
|
||||||
items int32
|
items int32
|
||||||
total int64
|
total int64
|
||||||
|
|
||||||
started bool
|
started bool
|
||||||
|
ConfigProgressbarFN func(*pb.ProgressBar)
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ProgressBar = new(StackableProgressBar)
|
var _ ProgressBar = new(StackableProgressBar)
|
||||||
|
|
||||||
func (spb *StackableProgressBar) start() {
|
func defaultProgressbarConfigFn(bar *pb.ProgressBar) {
|
||||||
spb.Bar.ProgressBar = pb.New(0)
|
bar.SetUnits(pb.U_BYTES)
|
||||||
spb.Bar.ProgressBar.SetUnits(pb.U_BYTES)
|
}
|
||||||
|
|
||||||
spb.Bar.ProgressBar.Start()
|
func (spb *StackableProgressBar) start() {
|
||||||
|
bar := pb.New(0)
|
||||||
|
if spb.ConfigProgressbarFN == nil {
|
||||||
|
spb.ConfigProgressbarFN = defaultProgressbarConfigFn
|
||||||
|
}
|
||||||
|
spb.ConfigProgressbarFN(bar)
|
||||||
|
|
||||||
|
bar.Start()
|
||||||
|
spb.Bar.ProgressBar = bar
|
||||||
spb.started = true
|
spb.started = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,7 +74,9 @@ func (spb *StackableProgressBar) Start(total int64) {
|
||||||
func (spb *StackableProgressBar) Add(total int64) {
|
func (spb *StackableProgressBar) Add(total int64) {
|
||||||
spb.mtx.Lock()
|
spb.mtx.Lock()
|
||||||
defer spb.mtx.Unlock()
|
defer spb.mtx.Unlock()
|
||||||
spb.Bar.Add(total)
|
if spb.Bar.ProgressBar != nil {
|
||||||
|
spb.Bar.Add(total)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spb *StackableProgressBar) NewProxyReader(r io.Reader) io.Reader {
|
func (spb *StackableProgressBar) NewProxyReader(r io.Reader) io.Reader {
|
||||||
|
@ -76,15 +86,17 @@ func (spb *StackableProgressBar) NewProxyReader(r io.Reader) io.Reader {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spb *StackableProgressBar) prefix() {
|
func (spb *StackableProgressBar) prefix() {
|
||||||
spb.Bar.ProgressBar.Prefix(fmt.Sprintf("%d items: ", atomic.LoadInt32(&spb.items)))
|
spb.Bar.ProgressBar.Prefix(fmt.Sprintf("%d items: ", spb.items))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spb *StackableProgressBar) Finish() {
|
func (spb *StackableProgressBar) Finish() {
|
||||||
spb.mtx.Lock()
|
spb.mtx.Lock()
|
||||||
defer spb.mtx.Unlock()
|
defer spb.mtx.Unlock()
|
||||||
|
|
||||||
spb.items--
|
if spb.items < 0 {
|
||||||
if spb.items == 0 {
|
spb.items--
|
||||||
|
}
|
||||||
|
if spb.items == 0 && spb.Bar.ProgressBar != nil {
|
||||||
// slef cleanup
|
// slef cleanup
|
||||||
spb.Bar.ProgressBar.Finish()
|
spb.Bar.ProgressBar.Finish()
|
||||||
spb.Bar.ProgressBar = nil
|
spb.Bar.ProgressBar = nil
|
||||||
|
@ -92,7 +104,6 @@ func (spb *StackableProgressBar) Finish() {
|
||||||
spb.total = 0
|
spb.total = 0
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
spb.prefix()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BasicProgressBar is packer's basic progress bar.
|
// BasicProgressBar is packer's basic progress bar.
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
package packer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cheggaaa/pb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func speedyProgressBar(bar *pb.ProgressBar) {
|
||||||
|
bar.SetUnits(pb.U_BYTES)
|
||||||
|
bar.SetRefreshRate(1 * time.Millisecond)
|
||||||
|
bar.NotPrint = true
|
||||||
|
bar.Format("[\x00=\x00>\x00-\x00]")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStackableProgressBar_race(t *testing.T) {
|
||||||
|
bar := &StackableProgressBar{
|
||||||
|
ConfigProgressbarFN: speedyProgressBar,
|
||||||
|
}
|
||||||
|
|
||||||
|
start42Fn := func() { bar.Start(42) }
|
||||||
|
finishFn := func() { bar.Finish() }
|
||||||
|
add21 := func() { bar.Add(21) }
|
||||||
|
// prefix := func() { bar.prefix() }
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
pre func()
|
||||||
|
calls []func()
|
||||||
|
post func()
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
iterations int
|
||||||
|
}{
|
||||||
|
{"all public", fields{nil, []func(){start42Fn, finishFn, add21, add21}, finishFn}, 300},
|
||||||
|
{"add", fields{start42Fn, []func(){add21}, finishFn}, 300},
|
||||||
|
{"add start", fields{start42Fn, []func(){start42Fn, add21}, finishFn}, 300},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
for i := 0; i < tt.iterations; i++ {
|
||||||
|
if tt.fields.pre != nil {
|
||||||
|
tt.fields.pre()
|
||||||
|
}
|
||||||
|
var startWg, endWg sync.WaitGroup
|
||||||
|
startWg.Add(1)
|
||||||
|
endWg.Add(len(tt.fields.calls))
|
||||||
|
for _, call := range tt.fields.calls {
|
||||||
|
call := call
|
||||||
|
go func() {
|
||||||
|
defer endWg.Done()
|
||||||
|
startWg.Wait()
|
||||||
|
call()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
startWg.Done() // everyone is initialized, let's unlock everyone at the same time.
|
||||||
|
// ....
|
||||||
|
endWg.Wait() // wait for all calls to return.
|
||||||
|
if tt.fields.post != nil {
|
||||||
|
tt.fields.post()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue