diff --git a/interface.go b/interface.go index 2c32c5e..8387b32 100644 --- a/interface.go +++ b/interface.go @@ -5,6 +5,7 @@ type Manager interface { // If the vault does not exist, it will be created. // If id is empty, DefaultVaultID will be used. OpenVault(id string) (Vault, error) + OpenVaultCrypto(id string, cryptoKey string) (Vault, error) } type Vault interface { diff --git a/manager.go b/manager.go index 31ad59a..0e30382 100644 --- a/manager.go +++ b/manager.go @@ -42,7 +42,7 @@ func (m *manager) OpenVault(id string) (Vault, error) { return v, nil } - vault, err := m.openVault(id) + vault, err := m.openVault(id, "") if err != nil { return nil, fmt.Errorf("failed to open vault: %w", err) } @@ -51,7 +51,25 @@ func (m *manager) OpenVault(id string) (Vault, error) { return vault, nil } -func (m *manager) openVault(id string) (Vault, error) { +func (m *manager) OpenVaultCrypto(id string, cryptoKey string) (Vault, error) { + if id == "" { + id = DefaultVaultID + } + + if v, ok := m.vaults[id]; ok { + return v, nil + } + + vault, err := m.openVault(id, cryptoKey) + if err != nil { + return nil, fmt.Errorf("failed to open vault: %w", err) + } + m.vaults[id] = vault + + return vault, nil +} + +func (m *manager) openVault(id string, cryptoKey string) (Vault, error) { metaFile, err := os.Open(path.Join(m.dir, id+".crc")) if err != nil { return nil, fmt.Errorf("failed to open metadata file: %w", err) @@ -69,7 +87,7 @@ func (m *manager) openVault(id string) (Vault, error) { return nil, fmt.Errorf("failed to load metadata: %w", err) } - v, err := loadVault(vaultFile, meta) + v, err := loadVault(vaultFile, meta, cryptoKey) if err != nil { return nil, fmt.Errorf("failed to load vault: %w", err) } diff --git a/manager_test.go b/manager_test.go index 9bdd64d..e5488d9 100644 --- a/manager_test.go +++ b/manager_test.go @@ -7,11 +7,28 @@ import ( ) func TestNewManager(t *testing.T) { - mgr, err := NewManager("./testdata") - assert.NoError(t, err) - assert.NotNil(t, mgr) + t.Run("Default", func(t *testing.T) { + mgr, err := NewManager("./testdata") + assert.NoError(t, err) + assert.NotNil(t, mgr) - vault, err := mgr.OpenVault("") - assert.NoError(t, err) - assert.NotNil(t, vault) + vault, err := mgr.OpenVault("") + assert.NoError(t, err) + assert.NotNil(t, vault) + }) + t.Run("Crypto", func(t *testing.T) { + mgr, err := NewManager("./testdata") + assert.NoError(t, err) + assert.NotNil(t, mgr) + + vault, err := mgr.OpenVaultCrypto("crypto", "123456") + val, err := vault.GetString("world") + assert.NotNil(t, vault) + + assert.Equal(t, "hello", val) + assert.NoError(t, err) + + _, err = vault.GetBytes("foo") + assert.Error(t, err) + }) } diff --git a/testdata/crypto b/testdata/crypto new file mode 100644 index 0000000..a3a396e Binary files /dev/null and b/testdata/crypto differ diff --git a/testdata/crypto.crc b/testdata/crypto.crc new file mode 100644 index 0000000..1b6459d Binary files /dev/null and b/testdata/crypto.crc differ diff --git a/vault.go b/vault.go index 131b5d8..df5b57b 100644 --- a/vault.go +++ b/vault.go @@ -1,6 +1,8 @@ package mmkv import ( + "crypto/aes" + "crypto/cipher" "encoding/binary" "fmt" "hash/crc32" @@ -44,8 +46,8 @@ func (v vault) GetString(key string) (string, error) { return string(val), nil } -// metadata is optional. but if it exists, validate with it. -func loadVault(src io.Reader, m *metadata) (Vault, error) { +// metadata and cryptoKey are optional. but if it exists, validate with them. +func loadVault(src io.Reader, m *metadata, cryptoKey string) (Vault, error) { fileSizeBuf := make([]byte, 4) _, err := io.ReadFull(src, fileSizeBuf) if err != nil { @@ -67,6 +69,19 @@ func loadVault(src io.Reader, m *metadata) (Vault, error) { return nil, fmt.Errorf("metadata and vault payload crc32 mismatch") } + // 将数据库完整解密 + if len(cryptoKey) > 0 { + m_key := make([]byte, aes.BlockSize) // 16 bytes key + copy(m_key, cryptoKey) + + block, err := aes.NewCipher(m_key) + if err != nil { + return nil, fmt.Errorf("failed to create aes cipher") + } + stream := cipher.NewCFBDecrypter(block, m.aesVector) + stream.XORKeyStream(buf, buf) + } + v := make(vault) // mmkv is not really protobuf compatible, diff --git a/vault_test.go b/vault_test.go index d3e79fc..578498e 100644 --- a/vault_test.go +++ b/vault_test.go @@ -12,7 +12,7 @@ func Test_loadVault(t *testing.T) { file, err := os.Open("./testdata/mmkv.default") require.NoError(t, err) - v, err := loadVault(file, nil) + v, err := loadVault(file, nil, "") require.NoError(t, err) assert.Equal(t, 2, len(v.Keys()))