Create a ConfigService struct to handle managing our config.

This commit is contained in:
Justin Hawkins 2022-04-07 20:39:14 +09:30
parent 4bd38a8635
commit 2e94eb6a87
3 changed files with 66 additions and 46 deletions

View File

@ -37,13 +37,18 @@ type Config struct {
DownloadProfiles []DownloadProfile `yaml:"profiles" json:"profiles"`
}
func TestConfig() *Config {
config := DefaultConfig()
config.DownloadProfiles = []DownloadProfile{{Name: "test profile", Command: "sleep", Args: []string{"5"}}}
return config
// ConfigService is a struct to handle configuration requests, allowing for the
// location that config files are loaded to be customised.
type ConfigService struct {
Config *Config
}
func DefaultConfig() *Config {
func (cs *ConfigService) LoadTestConfig() {
cs.LoadDefaultConfig()
cs.Config.DownloadProfiles = []DownloadProfile{{Name: "test profile", Command: "sleep", Args: []string{"5"}}}
}
func (cs *ConfigService) LoadDefaultConfig() {
defaultConfig := Config{}
stdProfile := DownloadProfile{Name: "standard video", Command: "youtube-dl", Args: []string{
"--newline",
@ -72,7 +77,9 @@ func DefaultConfig() *Config {
defaultConfig.ConfigVersion = 2
return &defaultConfig
cs.Config = &defaultConfig
return
}
func (c *Config) ProfileCalled(name string) *DownloadProfile {
@ -153,14 +160,15 @@ func (c *Config) UpdateFromJSON(j []byte) error {
if err != nil {
return fmt.Errorf("Could not find %s on the path", newConfig.DownloadProfiles[i].Command)
}
}
*c = newConfig
return nil
}
func configPath() string {
// configPath returns the full path to the config file (which may or may
// not yet exist) and also creates the subdir if needed (one level)
func (cs *ConfigService) configPath() string {
dir, err := os.UserConfigDir()
if err != nil {
log.Fatalf("cannot find a directory to store config: %v", err)
@ -181,33 +189,35 @@ func configPath() string {
return fullFilename
}
func ConfigFileExists() bool {
info, err := os.Stat(configPath())
// ConfigFileExists checks if the config file already exists, and also checks
// if there is an error accessing it
func (cs *ConfigService) ConfigFileExists() (bool, error) {
path := cs.configPath()
info, err := os.Stat(path)
if os.IsNotExist(err) {
return false
return false, nil
}
if err != nil {
log.Fatal(err)
return false, fmt.Errorf("could not check if '%s' exists: %s", path, err)
}
if info.Size() == 0 {
log.Print("config file is 0 bytes?")
return false
return false, errors.New("config file is 0 bytes")
}
return true
return true, nil
}
func LoadConfig() (*Config, error) {
path := configPath()
// LoadConfig loads the configuration from disk, migrating and updating it to the
// latest version if needed.
func (cs *ConfigService) LoadConfig() error {
path := cs.configPath()
b, err := os.ReadFile(path)
if err != nil {
log.Printf("Could not read config '%s': %v", path, err)
return nil, err
return fmt.Errorf("Could not read config '%s': %v", path, err)
}
c := Config{}
err = yaml.Unmarshal(b, &c)
if err != nil {
log.Printf("Could not parse YAML config '%s': %v", path, err)
return nil, err
return fmt.Errorf("Could not parse YAML config '%s': %v", path, err)
}
// do migrations
@ -221,19 +231,22 @@ func LoadConfig() (*Config, error) {
if configMigrated {
log.Print("Writing new config after version migration")
c.WriteConfig()
cs.WriteConfig()
}
return &c, nil
cs.Config = &c
return nil
}
func (c *Config) WriteConfig() {
s, err := yaml.Marshal(c)
// WriteConfig writes the in-memory config to disk.
func (cs *ConfigService) WriteConfig() {
s, err := yaml.Marshal(cs.Config)
if err != nil {
panic(err)
}
path := configPath()
path := cs.configPath()
file, err := os.Create(
path,
)

View File

@ -67,7 +67,9 @@ func TestUpdateMetadata(t *testing.T) {
// [ffmpeg] Merging formats into "Halo Infinite Flight 4K Gameplay-wi7Agv1M6PY.mp4"
func TestQueue(t *testing.T) {
conf := config.TestConfig()
cs := config.ConfigService{}
cs.LoadTestConfig()
conf := cs.Config
new1 := Download{Id: 1, Url: "http://sub.example.org/foo1", State: "queued", DownloadProfile: conf.DownloadProfiles[0], Config: conf}
new2 := Download{Id: 2, Url: "http://sub.example.org/foo2", State: "queued", DownloadProfile: conf.DownloadProfiles[0], Config: conf}

47
main.go
View File

@ -21,7 +21,7 @@ import (
var downloads download.Downloads
var downloadId = 0
var conf *config.Config
var configService *config.ConfigService
var versionInfo = version.Info{CurrentVersion: "v0.5.4"}
@ -39,16 +39,20 @@ type errorResponse struct {
}
func main() {
if !config.ConfigFileExists() {
log.Print("No config file - creating default config")
conf = config.DefaultConfig()
conf.WriteConfig()
} else {
loadedConfig, err := config.LoadConfig()
cs := config.ConfigService{}
exists, err := cs.ConfigFileExists()
if err != nil {
log.Fatal(err)
}
if !exists {
log.Print("No config file - creating default config")
cs.LoadDefaultConfig()
cs.WriteConfig()
} else {
err := cs.LoadConfig()
if err != nil {
log.Fatal(err)
}
conf = loadedConfig
}
r := mux.NewRouter()
@ -69,7 +73,7 @@ func main() {
srv := &http.Server{
Handler: r,
Addr: fmt.Sprintf(":%d", conf.Server.Port),
Addr: fmt.Sprintf(":%d", configService.Config.Server.Port),
// Good practice: enforce timeouts for servers you create!
WriteTimeout: 5 * time.Second,
ReadTimeout: 5 * time.Second,
@ -87,15 +91,16 @@ func main() {
// old entries
go func() {
for {
downloads.StartQueued(conf.Server.MaximumActiveDownloads)
downloads.StartQueued(configService.Config.Server.MaximumActiveDownloads)
downloads = downloads.Cleanup()
time.Sleep(time.Second)
}
}()
log.Printf("starting gropple %s - https://github.com/tardisx/gropple", versionInfo.CurrentVersion)
log.Printf("go to %s for details on installing the bookmarklet and to check status", conf.Server.Address)
log.Printf("go to %s for details on installing the bookmarklet and to check status", configService.Config.Server.Address)
log.Fatal(srv.ListenAndServe())
}
// versionRESTHandler returns the version information, if we have up-to-date info from github
@ -112,7 +117,7 @@ func versionRESTHandler(w http.ResponseWriter, r *http.Request) {
func homeHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
bookmarkletURL := fmt.Sprintf("javascript:(function(f,s,n,o){window.open(f+encodeURIComponent(s),n,o)}('%s/fetch?url=',window.location,'yourform','width=%d,height=%d'));", conf.Server.Address, conf.UI.PopupWidth, conf.UI.PopupHeight)
bookmarkletURL := fmt.Sprintf("javascript:(function(f,s,n,o){window.open(f+encodeURIComponent(s),n,o)}('%s/fetch?url=',window.location,'yourform','width=%d,height=%d'));", configService.Config.Server.Address, configService.Config.UI.PopupWidth, configService.Config.UI.PopupHeight)
t, err := template.ParseFS(webFS, "web/layout.tmpl", "web/menu.tmpl", "web/index.html")
if err != nil {
@ -128,7 +133,7 @@ func homeHandler(w http.ResponseWriter, r *http.Request) {
info := Info{
Downloads: downloads,
BookmarkletURL: template.URL(bookmarkletURL),
Config: conf,
Config: configService.Config,
}
err = t.ExecuteTemplate(w, "layout", info)
@ -180,7 +185,7 @@ func configRESTHandler(w http.ResponseWriter, r *http.Request) {
if err != nil {
panic(err)
}
err = conf.UpdateFromJSON(b)
err = configService.Config.UpdateFromJSON(b)
if err != nil {
errorRes := errorResponse{Success: false, Error: err.Error()}
@ -189,9 +194,9 @@ func configRESTHandler(w http.ResponseWriter, r *http.Request) {
w.Write(errorResB)
return
}
conf.WriteConfig()
configService.WriteConfig()
}
b, _ := json.Marshal(conf)
b, _ := json.Marshal(configService.Config)
w.Write(b)
}
@ -243,7 +248,7 @@ func fetchInfoOneRESTHandler(w http.ResponseWriter, r *http.Request) {
if thisReq.Action == "start" {
// find the profile they asked for
profile := conf.ProfileCalled(thisReq.Profile)
profile := configService.Config.ProfileCalled(thisReq.Profile)
if profile == nil {
panic("bad profile name?")
}
@ -296,7 +301,7 @@ func fetchHandler(w http.ResponseWriter, r *http.Request) {
panic(err)
}
templateData := map[string]interface{}{"dl": dl, "config": conf, "canStop": download.CanStopDownload}
templateData := map[string]interface{}{"dl": dl, "config": configService.Config, "canStop": download.CanStopDownload}
err = t.ExecuteTemplate(w, "layout", templateData)
if err != nil {
@ -318,14 +323,14 @@ func fetchHandler(w http.ResponseWriter, r *http.Request) {
} else {
// check the URL for a sudden but inevitable betrayal
if strings.Contains(url[0], conf.Server.Address) {
if strings.Contains(url[0], configService.Config.Server.Address) {
w.WriteHeader(400)
fmt.Fprint(w, "you mustn't gropple your gropple :-)")
return
}
// create the record
newDownload := download.NewDownload(conf, url[0])
newDownload := download.NewDownload(configService.Config, url[0])
downloads = append(downloads, newDownload)
// XXX atomic ^^
@ -340,7 +345,7 @@ func fetchHandler(w http.ResponseWriter, r *http.Request) {
panic(err)
}
templateData := map[string]interface{}{"dl": newDownload, "config": conf, "canStop": download.CanStopDownload}
templateData := map[string]interface{}{"dl": newDownload, "config": configService.Config, "canStop": download.CanStopDownload}
err = t.ExecuteTemplate(w, "layout", templateData)
if err != nil {