From a0c5a497c722d5a337d8b006ce3eab8b5642d1cb Mon Sep 17 00:00:00 2001 From: awalol Date: Mon, 5 Feb 2024 00:47:24 +0800 Subject: [PATCH] support encrypted database --- interface.go | 1 + manager.go | 24 +++++++++++++++++++++--- manager_test.go | 29 +++++++++++++++++++++++------ testdata/crypto | Bin 0 -> 4096 bytes testdata/crypto.crc | Bin 0 -> 4096 bytes vault.go | 19 +++++++++++++++++-- vault_test.go | 2 +- 7 files changed, 63 insertions(+), 12 deletions(-) create mode 100644 testdata/crypto create mode 100644 testdata/crypto.crc diff --git a/interface.go b/interface.go index 2c32c5e..8387b32 100644 --- a/interface.go +++ b/interface.go @@ -5,6 +5,7 @@ type Manager interface { // If the vault does not exist, it will be created. // If id is empty, DefaultVaultID will be used. OpenVault(id string) (Vault, error) + OpenVaultCrypto(id string, cryptoKey string) (Vault, error) } type Vault interface { diff --git a/manager.go b/manager.go index 31ad59a..0e30382 100644 --- a/manager.go +++ b/manager.go @@ -42,7 +42,7 @@ func (m *manager) OpenVault(id string) (Vault, error) { return v, nil } - vault, err := m.openVault(id) + vault, err := m.openVault(id, "") if err != nil { return nil, fmt.Errorf("failed to open vault: %w", err) } @@ -51,7 +51,25 @@ func (m *manager) OpenVault(id string) (Vault, error) { return vault, nil } -func (m *manager) openVault(id string) (Vault, error) { +func (m *manager) OpenVaultCrypto(id string, cryptoKey string) (Vault, error) { + if id == "" { + id = DefaultVaultID + } + + if v, ok := m.vaults[id]; ok { + return v, nil + } + + vault, err := m.openVault(id, cryptoKey) + if err != nil { + return nil, fmt.Errorf("failed to open vault: %w", err) + } + m.vaults[id] = vault + + return vault, nil +} + +func (m *manager) openVault(id string, cryptoKey string) (Vault, error) { metaFile, err := os.Open(path.Join(m.dir, id+".crc")) if err != nil { return nil, fmt.Errorf("failed to open metadata file: %w", err) @@ -69,7 +87,7 @@ func (m *manager) openVault(id string) (Vault, error) { return nil, fmt.Errorf("failed to load metadata: %w", err) } - v, err := loadVault(vaultFile, meta) + v, err := loadVault(vaultFile, meta, cryptoKey) if err != nil { return nil, fmt.Errorf("failed to load vault: %w", err) } diff --git a/manager_test.go b/manager_test.go index 9bdd64d..e5488d9 100644 --- a/manager_test.go +++ b/manager_test.go @@ -7,11 +7,28 @@ import ( ) func TestNewManager(t *testing.T) { - mgr, err := NewManager("./testdata") - assert.NoError(t, err) - assert.NotNil(t, mgr) + t.Run("Default", func(t *testing.T) { + mgr, err := NewManager("./testdata") + assert.NoError(t, err) + assert.NotNil(t, mgr) - vault, err := mgr.OpenVault("") - assert.NoError(t, err) - assert.NotNil(t, vault) + vault, err := mgr.OpenVault("") + assert.NoError(t, err) + assert.NotNil(t, vault) + }) + t.Run("Crypto", func(t *testing.T) { + mgr, err := NewManager("./testdata") + assert.NoError(t, err) + assert.NotNil(t, mgr) + + vault, err := mgr.OpenVaultCrypto("crypto", "123456") + val, err := vault.GetString("world") + assert.NotNil(t, vault) + + assert.Equal(t, "hello", val) + assert.NoError(t, err) + + _, err = vault.GetBytes("foo") + assert.Error(t, err) + }) } diff --git a/testdata/crypto b/testdata/crypto new file mode 100644 index 0000000000000000000000000000000000000000..a3a396e09516b8e35bf0babea06ed50342704fdb GIT binary patch literal 4096 zcmWe+U|`7P;5l+g(5HuE*N64lcg|}3U^9Wr4 literal 0 HcmV?d00001 diff --git a/testdata/crypto.crc b/testdata/crypto.crc new file mode 100644 index 0000000000000000000000000000000000000000..1b6459d67582f01f42c8a142367e6d61656029f6 GIT binary patch literal 4096 zcmeDEw|x%_0|Nsi5WkIQV0fg%z+hF$z);J{z#s?|1IcyNAB_Z3qhK@yMnhmU1V%$( gGz3ONU^E0qLtr!nMnhmU1V%$(Gz3ONV2FkQ0EEv7(EtDd literal 0 HcmV?d00001 diff --git a/vault.go b/vault.go index 131b5d8..df5b57b 100644 --- a/vault.go +++ b/vault.go @@ -1,6 +1,8 @@ package mmkv import ( + "crypto/aes" + "crypto/cipher" "encoding/binary" "fmt" "hash/crc32" @@ -44,8 +46,8 @@ func (v vault) GetString(key string) (string, error) { return string(val), nil } -// metadata is optional. but if it exists, validate with it. -func loadVault(src io.Reader, m *metadata) (Vault, error) { +// metadata and cryptoKey are optional. but if it exists, validate with them. +func loadVault(src io.Reader, m *metadata, cryptoKey string) (Vault, error) { fileSizeBuf := make([]byte, 4) _, err := io.ReadFull(src, fileSizeBuf) if err != nil { @@ -67,6 +69,19 @@ func loadVault(src io.Reader, m *metadata) (Vault, error) { return nil, fmt.Errorf("metadata and vault payload crc32 mismatch") } + // 将数据库完整解密 + if len(cryptoKey) > 0 { + m_key := make([]byte, aes.BlockSize) // 16 bytes key + copy(m_key, cryptoKey) + + block, err := aes.NewCipher(m_key) + if err != nil { + return nil, fmt.Errorf("failed to create aes cipher") + } + stream := cipher.NewCFBDecrypter(block, m.aesVector) + stream.XORKeyStream(buf, buf) + } + v := make(vault) // mmkv is not really protobuf compatible, diff --git a/vault_test.go b/vault_test.go index d3e79fc..578498e 100644 --- a/vault_test.go +++ b/vault_test.go @@ -12,7 +12,7 @@ func Test_loadVault(t *testing.T) { file, err := os.Open("./testdata/mmkv.default") require.NoError(t, err) - v, err := loadVault(file, nil) + v, err := loadVault(file, nil, "") require.NoError(t, err) assert.Equal(t, 2, len(v.Keys()))