diff --git a/packer/progressbar.go b/packer/progressbar.go index f08928767..e7405b074 100644 --- a/packer/progressbar.go +++ b/packer/progressbar.go @@ -4,7 +4,6 @@ import ( "fmt" "io" "sync" - "sync/atomic" "github.com/cheggaaa/pb" ) @@ -36,16 +35,25 @@ type StackableProgressBar struct { items int32 total int64 - started bool + started bool + ConfigProgressbarFN func(*pb.ProgressBar) } var _ ProgressBar = new(StackableProgressBar) -func (spb *StackableProgressBar) start() { - spb.Bar.ProgressBar = pb.New(0) - spb.Bar.ProgressBar.SetUnits(pb.U_BYTES) +func defaultProgressbarConfigFn(bar *pb.ProgressBar) { + bar.SetUnits(pb.U_BYTES) +} - spb.Bar.ProgressBar.Start() +func (spb *StackableProgressBar) start() { + bar := pb.New(0) + if spb.ConfigProgressbarFN == nil { + spb.ConfigProgressbarFN = defaultProgressbarConfigFn + } + spb.ConfigProgressbarFN(bar) + + bar.Start() + spb.Bar.ProgressBar = bar spb.started = true } @@ -66,7 +74,9 @@ func (spb *StackableProgressBar) Start(total int64) { func (spb *StackableProgressBar) Add(total int64) { spb.mtx.Lock() defer spb.mtx.Unlock() - spb.Bar.Add(total) + if spb.Bar.ProgressBar != nil { + spb.Bar.Add(total) + } } func (spb *StackableProgressBar) NewProxyReader(r io.Reader) io.Reader { @@ -76,15 +86,17 @@ func (spb *StackableProgressBar) NewProxyReader(r io.Reader) io.Reader { } func (spb *StackableProgressBar) prefix() { - spb.Bar.ProgressBar.Prefix(fmt.Sprintf("%d items: ", atomic.LoadInt32(&spb.items))) + spb.Bar.ProgressBar.Prefix(fmt.Sprintf("%d items: ", spb.items)) } func (spb *StackableProgressBar) Finish() { spb.mtx.Lock() defer spb.mtx.Unlock() - spb.items-- - if spb.items == 0 { + if spb.items < 0 { + spb.items-- + } + if spb.items == 0 && spb.Bar.ProgressBar != nil { // slef cleanup spb.Bar.ProgressBar.Finish() spb.Bar.ProgressBar = nil @@ -92,7 +104,6 @@ func (spb *StackableProgressBar) Finish() { spb.total = 0 return } - spb.prefix() } // BasicProgressBar is packer's basic progress bar. diff --git a/packer/progressbar_test.go b/packer/progressbar_test.go new file mode 100644 index 000000000..d19827989 --- /dev/null +++ b/packer/progressbar_test.go @@ -0,0 +1,71 @@ +package packer + +import ( + "sync" + "testing" + "time" + + "github.com/cheggaaa/pb" +) + +func speedyProgressBar(bar *pb.ProgressBar) { + bar.SetUnits(pb.U_BYTES) + bar.SetRefreshRate(1 * time.Millisecond) + bar.NotPrint = true + bar.Format("[\x00=\x00>\x00-\x00]") +} + +func TestStackableProgressBar_race(t *testing.T) { + bar := &StackableProgressBar{ + ConfigProgressbarFN: speedyProgressBar, + } + + start42Fn := func() { bar.Start(42) } + finishFn := func() { bar.Finish() } + add21 := func() { bar.Add(21) } + // prefix := func() { bar.prefix() } + + type fields struct { + pre func() + calls []func() + post func() + } + tests := []struct { + name string + fields fields + iterations int + }{ + {"all public", fields{nil, []func(){start42Fn, finishFn, add21, add21}, finishFn}, 300}, + {"add", fields{start42Fn, []func(){add21}, finishFn}, 300}, + {"add start", fields{start42Fn, []func(){start42Fn, add21}, finishFn}, 300}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + for i := 0; i < tt.iterations; i++ { + if tt.fields.pre != nil { + tt.fields.pre() + } + var startWg, endWg sync.WaitGroup + startWg.Add(1) + endWg.Add(len(tt.fields.calls)) + for _, call := range tt.fields.calls { + call := call + go func() { + defer endWg.Done() + startWg.Wait() + call() + }() + } + startWg.Done() // everyone is initialized, let's unlock everyone at the same time. + // .... + endWg.Wait() // wait for all calls to return. + if tt.fields.post != nil { + tt.fields.post() + } + } + }) + } + +}