CodedInputDataCrypt.cpp 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. /*
  2. * Tencent is pleased to support the open source community by making
  3. * MMKV available.
  4. *
  5. * Copyright (C) 2020 THL A29 Limited, a Tencent company.
  6. * All rights reserved.
  7. *
  8. * Licensed under the BSD 3-Clause License (the "License"); you may not use
  9. * this file except in compliance with the License. You may obtain a copy of
  10. * the License at
  11. *
  12. * https://opensource.org/licenses/BSD-3-Clause
  13. *
  14. * Unless required by applicable law or agreed to in writing, software
  15. * distributed under the License is distributed on an "AS IS" BASIS,
  16. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. * See the License for the specific language governing permissions and
  18. * limitations under the License.
  19. */
  20. #include "CodedInputDataCrypt.h"
  21. #include "MMKVLog.h"
  22. #include "PBUtility.h"
  23. #include <cassert>
  24. #include <cerrno>
  25. #include <cstring>
  26. #include <stdexcept>
  27. #ifdef MMKV_APPLE
  28. # if __has_feature(objc_arc)
  29. # error This file must be compiled with MRC. Use -fno-objc-arc flag.
  30. # endif
  31. #endif // MMKV_APPLE
  32. #ifndef MMKV_DISABLE_CRYPT
  33. using namespace std;
  34. namespace mmkv {
  35. CodedInputDataCrypt::CodedInputDataCrypt(const void *oData, size_t length, AESCrypt &crypt)
  36. : m_ptr((uint8_t *) oData), m_size(length), m_position(0), m_decryptPosition(0), m_decrypter(crypt) {
  37. m_decryptBufferSize = AES_KEY_LEN * 2;
  38. m_decryptBufferPosition = static_cast<size_t>(crypt.m_number);
  39. m_decryptBufferDiscardPosition = m_decryptBufferPosition;
  40. m_decryptBufferDecryptLength = m_decryptBufferPosition;
  41. m_decryptBuffer = (uint8_t *) malloc(m_decryptBufferSize);
  42. if (!m_decryptBuffer) {
  43. throw runtime_error(strerror(errno));
  44. }
  45. }
  46. CodedInputDataCrypt::~CodedInputDataCrypt() {
  47. if (m_decryptBuffer) {
  48. free(m_decryptBuffer);
  49. }
  50. }
  51. void CodedInputDataCrypt::seek(size_t addedSize) {
  52. m_position += addedSize;
  53. m_decryptPosition += addedSize;
  54. if (m_position > m_size) {
  55. throw out_of_range("OutOfSpace");
  56. }
  57. assert(m_position % AES_KEY_LEN == m_decrypter.m_number);
  58. }
  59. void CodedInputDataCrypt::consumeBytes(size_t length, bool discardPreData) {
  60. if (discardPreData) {
  61. m_decryptBufferDiscardPosition = m_decryptBufferPosition;
  62. }
  63. auto decryptedBytesLeft = m_decryptBufferDecryptLength - m_decryptBufferPosition;
  64. if (decryptedBytesLeft >= length) {
  65. return;
  66. }
  67. length -= decryptedBytesLeft;
  68. // if there's some data left inside m_decrypter.m_vector, use them first
  69. // it will be faster when always decrypt with (n * AES_KEY_LEN) bytes
  70. if (m_decrypter.m_number != 0) {
  71. auto alignDecrypter = AES_KEY_LEN - m_decrypter.m_number;
  72. // make sure no data left inside m_decrypter.m_vector after decrypt
  73. if (length < alignDecrypter) {
  74. length = alignDecrypter;
  75. } else {
  76. length -= alignDecrypter;
  77. length = ((length + AES_KEY_LEN - 1) / AES_KEY_LEN) * AES_KEY_LEN;
  78. length += alignDecrypter;
  79. }
  80. } else {
  81. length = ((length + AES_KEY_LEN - 1) / AES_KEY_LEN) * AES_KEY_LEN;
  82. }
  83. auto bytesLeftInSrc = m_size - m_decryptPosition;
  84. length = min(bytesLeftInSrc, length);
  85. auto bytesLeftInBuffer = m_decryptBufferSize - m_decryptBufferDecryptLength;
  86. // try move some space
  87. if (bytesLeftInBuffer < length && m_decryptBufferDiscardPosition > 0) {
  88. auto posToMove = (m_decryptBufferDiscardPosition / AES_KEY_LEN) * AES_KEY_LEN;
  89. if (posToMove) {
  90. auto sizeToMove = m_decryptBufferDecryptLength - posToMove;
  91. memmove(m_decryptBuffer, m_decryptBuffer + posToMove, sizeToMove);
  92. m_decryptBufferPosition -= posToMove;
  93. m_decryptBufferDecryptLength -= posToMove;
  94. m_decryptBufferDiscardPosition = 0;
  95. bytesLeftInBuffer = m_decryptBufferSize - m_decryptBufferDecryptLength;
  96. }
  97. }
  98. // still no enough space, try realloc()
  99. if (bytesLeftInBuffer < length) {
  100. auto newSize = m_decryptBufferSize + length;
  101. auto newBuffer = realloc(m_decryptBuffer, newSize);
  102. if (!newBuffer) {
  103. throw runtime_error(strerror(errno));
  104. }
  105. m_decryptBuffer = (uint8_t *) newBuffer;
  106. m_decryptBufferSize = newSize;
  107. }
  108. m_decrypter.decrypt(m_ptr + m_decryptPosition, m_decryptBuffer + m_decryptBufferDecryptLength, length);
  109. m_decryptPosition += length;
  110. m_decryptBufferDecryptLength += length;
  111. assert(m_decryptPosition == m_size || m_decrypter.m_number == 0);
  112. }
  113. void CodedInputDataCrypt::skipBytes(size_t length) {
  114. m_position += length;
  115. auto decryptedBytesLeft = m_decryptBufferDecryptLength - m_decryptBufferPosition;
  116. if (decryptedBytesLeft >= length) {
  117. m_decryptBufferPosition += length;
  118. return;
  119. }
  120. length -= decryptedBytesLeft;
  121. // if this happens, we need optimization like the alignDecrypter above
  122. assert(m_decrypter.m_number == 0);
  123. size_t alignSize = ((length + AES_KEY_LEN - 1) / AES_KEY_LEN) * AES_KEY_LEN;
  124. auto bytesLeftInSrc = m_size - m_decryptPosition;
  125. auto size = min(alignSize, bytesLeftInSrc);
  126. decryptedBytesLeft = size - length;
  127. for (size_t index = 0, round = size / AES_KEY_LEN; index < round; index++) {
  128. m_decrypter.decrypt(m_ptr + m_decryptPosition, m_decryptBuffer, AES_KEY_LEN);
  129. m_decryptPosition += AES_KEY_LEN;
  130. size -= AES_KEY_LEN;
  131. }
  132. if (size) {
  133. m_decrypter.decrypt(m_ptr + m_decryptPosition, m_decryptBuffer, size);
  134. m_decryptPosition += size;
  135. m_decryptBufferPosition = size - decryptedBytesLeft;
  136. m_decryptBufferDecryptLength = size;
  137. } else {
  138. m_decryptBufferPosition = AES_KEY_LEN - decryptedBytesLeft;
  139. m_decryptBufferDecryptLength = AES_KEY_LEN;
  140. }
  141. assert(m_decryptBufferPosition <= m_decryptBufferDecryptLength);
  142. assert(m_decryptPosition - m_decryptBufferDecryptLength + m_decryptBufferPosition == m_position);
  143. }
  144. inline void CodedInputDataCrypt::statusBeforeDecrypt(size_t rollbackSize, AESCryptStatus &status) {
  145. rollbackSize += m_decryptBufferDecryptLength - m_decryptBufferPosition;
  146. m_decrypter.statusBeforeDecrypt(m_ptr + m_decryptPosition, m_decryptBuffer + m_decryptBufferDecryptLength,
  147. rollbackSize, status);
  148. }
  149. int8_t CodedInputDataCrypt::readRawByte() {
  150. assert(m_position <= m_decryptPosition);
  151. if (m_position == m_size) {
  152. auto msg = "reach end, m_position: " + to_string(m_position) + ", m_size: " + to_string(m_size);
  153. throw out_of_range(msg);
  154. }
  155. m_position++;
  156. assert(m_decryptBufferPosition < m_decryptBufferSize);
  157. auto *bytes = (int8_t *) m_decryptBuffer;
  158. return bytes[m_decryptBufferPosition++];
  159. }
  160. int32_t CodedInputDataCrypt::readRawVarint32(bool discardPreData) {
  161. consumeBytes(10, discardPreData);
  162. int8_t tmp = this->readRawByte();
  163. if (tmp >= 0) {
  164. return tmp;
  165. }
  166. int32_t result = tmp & 0x7f;
  167. if ((tmp = this->readRawByte()) >= 0) {
  168. result |= tmp << 7;
  169. } else {
  170. result |= (tmp & 0x7f) << 7;
  171. if ((tmp = this->readRawByte()) >= 0) {
  172. result |= tmp << 14;
  173. } else {
  174. result |= (tmp & 0x7f) << 14;
  175. if ((tmp = this->readRawByte()) >= 0) {
  176. result |= tmp << 21;
  177. } else {
  178. result |= (tmp & 0x7f) << 21;
  179. result |= (tmp = this->readRawByte()) << 28;
  180. if (tmp < 0) {
  181. // discard upper 32 bits
  182. for (int i = 0; i < 5; i++) {
  183. if (this->readRawByte() >= 0) {
  184. return result;
  185. }
  186. }
  187. throw invalid_argument("InvalidProtocolBuffer malformed varint32");
  188. }
  189. }
  190. }
  191. }
  192. return result;
  193. }
  194. int32_t CodedInputDataCrypt::readInt32() {
  195. return this->readRawVarint32();
  196. }
  197. # ifndef MMKV_APPLE
  198. string CodedInputDataCrypt::readString(KeyValueHolderCrypt &kvHolder) {
  199. kvHolder.offset = static_cast<uint32_t>(m_position);
  200. int32_t size = this->readRawVarint32(true);
  201. if (size < 0) {
  202. throw length_error("InvalidProtocolBuffer negativeSize");
  203. }
  204. auto s_size = static_cast<size_t>(size);
  205. if (s_size <= m_size - m_position) {
  206. consumeBytes(s_size);
  207. kvHolder.keySize = static_cast<uint16_t>(s_size);
  208. string result((char *) (m_decryptBuffer + m_decryptBufferPosition), s_size);
  209. m_position += s_size;
  210. m_decryptBufferPosition += s_size;
  211. return result;
  212. } else {
  213. throw out_of_range("InvalidProtocolBuffer truncatedMessage");
  214. }
  215. }
  216. # endif
  217. void CodedInputDataCrypt::readData(KeyValueHolderCrypt &kvHolder) {
  218. int32_t size = this->readRawVarint32();
  219. if (size < 0) {
  220. throw length_error("InvalidProtocolBuffer negativeSize");
  221. }
  222. auto s_size = static_cast<size_t>(size);
  223. if (s_size <= m_size - m_position) {
  224. if (KeyValueHolderCrypt::isValueStoredAsOffset(s_size)) {
  225. kvHolder.type = KeyValueHolderType_Offset;
  226. kvHolder.valueSize = static_cast<uint32_t>(s_size);
  227. kvHolder.pbKeyValueSize =
  228. static_cast<uint8_t>(pbRawVarint32Size(kvHolder.valueSize) + pbRawVarint32Size(kvHolder.keySize));
  229. size_t rollbackSize = kvHolder.pbKeyValueSize + kvHolder.keySize;
  230. statusBeforeDecrypt(rollbackSize, kvHolder.cryptStatus);
  231. skipBytes(s_size);
  232. } else {
  233. consumeBytes(s_size);
  234. kvHolder.type = KeyValueHolderType_Direct;
  235. kvHolder = KeyValueHolderCrypt(m_decryptBuffer + m_decryptBufferPosition, s_size);
  236. m_decryptBufferPosition += s_size;
  237. m_position += s_size;
  238. }
  239. } else {
  240. throw out_of_range("InvalidProtocolBuffer truncatedMessage");
  241. }
  242. }
  243. } // namespace mmkv
  244. #endif // MMKV_DISABLE_CRYPT