diff --git a/algo/common/sniff.go b/algo/common/sniff.go new file mode 100644 index 0000000..650e3ae --- /dev/null +++ b/algo/common/sniff.go @@ -0,0 +1,44 @@ +package common + +import "bytes" + +type Sniffer func(header []byte) bool + +var snifferRegistry = map[string]Sniffer{ + ".m4a": SnifferM4A, + ".ogg": SnifferOGG, + ".flac": SnifferFLAC, + ".wav": SnifferWAV, + ".wma": SnifferWMA, + ".mp3": SnifferMP3, +} + +func SniffAll(header []byte) (string, bool) { + for ext, sniffer := range snifferRegistry { + if sniffer(header) { + return ext, true + } + } + return "", false +} + +func SnifferM4A(header []byte) bool { + return len(header) >= 8 && bytes.Equal([]byte("ftyp"), header[4:8]) +} + +func SnifferOGG(header []byte) bool { + return bytes.HasPrefix(header, []byte("OggS")) +} + +func SnifferFLAC(header []byte) bool { + return bytes.HasPrefix(header, []byte("fLaC")) +} +func SnifferMP3(header []byte) bool { + return bytes.HasPrefix(header, []byte("ID3")) +} +func SnifferWAV(header []byte) bool { + return bytes.HasPrefix(header, []byte("RIFF")) +} +func SnifferWMA(header []byte) bool { + return bytes.HasPrefix(header, []byte("\x30\x26\xb2\x75\x8e\x66\xcf\x11\xa6\xd9\x00\xaa\x00\x62\xce\x6c")) +} diff --git a/algo/qmc/consts.go b/algo/qmc/consts.go index 1de93e5..8ad25bf 100644 --- a/algo/qmc/consts.go +++ b/algo/qmc/consts.go @@ -38,8 +38,6 @@ var ( 0x92, 0x62, 0xf3, 0x74, 0xa1, 0x9f, 0xf4, 0xa0, 0x1d, 0x3f, 0x5b, 0xf0, 0x13, 0x0e, 0x09, 0x3d, 0xf9, 0xbc, 0x00, 0x11} - headerFlac = []byte{'f', 'L', 'a', 'C'} - headerOgg = []byte{'O', 'g', 'g', 'S'} ) var key256MappingAll [][]int //[idx256][idx128]idx44 var key256Mapping128to44 map[int]int diff --git a/algo/qmc/mask_key256.go b/algo/qmc/mask_key256.go index a2f5073..d9491f4 100644 --- a/algo/qmc/mask_key256.go +++ b/algo/qmc/mask_key256.go @@ -3,6 +3,7 @@ package qmc import ( "bytes" "errors" + "github.com/unlock-music/cli/algo/common" "github.com/unlock-music/cli/internal/logging" "go.uber.org/zap" ) @@ -116,7 +117,7 @@ func detectMflac256Mask(input []byte) (*Key256Mask, error) { if err != nil { continue } - if bytes.Equal(headerFlac, q.Decrypt(input[:len(headerFlac)])) { + if common.SnifferFLAC(q.Decrypt(input[:4])) { rtErr = nil break } @@ -164,7 +165,7 @@ func detectMgg256Mask(input []byte) (*Key256Mask, error) { if err != nil { return nil, err } - if bytes.Equal(headerOgg, q.Decrypt(input[:len(headerOgg)])) { + if common.SnifferOGG(q.Decrypt(input[:4])) { return q, nil } return nil, ErrDetectMggMask diff --git a/cmd/um/main.go b/cmd/um/main.go index 8722bb9..efedbe4 100644 --- a/cmd/um/main.go +++ b/cmd/um/main.go @@ -117,14 +117,19 @@ func tryDecFile(inputFile string, outputDir string, allDec []common.NewDecoderFu return errors.New("failed while decoding: " + err.Error()) } + outData := dec.GetAudioData() outExt := dec.GetAudioExt() if outExt == "" { - outExt = ".mp3" + if ext, ok := common.SniffAll(outData); ok { + outExt = ext + } else { + outExt = ".mp3" + } } filenameOnly := strings.TrimSuffix(filepath.Base(inputFile), filepath.Ext(inputFile)) - outPath := filepath.Join(outputDir, filenameOnly+"."+outExt) - err = os.WriteFile(outPath, dec.GetAudioData(), 0644) + outPath := filepath.Join(outputDir, filenameOnly+outExt) + err = os.WriteFile(outPath, outData, 0644) if err != nil { return err }