From 2e94eb6a8744b159ebc371091479510c2a7056b1 Mon Sep 17 00:00:00 2001 From: Justin Hawkins Date: Thu, 7 Apr 2022 20:39:14 +0930 Subject: [PATCH] Create a ConfigService struct to handle managing our config. --- config/config.go | 65 +++++++++++++++++++++++---------------- download/download_test.go | 4 ++- main.go | 43 ++++++++++++++------------ 3 files changed, 66 insertions(+), 46 deletions(-) diff --git a/config/config.go b/config/config.go index 8d3e1d0..9ca45f6 100644 --- a/config/config.go +++ b/config/config.go @@ -37,13 +37,18 @@ type Config struct { DownloadProfiles []DownloadProfile `yaml:"profiles" json:"profiles"` } -func TestConfig() *Config { - config := DefaultConfig() - config.DownloadProfiles = []DownloadProfile{{Name: "test profile", Command: "sleep", Args: []string{"5"}}} - return config +// ConfigService is a struct to handle configuration requests, allowing for the +// location that config files are loaded to be customised. +type ConfigService struct { + Config *Config } -func DefaultConfig() *Config { +func (cs *ConfigService) LoadTestConfig() { + cs.LoadDefaultConfig() + cs.Config.DownloadProfiles = []DownloadProfile{{Name: "test profile", Command: "sleep", Args: []string{"5"}}} +} + +func (cs *ConfigService) LoadDefaultConfig() { defaultConfig := Config{} stdProfile := DownloadProfile{Name: "standard video", Command: "youtube-dl", Args: []string{ "--newline", @@ -72,7 +77,9 @@ func DefaultConfig() *Config { defaultConfig.ConfigVersion = 2 - return &defaultConfig + cs.Config = &defaultConfig + + return } func (c *Config) ProfileCalled(name string) *DownloadProfile { @@ -153,14 +160,15 @@ func (c *Config) UpdateFromJSON(j []byte) error { if err != nil { return fmt.Errorf("Could not find %s on the path", newConfig.DownloadProfiles[i].Command) } - } *c = newConfig return nil } -func configPath() string { +// configPath returns the full path to the config file (which may or may +// not yet exist) and also creates the subdir if needed (one level) +func (cs *ConfigService) configPath() string { dir, err := os.UserConfigDir() if err != nil { log.Fatalf("cannot find a directory to store config: %v", err) @@ -181,33 +189,35 @@ func configPath() string { return fullFilename } -func ConfigFileExists() bool { - info, err := os.Stat(configPath()) +// ConfigFileExists checks if the config file already exists, and also checks +// if there is an error accessing it +func (cs *ConfigService) ConfigFileExists() (bool, error) { + path := cs.configPath() + info, err := os.Stat(path) if os.IsNotExist(err) { - return false + return false, nil } if err != nil { - log.Fatal(err) + return false, fmt.Errorf("could not check if '%s' exists: %s", path, err) } if info.Size() == 0 { - log.Print("config file is 0 bytes?") - return false + return false, errors.New("config file is 0 bytes") } - return true + return true, nil } -func LoadConfig() (*Config, error) { - path := configPath() +// LoadConfig loads the configuration from disk, migrating and updating it to the +// latest version if needed. +func (cs *ConfigService) LoadConfig() error { + path := cs.configPath() b, err := os.ReadFile(path) if err != nil { - log.Printf("Could not read config '%s': %v", path, err) - return nil, err + return fmt.Errorf("Could not read config '%s': %v", path, err) } c := Config{} err = yaml.Unmarshal(b, &c) if err != nil { - log.Printf("Could not parse YAML config '%s': %v", path, err) - return nil, err + return fmt.Errorf("Could not parse YAML config '%s': %v", path, err) } // do migrations @@ -221,19 +231,22 @@ func LoadConfig() (*Config, error) { if configMigrated { log.Print("Writing new config after version migration") - c.WriteConfig() + cs.WriteConfig() } - return &c, nil + cs.Config = &c + + return nil } -func (c *Config) WriteConfig() { - s, err := yaml.Marshal(c) +// WriteConfig writes the in-memory config to disk. +func (cs *ConfigService) WriteConfig() { + s, err := yaml.Marshal(cs.Config) if err != nil { panic(err) } - path := configPath() + path := cs.configPath() file, err := os.Create( path, ) diff --git a/download/download_test.go b/download/download_test.go index a949fab..d12abe5 100644 --- a/download/download_test.go +++ b/download/download_test.go @@ -67,7 +67,9 @@ func TestUpdateMetadata(t *testing.T) { // [ffmpeg] Merging formats into "Halo Infinite Flight 4K Gameplay-wi7Agv1M6PY.mp4" func TestQueue(t *testing.T) { - conf := config.TestConfig() + cs := config.ConfigService{} + cs.LoadTestConfig() + conf := cs.Config new1 := Download{Id: 1, Url: "http://sub.example.org/foo1", State: "queued", DownloadProfile: conf.DownloadProfiles[0], Config: conf} new2 := Download{Id: 2, Url: "http://sub.example.org/foo2", State: "queued", DownloadProfile: conf.DownloadProfiles[0], Config: conf} diff --git a/main.go b/main.go index 4784656..f9dd562 100644 --- a/main.go +++ b/main.go @@ -21,7 +21,7 @@ import ( var downloads download.Downloads var downloadId = 0 -var conf *config.Config +var configService *config.ConfigService var versionInfo = version.Info{CurrentVersion: "v0.5.4"} @@ -39,16 +39,20 @@ type errorResponse struct { } func main() { - if !config.ConfigFileExists() { + cs := config.ConfigService{} + exists, err := cs.ConfigFileExists() + if err != nil { + log.Fatal(err) + } + if !exists { log.Print("No config file - creating default config") - conf = config.DefaultConfig() - conf.WriteConfig() + cs.LoadDefaultConfig() + cs.WriteConfig() } else { - loadedConfig, err := config.LoadConfig() + err := cs.LoadConfig() if err != nil { log.Fatal(err) } - conf = loadedConfig } r := mux.NewRouter() @@ -69,7 +73,7 @@ func main() { srv := &http.Server{ Handler: r, - Addr: fmt.Sprintf(":%d", conf.Server.Port), + Addr: fmt.Sprintf(":%d", configService.Config.Server.Port), // Good practice: enforce timeouts for servers you create! WriteTimeout: 5 * time.Second, ReadTimeout: 5 * time.Second, @@ -87,15 +91,16 @@ func main() { // old entries go func() { for { - downloads.StartQueued(conf.Server.MaximumActiveDownloads) + downloads.StartQueued(configService.Config.Server.MaximumActiveDownloads) downloads = downloads.Cleanup() time.Sleep(time.Second) } }() log.Printf("starting gropple %s - https://github.com/tardisx/gropple", versionInfo.CurrentVersion) - log.Printf("go to %s for details on installing the bookmarklet and to check status", conf.Server.Address) + log.Printf("go to %s for details on installing the bookmarklet and to check status", configService.Config.Server.Address) log.Fatal(srv.ListenAndServe()) + } // versionRESTHandler returns the version information, if we have up-to-date info from github @@ -112,7 +117,7 @@ func versionRESTHandler(w http.ResponseWriter, r *http.Request) { func homeHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - bookmarkletURL := fmt.Sprintf("javascript:(function(f,s,n,o){window.open(f+encodeURIComponent(s),n,o)}('%s/fetch?url=',window.location,'yourform','width=%d,height=%d'));", conf.Server.Address, conf.UI.PopupWidth, conf.UI.PopupHeight) + bookmarkletURL := fmt.Sprintf("javascript:(function(f,s,n,o){window.open(f+encodeURIComponent(s),n,o)}('%s/fetch?url=',window.location,'yourform','width=%d,height=%d'));", configService.Config.Server.Address, configService.Config.UI.PopupWidth, configService.Config.UI.PopupHeight) t, err := template.ParseFS(webFS, "web/layout.tmpl", "web/menu.tmpl", "web/index.html") if err != nil { @@ -128,7 +133,7 @@ func homeHandler(w http.ResponseWriter, r *http.Request) { info := Info{ Downloads: downloads, BookmarkletURL: template.URL(bookmarkletURL), - Config: conf, + Config: configService.Config, } err = t.ExecuteTemplate(w, "layout", info) @@ -180,7 +185,7 @@ func configRESTHandler(w http.ResponseWriter, r *http.Request) { if err != nil { panic(err) } - err = conf.UpdateFromJSON(b) + err = configService.Config.UpdateFromJSON(b) if err != nil { errorRes := errorResponse{Success: false, Error: err.Error()} @@ -189,9 +194,9 @@ func configRESTHandler(w http.ResponseWriter, r *http.Request) { w.Write(errorResB) return } - conf.WriteConfig() + configService.WriteConfig() } - b, _ := json.Marshal(conf) + b, _ := json.Marshal(configService.Config) w.Write(b) } @@ -243,7 +248,7 @@ func fetchInfoOneRESTHandler(w http.ResponseWriter, r *http.Request) { if thisReq.Action == "start" { // find the profile they asked for - profile := conf.ProfileCalled(thisReq.Profile) + profile := configService.Config.ProfileCalled(thisReq.Profile) if profile == nil { panic("bad profile name?") } @@ -296,7 +301,7 @@ func fetchHandler(w http.ResponseWriter, r *http.Request) { panic(err) } - templateData := map[string]interface{}{"dl": dl, "config": conf, "canStop": download.CanStopDownload} + templateData := map[string]interface{}{"dl": dl, "config": configService.Config, "canStop": download.CanStopDownload} err = t.ExecuteTemplate(w, "layout", templateData) if err != nil { @@ -318,14 +323,14 @@ func fetchHandler(w http.ResponseWriter, r *http.Request) { } else { // check the URL for a sudden but inevitable betrayal - if strings.Contains(url[0], conf.Server.Address) { + if strings.Contains(url[0], configService.Config.Server.Address) { w.WriteHeader(400) fmt.Fprint(w, "you mustn't gropple your gropple :-)") return } // create the record - newDownload := download.NewDownload(conf, url[0]) + newDownload := download.NewDownload(configService.Config, url[0]) downloads = append(downloads, newDownload) // XXX atomic ^^ @@ -340,7 +345,7 @@ func fetchHandler(w http.ResponseWriter, r *http.Request) { panic(err) } - templateData := map[string]interface{}{"dl": newDownload, "config": conf, "canStop": download.CanStopDownload} + templateData := map[string]interface{}{"dl": newDownload, "config": configService.Config, "canStop": download.CanStopDownload} err = t.ExecuteTemplate(w, "layout", templateData) if err != nil {