From ee7b8565cc45b5daeced15ee6e93cb58ea7424eb Mon Sep 17 00:00:00 2001 From: Justin Hawkins Date: Mon, 18 Apr 2022 19:53:07 +0930 Subject: [PATCH] Refactor to prevent races on access of downloads --- download/download.go | 125 ++++++++++++++++++++++++++----------------- main.go | 82 +++++++++++++--------------- 2 files changed, 114 insertions(+), 93 deletions(-) diff --git a/download/download.go b/download/download.go index 619896c..f71907b 100644 --- a/download/download.go +++ b/download/download.go @@ -34,78 +34,92 @@ type Download struct { Percent float32 `json:"percent"` Log []string `json:"log"` Config *config.Config - mutex sync.Mutex } -type Downloads []*Download +type Manager struct { + Downloads []*Download + MaxPerDomain int + Lock sync.Mutex +} var CanStopDownload = false var downloadId int32 = 0 -// StartQueued starts any downloads that have been queued, we would not exceed +func (m *Manager) ManageQueue() { + for { + + m.Lock.Lock() + + m.startQueued(m.MaxPerDomain) + m.cleanup() + m.Lock.Unlock() + + time.Sleep(time.Second) + } +} + +// startQueued starts any downloads that have been queued, we would not exceed // maxRunning. If maxRunning is 0, there is no limit. -func (dls Downloads) StartQueued(maxRunning int) { +func (m *Manager) startQueued(maxRunning int) { active := make(map[string]int) - for _, dl := range dls { - - dl.mutex.Lock() + for _, dl := range m.Downloads { if dl.State == "downloading" { active[dl.domain()]++ } - dl.mutex.Unlock() } - for _, dl := range dls { - - dl.mutex.Lock() + for _, dl := range m.Downloads { if dl.State == "queued" && (maxRunning == 0 || active[dl.domain()] < maxRunning) { dl.State = "downloading" active[dl.domain()]++ log.Printf("Starting download for id:%d (%s)", dl.Id, dl.Url) - dl.mutex.Unlock() - go func() { dl.Begin() }() - } else { - dl.mutex.Unlock() + go func() { + m.Begin(dl.Id) + }() } } } -// Cleanup removes old downloads from the list. Hardcoded to remove them one hour +// cleanup removes old downloads from the list. Hardcoded to remove them one hour // completion. -func (dls Downloads) Cleanup() Downloads { - newDLs := Downloads{} - for _, dl := range dls { - - dl.mutex.Lock() +func (m *Manager) cleanup() { + newDLs := []*Download{} + for _, dl := range m.Downloads { if dl.Finished && time.Since(dl.FinishedTS) > time.Duration(time.Hour) { // do nothing } else { newDLs = append(newDLs, dl) } - dl.mutex.Unlock() } - return newDLs + m.Downloads = newDLs +} + +func (m *Manager) DlById(id int) *Download { + for _, dl := range m.Downloads { + if dl.Id == id { + return dl + } + } + return nil } // Queue queues a download -func (dl *Download) Queue() { - - dl.mutex.Lock() - defer dl.mutex.Unlock() +func (m *Manager) Queue(id int) { + dl := m.DlById(id) dl.State = "queued" } -func NewDownload(conf *config.Config, url string) *Download { +func (m *Manager) NewDownload(conf *config.Config, url string) int { atomic.AddInt32(&downloadId, 1) dl := Download{ Config: conf, @@ -119,24 +133,30 @@ func NewDownload(conf *config.Config, url string) *Download { Percent: 0.0, Log: make([]string, 0, 1000), } - return &dl + m.Downloads = append(m.Downloads, &dl) + return int(downloadId) } -func (dl *Download) Stop() { +func (m *Manager) AppendLog(id int, text string) { + dl := m.DlById(id) + dl.Log = append(dl.Log, text) +} + +// Stop the download. +func (m *Manager) Stop(id int) { if !CanStopDownload { log.Print("attempted to stop download on a platform that it is not currently supported on - please report this as a bug") os.Exit(1) } + dl := m.DlById(id) + log.Printf("stopping the download") - dl.mutex.Lock() dl.Log = append(dl.Log, "aborted by user") - defer dl.mutex.Unlock() dl.Process.Kill() } func (dl *Download) domain() string { - // note that we expect to already have the mutex locked by the caller url, err := url.Parse(dl.Url) if err != nil { log.Printf("Unknown domain for url: %s", dl.Url) @@ -149,9 +169,10 @@ func (dl *Download) domain() string { // Begin starts a download, by starting the command specified in the DownloadProfile. // It blocks until the download is complete. -func (dl *Download) Begin() { +func (m *Manager) Begin(id int) { + m.Lock.Lock() - dl.mutex.Lock() + dl := m.DlById(id) dl.State = "downloading" cmdSlice := []string{} @@ -171,6 +192,8 @@ func (dl *Download) Begin() { dl.Finished = true dl.FinishedTS = time.Now() dl.Log = append(dl.Log, fmt.Sprintf("error setting up stdout pipe: %v", err)) + m.Lock.Unlock() + return } @@ -180,6 +203,8 @@ func (dl *Download) Begin() { dl.Finished = true dl.FinishedTS = time.Now() dl.Log = append(dl.Log, fmt.Sprintf("error setting up stderr pipe: %v", err)) + m.Lock.Unlock() + return } @@ -190,31 +215,35 @@ func (dl *Download) Begin() { dl.Finished = true dl.FinishedTS = time.Now() dl.Log = append(dl.Log, fmt.Sprintf("error starting command '%s': %v", dl.DownloadProfile.Command, err)) + m.Lock.Unlock() + return } dl.Process = cmd.Process var wg sync.WaitGroup - dl.mutex.Unlock() - wg.Add(2) + + m.Lock.Unlock() + go func() { defer wg.Done() - dl.updateDownload(stdout) + m.updateDownload(dl, stdout) }() go func() { defer wg.Done() - dl.updateDownload(stderr) + m.updateDownload(dl, stderr) }() wg.Wait() cmd.Wait() - dl.mutex.Lock() log.Printf("Process finished for id: %d (%v)", dl.Id, cmd) + m.Lock.Lock() + dl.State = "complete" dl.Finished = true dl.FinishedTS = time.Now() @@ -223,11 +252,12 @@ func (dl *Download) Begin() { if dl.ExitCode != 0 { dl.State = "failed" } - dl.mutex.Unlock() + + m.Lock.Unlock() } -func (dl *Download) updateDownload(r io.Reader) { +func (m *Manager) updateDownload(dl *Download, r io.Reader) { // XXX not sure if we might get a partial line? buf := make([]byte, 1024) for { @@ -242,15 +272,16 @@ func (dl *Download) updateDownload(r io.Reader) { continue } - dl.mutex.Lock() + m.Lock.Lock() // append the raw log dl.Log = append(dl.Log, l) - dl.mutex.Unlock() - // look for the percent and eta and other metadata dl.updateMetadata(l) + + m.Lock.Unlock() + } } if err != nil { @@ -261,10 +292,6 @@ func (dl *Download) updateDownload(r io.Reader) { func (dl *Download) updateMetadata(s string) { - dl.mutex.Lock() - - defer dl.mutex.Unlock() - // [download] 49.7% of ~15.72MiB at 5.83MiB/s ETA 00:07 // [download] 99.3% of ~1.42GiB at 320.87KiB/s ETA 00:07 (frag 212/214) etaRE := regexp.MustCompile(`download.+ETA +(\d\d:\d\d(?::\d\d)?)`) diff --git a/main.go b/main.go index 2eea312..2023764 100644 --- a/main.go +++ b/main.go @@ -19,7 +19,7 @@ import ( "github.com/tardisx/gropple/version" ) -var downloads download.Downloads +var dm *download.Manager var downloadId = 0 var configService *config.ConfigService @@ -60,9 +60,11 @@ func main() { log.Fatal(err) } log.Printf("Configuration loaded from %s", configService.ConfigPath) - } + // create the download manager + dm = &download.Manager{MaxPerDomain: configService.Config.Server.MaximumActiveDownloads} + r := mux.NewRouter() r.HandleFunc("/", homeHandler) r.HandleFunc("/static/{filename}", staticHandler) @@ -97,13 +99,7 @@ func main() { // start downloading queued downloads when slots available, and clean up // old entries - go func() { - for { - downloads.StartQueued(configService.Config.Server.MaximumActiveDownloads) - downloads = downloads.Cleanup() - time.Sleep(time.Second) - } - }() + go dm.ManageQueue() log.Printf("Visit %s for details on installing the bookmarklet and to check status", configService.Config.Server.Address) log.Fatal(srv.ListenAndServe()) @@ -138,8 +134,11 @@ func homeHandler(w http.ResponseWriter, r *http.Request) { Version version.Info } + dm.Lock.Lock() + defer dm.Lock.Unlock() + info := Info{ - Downloads: downloads, + Downloads: dm.Downloads, BookmarkletURL: template.URL(bookmarkletURL), Config: configService.Config, Version: versionInfo.GetInfo(), @@ -220,13 +219,10 @@ func fetchInfoOneRESTHandler(w http.ResponseWriter, r *http.Request) { return } - // find the download - var thisDownload *download.Download - for _, dl := range downloads { - if dl.Id == id { - thisDownload = dl - } - } + dm.Lock.Lock() + defer dm.Lock.Unlock() + + thisDownload := dm.DlById(id) if thisDownload == nil { http.NotFound(w, r) return @@ -263,8 +259,8 @@ func fetchInfoOneRESTHandler(w http.ResponseWriter, r *http.Request) { } // set the profile thisDownload.DownloadProfile = *profile + dm.Queue(thisDownload.Id) - thisDownload.Queue() succRes := successResponse{Success: true, Message: "download started"} succResB, _ := json.Marshal(succRes) w.Write(succResB) @@ -272,7 +268,7 @@ func fetchInfoOneRESTHandler(w http.ResponseWriter, r *http.Request) { } if thisReq.Action == "stop" { - thisDownload.Stop() + dm.Stop(thisDownload.Id) succRes := successResponse{Success: true, Message: "download stopped"} succResB, _ := json.Marshal(succRes) w.Write(succResB) @@ -290,7 +286,10 @@ func fetchInfoOneRESTHandler(w http.ResponseWriter, r *http.Request) { } func fetchInfoRESTHandler(w http.ResponseWriter, r *http.Request) { - b, _ := json.Marshal(downloads) + + dm.Lock.Lock() + defer dm.Lock.Unlock() + b, _ := json.Marshal(dm.Downloads) w.Write(b) } @@ -301,25 +300,26 @@ func fetchHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) idString := vars["id"] + dm.Lock.Lock() + defer dm.Lock.Unlock() + idInt, err := strconv.ParseInt(idString, 10, 32) + + // existing, load it up if err == nil && idInt > 0 { - for _, dl := range downloads { - if dl.Id == int(idInt) { - t, err := template.ParseFS(webFS, "web/layout.tmpl", "web/popup.html") - if err != nil { - panic(err) - } - - templateData := map[string]interface{}{"dl": dl, "config": configService.Config, "canStop": download.CanStopDownload} - - err = t.ExecuteTemplate(w, "layout", templateData) - if err != nil { - panic(err) - } - return - } + dl := dm.DlById(int(idInt)) + t, err := template.ParseFS(webFS, "web/layout.tmpl", "web/popup.html") + if err != nil { + panic(err) } + templateData := map[string]interface{}{"dl": dl, "config": configService.Config, "canStop": download.CanStopDownload} + + err = t.ExecuteTemplate(w, "layout", templateData) + if err != nil { + panic(err) + } + return } query := r.URL.Query() @@ -339,22 +339,16 @@ func fetchHandler(w http.ResponseWriter, r *http.Request) { } // create the record - newDownload := download.NewDownload(configService.Config, url[0]) - downloads = append(downloads, newDownload) - // XXX atomic ^^ - newDownload.Log = append(newDownload.Log, "start of log...") - - // go func() { - // newDownload.Begin() - // }() + newDownloadId := dm.NewDownload(configService.Config, url[0]) + dm.AppendLog(newDownloadId, "start of log...") t, err := template.ParseFS(webFS, "web/layout.tmpl", "web/popup.html") if err != nil { panic(err) } - templateData := map[string]interface{}{"Version": versionInfo, "dl": newDownload, "config": configService.Config, "canStop": download.CanStopDownload} + templateData := map[string]interface{}{"Version": versionInfo.GetInfo(), "dl": dm.DlById(newDownloadId), "config": configService.Config, "canStop": download.CanStopDownload} err = t.ExecuteTemplate(w, "layout", templateData) if err != nil {