AESCrypt.cpp 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. /*
  2. * Tencent is pleased to support the open source community by making
  3. * MMKV available.
  4. *
  5. * Copyright (C) 2018 THL A29 Limited, a Tencent company.
  6. * All rights reserved.
  7. *
  8. * Licensed under the BSD 3-Clause License (the "License"); you may not use
  9. * this file except in compliance with the License. You may obtain a copy of
  10. * the License at
  11. *
  12. * https://opensource.org/licenses/BSD-3-Clause
  13. *
  14. * Unless required by applicable law or agreed to in writing, software
  15. * distributed under the License is distributed on an "AS IS" BASIS,
  16. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. * See the License for the specific language governing permissions and
  18. * limitations under the License.
  19. */
  20. #include "AESCrypt.h"
  21. #include "openssl/openssl_aes.h"
  22. #include <cstdint>
  23. #include <cstdlib>
  24. #include <cstring>
  25. #include <ctime>
  26. #ifndef MMKV_DISABLE_CRYPT
  27. using namespace openssl;
  28. namespace mmkv {
  29. AESCrypt::AESCrypt(const void *key, size_t keyLength, const void *iv, size_t ivLength) {
  30. if (key && keyLength > 0) {
  31. memcpy(m_key, key, (keyLength > AES_KEY_LEN) ? AES_KEY_LEN : keyLength);
  32. resetIV(iv, ivLength);
  33. m_aesKey = new AES_KEY;
  34. memset(m_aesKey, 0, sizeof(AES_KEY));
  35. int ret = AES_set_encrypt_key(m_key, AES_KEY_BITSET_LEN, m_aesKey);
  36. MMKV_ASSERT(ret == 0);
  37. }
  38. }
  39. AESCrypt::AESCrypt(const AESCrypt &other, const AESCryptStatus &status) : m_isClone(true), m_number(status.m_number) {
  40. //memcpy(m_key, other.m_key, sizeof(m_key));
  41. memcpy(m_vector, status.m_vector, sizeof(m_vector));
  42. m_aesKey = other.m_aesKey;
  43. }
  44. AESCrypt::~AESCrypt() {
  45. if (!m_isClone) {
  46. delete m_aesKey;
  47. delete m_aesRollbackKey;
  48. }
  49. }
  50. void AESCrypt::resetIV(const void *iv, size_t ivLength) {
  51. m_number = 0;
  52. if (iv && ivLength > 0) {
  53. memcpy(m_vector, iv, (ivLength > AES_KEY_LEN) ? AES_KEY_LEN : ivLength);
  54. } else {
  55. memcpy(m_vector, m_key, AES_KEY_LEN);
  56. }
  57. }
  58. void AESCrypt::resetStatus(const AESCryptStatus &status) {
  59. m_number = status.m_number;
  60. memcpy(m_vector, status.m_vector, AES_KEY_LEN);
  61. }
  62. void AESCrypt::getKey(void *output) const {
  63. if (output) {
  64. memcpy(output, m_key, AES_KEY_LEN);
  65. }
  66. }
  67. void AESCrypt::encrypt(const void *input, void *output, size_t length) {
  68. if (!input || !output || length == 0) {
  69. return;
  70. }
  71. AES_cfb128_encrypt((const uint8_t *) input, (uint8_t *) output, length, m_aesKey, m_vector, &m_number);
  72. }
  73. void AESCrypt::decrypt(const void *input, void *output, size_t length) {
  74. if (!input || !output || length == 0) {
  75. return;
  76. }
  77. AES_cfb128_decrypt((const uint8_t *) input, (uint8_t *) output, length, m_aesKey, m_vector, &m_number);
  78. }
  79. void AESCrypt::fillRandomIV(void *vector) {
  80. if (!vector) {
  81. return;
  82. }
  83. srand((unsigned) time(nullptr));
  84. int *ptr = (int *) vector;
  85. for (uint32_t i = 0; i < AES_KEY_LEN / sizeof(int); i++) {
  86. ptr[i] = rand();
  87. }
  88. }
  89. static inline void
  90. Rollback_cfb_decrypt(const uint8_t *input, const uint8_t *output, size_t len, AES_KEY *key, AESCryptStatus &status) {
  91. auto ivec = status.m_vector;
  92. auto n = status.m_number;
  93. while (n && len) {
  94. auto c = *(--output);
  95. ivec[--n] = *(--input) ^ c;
  96. len--;
  97. }
  98. if (n == 0 && (status.m_number != 0)) {
  99. AES_decrypt(ivec, ivec, key);
  100. }
  101. while (len >= 16) {
  102. len -= 16;
  103. output -= 16;
  104. input -= 16;
  105. for (; n < 16; n += sizeof(size_t)) {
  106. size_t t = *(size_t *) (output + n);
  107. *(size_t *) (ivec + n) = *(size_t *) (input + n) ^ t;
  108. }
  109. n = 0;
  110. AES_decrypt(ivec, ivec, key);
  111. }
  112. if (len) {
  113. n = 16;
  114. do {
  115. auto c = *(--output);
  116. ivec[--n] = *(--input) ^ c;
  117. len--;
  118. } while (len);
  119. }
  120. status.m_number = n;
  121. }
  122. void AESCrypt::statusBeforeDecrypt(const void *input, const void *output, size_t length, AESCryptStatus &status) {
  123. if (length == 0) {
  124. return;
  125. }
  126. if (!m_aesRollbackKey) {
  127. m_aesRollbackKey = new AES_KEY;
  128. memset(m_aesRollbackKey, 0, sizeof(AES_KEY));
  129. int ret = AES_set_decrypt_key(m_key, AES_KEY_BITSET_LEN, m_aesRollbackKey);
  130. MMKV_ASSERT(ret == 0);
  131. }
  132. getCurStatus(status);
  133. Rollback_cfb_decrypt((const uint8_t *) input, (const uint8_t *) output, length, m_aesRollbackKey, status);
  134. }
  135. void AESCrypt::getCurStatus(AESCryptStatus &status) {
  136. status.m_number = static_cast<uint8_t>(m_number);
  137. memcpy(status.m_vector, m_vector, sizeof(m_vector));
  138. }
  139. AESCrypt AESCrypt::cloneWithStatus(const AESCryptStatus &status) const {
  140. return AESCrypt(*this, status);
  141. }
  142. } // namespace mmkv
  143. # ifdef MMKV_DEBUG
  144. # include "../MMKVLog.h"
  145. # include "../MemoryFile.h"
  146. namespace mmkv {
  147. // check if AESCrypt is encrypt-decrypt full-duplex
  148. void AESCrypt::testAESCrypt() {
  149. const uint8_t plainText[] = "Hello, OpenSSL-mmkv::AESCrypt::testAESCrypt() with AES CFB 128.";
  150. constexpr size_t textLength = sizeof(plainText) - 1;
  151. const uint8_t key[] = "TheAESKey";
  152. constexpr size_t keyLength = sizeof(key) - 1;
  153. uint8_t iv[AES_KEY_LEN];
  154. srand((unsigned) time(nullptr));
  155. for (uint32_t i = 0; i < AES_KEY_LEN; i++) {
  156. iv[i] = (uint8_t) rand();
  157. }
  158. AESCrypt crypt1(key, keyLength, iv, sizeof(iv));
  159. AESCrypt crypt2(key, keyLength, iv, sizeof(iv));
  160. auto encryptText = new uint8_t[DEFAULT_MMAP_SIZE];
  161. auto decryptText = new uint8_t[DEFAULT_MMAP_SIZE];
  162. memset(encryptText, 0, DEFAULT_MMAP_SIZE);
  163. memset(decryptText, 0, DEFAULT_MMAP_SIZE);
  164. /* in-place encryption & decryption
  165. memcpy(encryptText, plainText, textLength);
  166. crypt1.encrypt(encryptText, encryptText, textLength);
  167. crypt2.decrypt(encryptText, encryptText, textLength);
  168. return;
  169. */
  170. AES_KEY decryptKey;
  171. AES_set_decrypt_key(crypt1.m_key, AES_KEY_BITSET_LEN, &decryptKey);
  172. size_t actualSize = 0;
  173. bool flip = false;
  174. for (const uint8_t *ptr = plainText; ptr < plainText + textLength;) {
  175. auto tokenPtr = (const uint8_t *) strchr((const char *) ptr, ' ');
  176. size_t size = 0;
  177. if (!tokenPtr) {
  178. size = static_cast<size_t>(plainText + textLength - ptr);
  179. } else {
  180. size = static_cast<size_t>(tokenPtr - ptr + 1);
  181. }
  182. AESCrypt *decrypter;
  183. uint32_t oldNum;
  184. uint8_t oldVector[sizeof(crypt1.m_vector)];
  185. flip = !flip;
  186. if (flip) {
  187. crypt1.encrypt(plainText + actualSize, encryptText + actualSize, size);
  188. decrypter = &crypt2;
  189. oldNum = decrypter->m_number;
  190. memcpy(oldVector, decrypter->m_vector, sizeof(oldVector));
  191. crypt2.decrypt(encryptText + actualSize, decryptText + actualSize, size);
  192. } else {
  193. crypt2.encrypt(plainText + actualSize, encryptText + actualSize, size);
  194. decrypter = &crypt1;
  195. oldNum = decrypter->m_number;
  196. memcpy(oldVector, decrypter->m_vector, sizeof(oldVector));
  197. crypt1.decrypt(encryptText + actualSize, decryptText + actualSize, size);
  198. }
  199. // that's why AESCrypt can be full-duplex
  200. assert(crypt1.m_number == crypt2.m_number);
  201. assert(0 == memcmp(crypt1.m_vector, crypt2.m_vector, sizeof(crypt1.m_vector)));
  202. // how rollback works
  203. AESCryptStatus status;
  204. decrypter->statusBeforeDecrypt(encryptText + actualSize + size, decryptText + actualSize + size, size, status);
  205. assert(oldNum == status.m_number);
  206. assert(0 == memcmp(oldVector, status.m_vector, sizeof(oldVector)));
  207. actualSize += size;
  208. ptr += size;
  209. }
  210. MMKVInfo("AES CFB decode: %s", decryptText);
  211. delete[] encryptText;
  212. delete[] decryptText;
  213. }
  214. } // namespace mmkv
  215. # endif // MMKV_DEBUG
  216. #endif // MMKV_DISABLE_CRYPT