use proxy reader for download progress & stop storing total/current in downloaders

This commit is contained in:
Adrien Delorme 2018-09-05 12:50:53 +02:00
parent ddd96c513b
commit fd7cb47adc
4 changed files with 90 additions and 95 deletions

View File

@ -85,10 +85,10 @@ func NewDownloadClient(c *DownloadConfig, bar packer.ProgressBar) *DownloadClien
// Create downloader map if it hasn't been specified already.
if c.DownloaderMap == nil {
c.DownloaderMap = map[string]Downloader{
"file": &FileDownloader{progress: bar, bufferSize: nil},
"http": &HTTPDownloader{progress: bar, userAgent: c.UserAgent},
"https": &HTTPDownloader{progress: bar, userAgent: c.UserAgent},
"smb": &SMBDownloader{progress: bar, bufferSize: nil},
"file": &FileDownloader{progressBar: bar, bufferSize: nil},
"http": &HTTPDownloader{progressBar: bar, userAgent: c.UserAgent},
"https": &HTTPDownloader{progressBar: bar, userAgent: c.UserAgent},
"smb": &SMBDownloader{progressBar: bar, bufferSize: nil},
}
}
return &DownloadClient{config: c}
@ -99,8 +99,7 @@ func NewDownloadClient(c *DownloadConfig, bar packer.ProgressBar) *DownloadClien
type Downloader interface {
Resume()
Cancel()
Progress() uint64
Total() uint64
ProgressBar() packer.ProgressBar
}
// A LocalDownloader is responsible for converting a uri to a local path
@ -226,11 +225,9 @@ func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
// HTTPDownloader is an implementation of Downloader that downloads
// files over HTTP.
type HTTPDownloader struct {
current uint64
total uint64
userAgent string
progress packer.ProgressBar
progressBar packer.ProgressBar
}
func (d *HTTPDownloader) Cancel() {
@ -249,8 +246,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
return err
}
// Reset our progress
d.current = 0
var current uint64
// Make the request. We first make a HEAD request so we can check
// if the server supports range queries. If the server/URL doesn't
@ -294,7 +290,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
if _, err = dst.Seek(0, os.SEEK_END); err == nil {
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size()))
d.current = uint64(fi.Size())
current = uint64(fi.Size())
}
}
}
@ -321,24 +317,21 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
return fmt.Errorf("HTTP error: %s", err.Error())
}
d.total = d.current + uint64(resp.ContentLength)
total := current + uint64(resp.ContentLength)
bar := d.progress
log.Printf("this %#v", bar)
log.Printf("that")
bar.Start(d.total)
bar.Set(d.current)
bar := d.ProgressBar()
bar.Start(total)
bar.Add(current)
body := bar.NewProxyReader(resp.Body)
var buffer [4096]byte
for {
n, err := resp.Body.Read(buffer[:])
n, err := body.Read(buffer[:])
if err != nil && err != io.EOF {
return err
}
d.current += uint64(n)
bar.Set(d.current)
if _, werr := dst.Write(buffer[:n]); werr != nil {
return werr
}
@ -351,32 +344,13 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
return nil
}
func (d *HTTPDownloader) Progress() uint64 {
return d.current
}
func (d *HTTPDownloader) Total() uint64 {
return d.total
}
// FileDownloader is an implementation of Downloader that downloads
// files using the regular filesystem.
type FileDownloader struct {
bufferSize *uint
active bool
current uint64
total uint64
progress packer.ProgressBar
}
func (d *FileDownloader) Progress() uint64 {
return d.current
}
func (d *FileDownloader) Total() uint64 {
return d.total
active bool
progressBar packer.ProgressBar
}
func (d *FileDownloader) Cancel() {
@ -443,7 +417,6 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
}
/* download the file using the operating system's facilities */
d.current = 0
d.active = true
f, err := os.Open(realpath)
@ -457,38 +430,31 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
if err != nil {
return err
}
d.total = uint64(fi.Size())
bar := d.progress
bar := d.ProgressBar()
bar.Start(d.total)
bar.Set(d.current)
bar.Start(uint64(fi.Size()))
fProxy := bar.NewProxyReader(f)
// no bufferSize specified, so copy synchronously.
if d.bufferSize == nil {
var n int64
n, err = io.Copy(dst, f)
_, err = io.Copy(dst, fProxy)
d.active = false
d.current += uint64(n)
bar.Set(d.current)
// use a goro in case someone else wants to enable cancel/resume
} else {
errch := make(chan error)
go func(d *FileDownloader, r io.Reader, w io.Writer, e chan error) {
for d.active {
n, err := io.CopyN(w, r, int64(*d.bufferSize))
_, err := io.CopyN(w, r, int64(*d.bufferSize))
if err != nil {
break
}
d.current += uint64(n)
bar.Set(d.current)
}
d.active = false
e <- err
}(d, f, dst, errch)
}(d, fProxy, dst, errch)
// ...and we spin until it's done
err = <-errch
@ -503,19 +469,8 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
type SMBDownloader struct {
bufferSize *uint
active bool
current uint64
total uint64
progress packer.ProgressBar
}
func (d *SMBDownloader) Progress() uint64 {
return d.current
}
func (d *SMBDownloader) Total() uint64 {
return d.total
active bool
progressBar packer.ProgressBar
}
func (d *SMBDownloader) Cancel() {
@ -564,7 +519,6 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
}
/* Open up the "\\"-prefixed path using the Windows filesystem */
d.current = 0
d.active = true
f, err := os.Open(realpath)
@ -578,37 +532,31 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
if err != nil {
return err
}
d.total = uint64(fi.Size())
bar := d.progress
bar := d.ProgressBar()
bar.Start(d.current)
bar.Start(uint64(fi.Size()))
fProxy := bar.NewProxyReader(f)
// no bufferSize specified, so copy synchronously.
if d.bufferSize == nil {
var n int64
n, err = io.Copy(dst, f)
_, err = io.Copy(dst, fProxy)
d.active = false
d.current += uint64(n)
bar.Set(d.current)
// use a goro in case someone else wants to enable cancel/resume
} else {
errch := make(chan error)
go func(d *SMBDownloader, r io.Reader, w io.Writer, e chan error) {
for d.active {
n, err := io.CopyN(w, r, int64(*d.bufferSize))
_, err := io.CopyN(w, r, int64(*d.bufferSize))
if err != nil {
break
}
d.current += uint64(n)
bar.Set(d.current)
}
d.active = false
e <- err
}(d, f, dst, errch)
}(d, fProxy, dst, errch)
// ...and as usual we spin until it's done
err = <-errch
@ -617,3 +565,7 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
f.Close()
return err
}
func (d *HTTPDownloader) ProgressBar() packer.ProgressBar { return d.progressBar }
func (d *FileDownloader) ProgressBar() packer.ProgressBar { return d.progressBar }
func (d *SMBDownloader) ProgressBar() packer.ProgressBar { return d.progressBar }

View File

@ -1,6 +1,8 @@
package packer
import (
"io"
"github.com/cheggaaa/pb"
)
@ -9,7 +11,8 @@ import (
// No-op When in machine readable mode.
type ProgressBar interface {
Start(total uint64)
Set(current uint64)
Add(current uint64)
NewProxyReader(r io.Reader) (proxy io.Reader)
Finish()
}
@ -22,8 +25,20 @@ func (bpb *BasicProgressBar) Start(total uint64) {
bpb.ProgressBar.Start()
}
func (bpb *BasicProgressBar) Set(current uint64) {
bpb.ProgressBar.Set64(int64(current))
func (bpb *BasicProgressBar) Add(current uint64) {
bpb.ProgressBar.Add64(int64(current))
}
func (bpb *BasicProgressBar) NewProxyReader(r io.Reader) io.Reader {
return &ProxyReader{
Reader: r,
ProgressBar: bpb,
}
}
func (bpb *BasicProgressBar) NewProxyReadCloser(r io.ReadCloser) io.ReadCloser {
return &ProxyReader{
Reader: r,
ProgressBar: bpb,
}
}
var _ ProgressBar = new(BasicProgressBar)
@ -32,8 +47,31 @@ var _ ProgressBar = new(BasicProgressBar)
type NoopProgressBar struct {
}
func (bpb *NoopProgressBar) Start(_ uint64) {}
func (bpb *NoopProgressBar) Set(_ uint64) {}
func (bpb *NoopProgressBar) Finish() {}
func (npb *NoopProgressBar) Start(uint64) {}
func (npb *NoopProgressBar) Add(uint64) {}
func (npb *NoopProgressBar) Finish() {}
func (npb *NoopProgressBar) NewProxyReader(r io.Reader) io.Reader { return r }
func (npb *NoopProgressBar) NewProxyReadCloser(r io.ReadCloser) io.ReadCloser { return r }
var _ ProgressBar = new(NoopProgressBar)
// ProxyReader implements io.ReadCloser but sends
// count of read bytes to progress bar
type ProxyReader struct {
io.Reader
ProgressBar
}
func (r *ProxyReader) Read(p []byte) (n int, err error) {
n, err = r.Reader.Read(p)
r.ProgressBar.Add(uint64(n))
return
}
// Close the reader if it implements io.Closer
func (r *ProxyReader) Close() (err error) {
if closer, ok := r.Reader.(io.Closer); ok {
return closer.Close()
}
return
}

View File

@ -1,6 +1,7 @@
package rpc
import (
"io"
"log"
"math/rand"
"net/rpc"
@ -88,14 +89,18 @@ func (pb *RemoteProgressBarClient) Start(total uint64) {
pb.client.Call(pb.id+".Start", total, new(interface{}))
}
func (pb *RemoteProgressBarClient) Set(current uint64) {
pb.client.Call(pb.id+".Set", current, new(interface{}))
func (pb *RemoteProgressBarClient) Add(current uint64) {
pb.client.Call(pb.id+".Add", current, new(interface{}))
}
func (pb *RemoteProgressBarClient) Finish() {
pb.client.Call(pb.id+".Finish", nil, new(interface{}))
}
func (pb *RemoteProgressBarClient) NewProxyReader(r io.Reader) io.Reader {
return &packer.ProxyReader{Reader: r, ProgressBar: pb}
}
func (u *UiServer) Ask(query string, reply *string) (err error) {
*reply, err = u.ui.Ask(query)
return
@ -128,7 +133,7 @@ func (u *UiServer) Say(message *string, reply *interface{}) error {
return nil
}
func RandStringBytes(n int) string {
func RandStringBytes(n int) string { // TODO(azr): remove before merging
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
b := make([]byte, n)
@ -167,7 +172,7 @@ func (pb *RemoteProgressBarServer) Start(total uint64, _ *interface{}) error {
return nil
}
func (pb *RemoteProgressBarServer) Set(current uint64, _ *interface{}) error {
pb.pb.Set(current)
func (pb *RemoteProgressBarServer) Add(current uint64, _ *interface{}) error {
pb.pb.Add(current)
return nil
}

View File

@ -177,7 +177,7 @@ func (p *Provisioner) ProvisionUpload(ui packer.Ui, comm packer.Communicator) er
// Get a default progress bar
bar := ui.ProgressBar()
bar.Start(0)
bar.Start(uint64(info.Size()))
defer bar.Finish()
// Create ProxyReader for the current progress