diff --git a/cmd/um/main.go b/cmd/um/main.go index 1730f38..aa2cccd 100644 --- a/cmd/um/main.go +++ b/cmd/um/main.go @@ -8,6 +8,7 @@ import ( "io" "os" "os/signal" + "path" "path/filepath" "runtime" "runtime/debug" @@ -84,6 +85,11 @@ func printSupportedExtensions() { } func appMain(c *cli.Context) (err error) { + cwd, err := os.Getwd() + if err != nil { + return err + } + if c.Bool("supported-ext") { printSupportedExtensions() return nil @@ -92,10 +98,7 @@ func appMain(c *cli.Context) (err error) { if input == "" { switch c.Args().Len() { case 0: - input, err = os.Getwd() - if err != nil { - return err - } + input = cwd case 1: input = c.Args().Get(0) default: @@ -104,22 +107,20 @@ func appMain(c *cli.Context) (err error) { } output := c.String("output") - if output == "" { - var err error - output, err = os.Getwd() - if err != nil { - return err - } - if input == output { - return errors.New("input and output path are same") - } - } - inputStat, err := os.Stat(input) if err != nil { return err } + if output == "" { + // Default to where the input is + if inputStat.IsDir() { + output = input + } else { + output = path.Dir(input) + } + } + outputStat, err := os.Stat(output) if err != nil { if errors.Is(err, os.ErrNotExist) {