From ddd96c513b9a3f86e2bd16b61dbda5d610624c38 Mon Sep 17 00:00:00 2001 From: Adrien Delorme Date: Tue, 4 Sep 2018 17:57:21 +0200 Subject: [PATCH] first draft at self refreshing loading bar centralized/controlled by Ui --- common/config.go | 6 ++- common/download.go | 57 +++++++++------------ common/download_test.go | 24 +++++---- common/progress.go | 7 +-- common/step_download.go | 5 +- packer/progressbar.go | 39 ++++++++++++++ packer/rpc/server.go | 3 +- packer/rpc/ui.go | 82 +++++++++++++++++++++++++++++- packer/ui.go | 29 +++++++++++ provisioner/ansible/provisioner.go | 5 ++ provisioner/file/provisioner.go | 12 ++--- 11 files changed, 211 insertions(+), 58 deletions(-) create mode 100644 packer/progressbar.go diff --git a/common/config.go b/common/config.go index 24c6d4233..9527af55b 100644 --- a/common/config.go +++ b/common/config.go @@ -9,6 +9,8 @@ import ( "runtime" "strings" "time" + + "github.com/hashicorp/packer/packer" ) // PackerKeyEnv is used to specify the key interval (delay) between keystrokes @@ -41,7 +43,7 @@ func SupportedProtocol(u *url.URL) bool { // build a dummy NewDownloadClient since this is the only place that valid // protocols are actually exposed. - cli := NewDownloadClient(&DownloadConfig{}, nil) + cli := NewDownloadClient(&DownloadConfig{}, new(packer.NoopProgressBar)) // Iterate through each downloader to see if a protocol was found. ok := false @@ -173,7 +175,7 @@ func FileExistsLocally(original string) bool { // First create a dummy downloader so we can figure out which // protocol to use. - cli := NewDownloadClient(&DownloadConfig{}, nil) + cli := NewDownloadClient(&DownloadConfig{}, new(packer.NoopProgressBar)) d, ok := cli.config.DownloaderMap[u.Scheme] if !ok { return false diff --git a/common/download.go b/common/download.go index f0e88420b..21ff25165 100644 --- a/common/download.go +++ b/common/download.go @@ -10,19 +10,17 @@ import ( "errors" "fmt" "hash" + "io" "log" + "net/http" "net/url" "os" "path" + "path/filepath" "runtime" "strings" -) -// imports related to each Downloader implementation -import ( - "io" - "net/http" - "path/filepath" + "github.com/hashicorp/packer/packer" // imports related to each Downloader implementation ) // DownloadConfig is the configuration given to instantiate a new @@ -81,14 +79,9 @@ func HashForType(t string) hash.Hash { // NewDownloadClient returns a new DownloadClient for the given // configuration. -func NewDownloadClient(c *DownloadConfig, bar ProgressBar) *DownloadClient { +func NewDownloadClient(c *DownloadConfig, bar packer.ProgressBar) *DownloadClient { const mtu = 1500 /* ethernet */ - 20 /* ipv4 */ - 20 /* tcp */ - // If bar is nil, then use a dummy progress bar that doesn't do anything - if bar == nil { - bar = GetDummyProgressBar() - } - // Create downloader map if it hasn't been specified already. if c.DownloaderMap == nil { c.DownloaderMap = map[string]Downloader{ @@ -237,7 +230,7 @@ type HTTPDownloader struct { total uint64 userAgent string - progress ProgressBar + progress packer.ProgressBar } func (d *HTTPDownloader) Cancel() { @@ -331,9 +324,10 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { d.total = d.current + uint64(resp.ContentLength) bar := d.progress - bar.SetTotal64(int64(d.total)) - progressBar := bar.Start() - progressBar.Set64(int64(d.current)) + log.Printf("this %#v", bar) + log.Printf("that") + bar.Start(d.total) + bar.Set(d.current) var buffer [4096]byte for { @@ -343,7 +337,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { } d.current += uint64(n) - progressBar.Set64(int64(d.current)) + bar.Set(d.current) if _, werr := dst.Write(buffer[:n]); werr != nil { return werr @@ -353,7 +347,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { break } } - progressBar.Finish() + bar.Finish() return nil } @@ -374,7 +368,7 @@ type FileDownloader struct { current uint64 total uint64 - progress ProgressBar + progress packer.ProgressBar } func (d *FileDownloader) Progress() uint64 { @@ -466,9 +460,9 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error { d.total = uint64(fi.Size()) bar := d.progress - bar.SetTotal64(int64(d.total)) - progressBar := bar.Start() - progressBar.Set64(int64(d.current)) + + bar.Start(d.total) + bar.Set(d.current) // no bufferSize specified, so copy synchronously. if d.bufferSize == nil { @@ -477,7 +471,7 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error { d.active = false d.current += uint64(n) - progressBar.Set64(int64(d.current)) + bar.Set(d.current) // use a goro in case someone else wants to enable cancel/resume } else { @@ -490,7 +484,7 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error { } d.current += uint64(n) - progressBar.Set64(int64(d.current)) + bar.Set(d.current) } d.active = false e <- err @@ -499,7 +493,7 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error { // ...and we spin until it's done err = <-errch } - progressBar.Finish() + bar.Finish() f.Close() return err } @@ -513,7 +507,7 @@ type SMBDownloader struct { current uint64 total uint64 - progress ProgressBar + progress packer.ProgressBar } func (d *SMBDownloader) Progress() uint64 { @@ -587,9 +581,8 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error { d.total = uint64(fi.Size()) bar := d.progress - bar.SetTotal64(int64(d.total)) - progressBar := bar.Start() - progressBar.Set64(int64(d.current)) + + bar.Start(d.current) // no bufferSize specified, so copy synchronously. if d.bufferSize == nil { @@ -598,7 +591,7 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error { d.active = false d.current += uint64(n) - progressBar.Set64(int64(d.current)) + bar.Set(d.current) // use a goro in case someone else wants to enable cancel/resume } else { @@ -611,7 +604,7 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error { } d.current += uint64(n) - progressBar.Set64(int64(d.current)) + bar.Set(d.current) } d.active = false e <- err @@ -620,7 +613,7 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error { // ...and as usual we spin until it's done err = <-errch } - progressBar.Finish() + bar.Finish() f.Close() return err } diff --git a/common/download_test.go b/common/download_test.go index 153e8daf2..ea6d7585e 100644 --- a/common/download_test.go +++ b/common/download_test.go @@ -12,6 +12,8 @@ import ( "runtime" "strings" "testing" + + "github.com/hashicorp/packer/packer" ) func TestDownloadClientVerifyChecksum(t *testing.T) { @@ -36,7 +38,7 @@ func TestDownloadClientVerifyChecksum(t *testing.T) { Checksum: checksum, } - d := NewDownloadClient(config, nil) + d := NewDownloadClient(config, new(packer.NoopProgressBar)) result, err := d.VerifyChecksum(tf.Name()) if err != nil { t.Fatalf("Verify err: %s", err) @@ -59,7 +61,7 @@ func TestDownloadClient_basic(t *testing.T) { Url: ts.URL + "/basic.txt", TargetPath: tf.Name(), CopyFile: true, - }, nil) + }, new(packer.NoopProgressBar)) path, err := client.Get() if err != nil { @@ -95,7 +97,7 @@ func TestDownloadClient_checksumBad(t *testing.T) { Hash: HashForType("md5"), Checksum: checksum, CopyFile: true, - }, nil) + }, new(packer.NoopProgressBar)) if _, err := client.Get(); err == nil { t.Fatal("should error") @@ -121,7 +123,7 @@ func TestDownloadClient_checksumGood(t *testing.T) { Hash: HashForType("md5"), Checksum: checksum, CopyFile: true, - }, nil) + }, new(packer.NoopProgressBar)) path, err := client.Get() if err != nil { @@ -153,7 +155,7 @@ func TestDownloadClient_checksumNoDownload(t *testing.T) { Hash: HashForType("md5"), Checksum: checksum, CopyFile: true, - }, nil) + }, new(packer.NoopProgressBar)) path, err := client.Get() if err != nil { t.Fatalf("err: %s", err) @@ -183,7 +185,7 @@ func TestDownloadClient_notFound(t *testing.T) { client := NewDownloadClient(&DownloadConfig{ Url: ts.URL + "/not-found.txt", TargetPath: tf.Name(), - }, nil) + }, new(packer.NoopProgressBar)) if _, err := client.Get(); err == nil { t.Fatal("should error") @@ -211,7 +213,7 @@ func TestDownloadClient_resume(t *testing.T) { Url: ts.URL, TargetPath: tf.Name(), CopyFile: true, - }, nil) + }, new(packer.NoopProgressBar)) path, err := client.Get() if err != nil { @@ -273,7 +275,7 @@ func TestDownloadClient_usesDefaultUserAgent(t *testing.T) { CopyFile: true, } - client := NewDownloadClient(config, nil) + client := NewDownloadClient(config, new(packer.NoopProgressBar)) _, err = client.Get() if err != nil { t.Fatal(err) @@ -306,7 +308,7 @@ func TestDownloadClient_setsUserAgent(t *testing.T) { CopyFile: true, } - client := NewDownloadClient(config, nil) + client := NewDownloadClient(config, new(packer.NoopProgressBar)) _, err = client.Get() if err != nil { t.Fatal(err) @@ -405,7 +407,7 @@ func TestDownloadFileUrl(t *testing.T) { CopyFile: false, } - client := NewDownloadClient(config, nil) + client := NewDownloadClient(config, new(packer.NoopProgressBar)) // Verify that we fail to match the checksum _, err = client.Get() @@ -436,7 +438,7 @@ func SimulateFileUriDownload(t *testing.T, uri string) (string, error) { } // go go go - client := NewDownloadClient(config, nil) + client := NewDownloadClient(config, new(packer.NoopProgressBar)) path, err := client.Get() // ignore any non-important checksum errors if it's not a unc path diff --git a/common/progress.go b/common/progress.go index 83b6d0e2b..6cb91b08b 100644 --- a/common/progress.go +++ b/common/progress.go @@ -2,13 +2,14 @@ package common import ( "fmt" + "log" + "reflect" + "time" + "github.com/cheggaaa/pb" "github.com/hashicorp/packer/helper/multistep" "github.com/hashicorp/packer/packer" "github.com/hashicorp/packer/packer/rpc" - "log" - "reflect" - "time" ) // This is the arrow from packer/ui.go -> TargetedUI.prefixLines diff --git a/common/step_download.go b/common/step_download.go index c55e12632..26f6d6b65 100644 --- a/common/step_download.go +++ b/common/step_download.go @@ -64,7 +64,7 @@ func (s *StepDownload) Run(_ context.Context, state multistep.StateBag) multiste ui.Say(fmt.Sprintf("Retrieving %s", s.Description)) // Get a progress bar from the ui so we can hand it off to the download client - bar := GetProgressBar(ui, GetPackerConfigFromStateBag(state)) + bar := ui.ProgressBar() // First try to use any already downloaded file // If it fails, proceed to regular download logic @@ -144,7 +144,8 @@ func (s *StepDownload) download(config *DownloadConfig, state multistep.StateBag ui := state.Get("ui").(packer.Ui) // Get a progress bar and hand it off to the download client - bar := GetProgressBar(ui, GetPackerConfigFromStateBag(state)) + bar := ui.ProgressBar() + log.Printf("new progress bar: %#v, %t", bar, bar == nil) // Create download client with config and progress bar download := NewDownloadClient(config, bar) diff --git a/packer/progressbar.go b/packer/progressbar.go new file mode 100644 index 000000000..1f3784fc1 --- /dev/null +++ b/packer/progressbar.go @@ -0,0 +1,39 @@ +package packer + +import ( + "github.com/cheggaaa/pb" +) + +// ProgressBar allows to graphically display +// a self refreshing progress bar. +// No-op When in machine readable mode. +type ProgressBar interface { + Start(total uint64) + Set(current uint64) + Finish() +} + +type BasicProgressBar struct { + *pb.ProgressBar +} + +func (bpb *BasicProgressBar) Start(total uint64) { + bpb.SetTotal64(int64(total)) + bpb.ProgressBar.Start() +} + +func (bpb *BasicProgressBar) Set(current uint64) { + bpb.ProgressBar.Set64(int64(current)) +} + +var _ ProgressBar = new(BasicProgressBar) + +// NoopProgressBar is a silent progress bar +type NoopProgressBar struct { +} + +func (bpb *NoopProgressBar) Start(_ uint64) {} +func (bpb *NoopProgressBar) Set(_ uint64) {} +func (bpb *NoopProgressBar) Finish() {} + +var _ ProgressBar = new(NoopProgressBar) diff --git a/packer/rpc/server.go b/packer/rpc/server.go index 8982ec0e7..dff8854e2 100644 --- a/packer/rpc/server.go +++ b/packer/rpc/server.go @@ -114,7 +114,8 @@ func (s *Server) RegisterProvisioner(p packer.Provisioner) { func (s *Server) RegisterUi(ui packer.Ui) { s.server.RegisterName(DefaultUiEndpoint, &UiServer{ - ui: ui, + ui: ui, + register: s.server.RegisterName, }) } diff --git a/packer/rpc/ui.go b/packer/rpc/ui.go index 1c6356a65..50dce61fa 100644 --- a/packer/rpc/ui.go +++ b/packer/rpc/ui.go @@ -2,6 +2,7 @@ package rpc import ( "log" + "math/rand" "net/rpc" "github.com/hashicorp/packer/packer" @@ -14,10 +15,13 @@ type Ui struct { endpoint string } +var _ packer.Ui = new(Ui) + // UiServer wraps a packer.Ui implementation and makes it exportable // as part of a Golang RPC server. type UiServer struct { - ui packer.Ui + ui packer.Ui + register func(name string, rcvr interface{}) error } // The arguments sent to Ui.Machine @@ -60,6 +64,38 @@ func (u *Ui) Say(message string) { } } +func (u *Ui) ProgressBar() packer.ProgressBar { + var callMeMaybe string + if err := u.client.Call("Ui.ProgressBar", nil, &callMeMaybe); err != nil { + log.Printf("Error in Ui RPC call: %s", err) + return new(packer.NoopProgressBar) + } + + return &RemoteProgressBarClient{ + id: callMeMaybe, + client: u.client, + } +} + +type RemoteProgressBarClient struct { + id string + client *rpc.Client +} + +var _ packer.ProgressBar = new(RemoteProgressBarClient) + +func (pb *RemoteProgressBarClient) Start(total uint64) { + pb.client.Call(pb.id+".Start", total, new(interface{})) +} + +func (pb *RemoteProgressBarClient) Set(current uint64) { + pb.client.Call(pb.id+".Set", current, new(interface{})) +} + +func (pb *RemoteProgressBarClient) Finish() { + pb.client.Call(pb.id+".Finish", nil, new(interface{})) +} + func (u *UiServer) Ask(query string, reply *string) (err error) { *reply, err = u.ui.Ask(query) return @@ -91,3 +127,47 @@ func (u *UiServer) Say(message *string, reply *interface{}) error { *reply = nil return nil } + +func RandStringBytes(n int) string { + const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} + +func (u *UiServer) ProgressBar(_ *string, reply *interface{}) error { + bar := u.ui.ProgressBar() + + callbackName := RandStringBytes(6) + + log.Printf("registering progressbar %s", callbackName) + err := u.register(callbackName, &RemoteProgressBarServer{bar}) + if err != nil { + log.Printf("failed to register a new progress bar rpc server, %s", err) + return err + } + *reply = callbackName + return nil +} + +type RemoteProgressBarServer struct { + pb packer.ProgressBar +} + +func (pb *RemoteProgressBarServer) Finish(_ string, _ *interface{}) error { + pb.pb.Finish() + return nil +} + +func (pb *RemoteProgressBarServer) Start(total uint64, _ *interface{}) error { + pb.pb.Start(total) + return nil +} + +func (pb *RemoteProgressBarServer) Set(current uint64, _ *interface{}) error { + pb.pb.Set(current) + return nil +} diff --git a/packer/ui.go b/packer/ui.go index c16a2eae0..e6ceba0e3 100644 --- a/packer/ui.go +++ b/packer/ui.go @@ -15,6 +15,8 @@ import ( "syscall" "time" "unicode" + + "github.com/cheggaaa/pb" ) type UiColor uint @@ -37,6 +39,7 @@ type Ui interface { Message(string) Error(string) Machine(string, ...string) + ProgressBar() ProgressBar } // ColoredUi is a UI that is colored using terminal colors. @@ -46,6 +49,8 @@ type ColoredUi struct { Ui Ui } +var _ Ui = new(ColoredUi) + // TargetedUI is a UI that wraps another UI implementation and modifies // the output to indicate a specific target. Specifically, all Say output // is prefixed with the target name. Message output is not prefixed but @@ -56,6 +61,8 @@ type TargetedUI struct { Ui Ui } +var _ Ui = new(TargetedUI) + // The BasicUI is a UI that reads and writes from a standard Go reader // and writer. It is safe to be called from multiple goroutines. Machine // readable output is simply logged for this UI. @@ -68,12 +75,21 @@ type BasicUi struct { scanner *bufio.Scanner } +var _ Ui = new(BasicUi) + +func (bu *BasicUi) ProgressBar() ProgressBar { + log.Printf("hehey !") + return &BasicProgressBar{ProgressBar: pb.New(0)} +} + // MachineReadableUi is a UI that only outputs machine-readable output // to the given Writer. type MachineReadableUi struct { Writer io.Writer } +var _ Ui = new(MachineReadableUi) + func (u *ColoredUi) Ask(query string) (string, error) { return u.Ui.Ask(u.colorize(query, u.Color, true)) } @@ -100,6 +116,10 @@ func (u *ColoredUi) Machine(t string, args ...string) { u.Ui.Machine(t, args...) } +func (u *ColoredUi) ProgressBar() ProgressBar { + return u.Ui.ProgressBar() //TODO(adrien): color me +} + func (u *ColoredUi) colorize(message string, color UiColor, bold bool) string { if !u.supportsColors() { return message @@ -153,6 +173,10 @@ func (u *TargetedUI) Machine(t string, args ...string) { u.Ui.Machine(fmt.Sprintf("%s,%s", u.Target, t), args...) } +func (u *TargetedUI) ProgressBar() ProgressBar { + return u.Ui.ProgressBar() +} + func (u *TargetedUI) prefixLines(arrow bool, message string) string { arrowText := "==>" if !arrow { @@ -305,3 +329,8 @@ func (u *MachineReadableUi) Machine(category string, args ...string) { } } } + +func (u *MachineReadableUi) ProgressBar() ProgressBar { + panic("MachineReadableUi") + return nil // no-op +} diff --git a/provisioner/ansible/provisioner.go b/provisioner/ansible/provisioner.go index 574ec2b39..d8d82289f 100644 --- a/provisioner/ansible/provisioner.go +++ b/provisioner/ansible/provisioner.go @@ -612,3 +612,8 @@ func (ui *Ui) Machine(t string, args ...string) { ui.ui.Machine(t, args...) <-ui.sem } + +func (ui *Ui) ProgressBar() packer.ProgressBar { + panic("to implement") + return nil // TODO +} diff --git a/provisioner/file/provisioner.go b/provisioner/file/provisioner.go index 13e2aa8e6..a05c2ed9c 100644 --- a/provisioner/file/provisioner.go +++ b/provisioner/file/provisioner.go @@ -127,12 +127,12 @@ func (p *Provisioner) ProvisionDownload(ui packer.Ui, comm packer.Communicator) defer f.Close() // Get a default progress bar - pb := common.GetProgressBar(ui, &p.config.PackerConfig) - bar := pb.Start() - defer bar.Finish() + pb := packer.NoopProgressBar{} + pb.Start(0) + defer pb.Finish() // Create MultiWriter for the current progress - pf := io.MultiWriter(f, bar) + pf := io.MultiWriter(f) // Download the file if err = comm.Download(src, pf); err != nil { @@ -176,8 +176,8 @@ func (p *Provisioner) ProvisionUpload(ui packer.Ui, comm packer.Communicator) er } // Get a default progress bar - pb := common.GetProgressBar(ui, &p.config.PackerConfig) - bar := pb.Start() + bar := ui.ProgressBar() + bar.Start(0) defer bar.Finish() // Create ProxyReader for the current progress