xerial.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. package snappy
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "io"
  7. "github.com/klauspost/compress/snappy"
  8. )
  9. const defaultBufferSize = 32 * 1024
  10. // An implementation of io.Reader which consumes a stream of xerial-framed
  11. // snappy-encoeded data. The framing is optional, if no framing is detected
  12. // the reader will simply forward the bytes from its underlying stream.
  13. type xerialReader struct {
  14. reader io.Reader
  15. header [16]byte
  16. input []byte
  17. output []byte
  18. offset int64
  19. nbytes int64
  20. decode func([]byte, []byte) ([]byte, error)
  21. }
  22. func (x *xerialReader) Reset(r io.Reader) {
  23. x.reader = r
  24. x.input = x.input[:0]
  25. x.output = x.output[:0]
  26. x.header = [16]byte{}
  27. x.offset = 0
  28. x.nbytes = 0
  29. }
  30. func (x *xerialReader) Read(b []byte) (int, error) {
  31. for {
  32. if x.offset < int64(len(x.output)) {
  33. n := copy(b, x.output[x.offset:])
  34. x.offset += int64(n)
  35. return n, nil
  36. }
  37. n, err := x.readChunk(b)
  38. if err != nil {
  39. return 0, err
  40. }
  41. if n > 0 {
  42. return n, nil
  43. }
  44. }
  45. }
  46. func (x *xerialReader) WriteTo(w io.Writer) (int64, error) {
  47. wn := int64(0)
  48. for {
  49. for x.offset < int64(len(x.output)) {
  50. n, err := w.Write(x.output[x.offset:])
  51. wn += int64(n)
  52. x.offset += int64(n)
  53. if err != nil {
  54. return wn, err
  55. }
  56. }
  57. if _, err := x.readChunk(nil); err != nil {
  58. if errors.Is(err, io.EOF) {
  59. err = nil
  60. }
  61. return wn, err
  62. }
  63. }
  64. }
  65. func (x *xerialReader) readChunk(dst []byte) (int, error) {
  66. x.output = x.output[:0]
  67. x.offset = 0
  68. prefix := 0
  69. if x.nbytes == 0 {
  70. n, err := x.readFull(x.header[:])
  71. if err != nil && n == 0 {
  72. return 0, err
  73. }
  74. prefix = n
  75. }
  76. if isXerialHeader(x.header[:]) {
  77. if cap(x.input) < 4 {
  78. x.input = make([]byte, 4, defaultBufferSize)
  79. } else {
  80. x.input = x.input[:4]
  81. }
  82. _, err := x.readFull(x.input)
  83. if err != nil {
  84. return 0, err
  85. }
  86. frame := int(binary.BigEndian.Uint32(x.input))
  87. if cap(x.input) < frame {
  88. x.input = make([]byte, frame, align(frame, defaultBufferSize))
  89. } else {
  90. x.input = x.input[:frame]
  91. }
  92. if _, err := x.readFull(x.input); err != nil {
  93. return 0, err
  94. }
  95. } else {
  96. if cap(x.input) == 0 {
  97. x.input = make([]byte, 0, defaultBufferSize)
  98. } else {
  99. x.input = x.input[:0]
  100. }
  101. if prefix > 0 {
  102. x.input = append(x.input, x.header[:prefix]...)
  103. }
  104. for {
  105. if len(x.input) == cap(x.input) {
  106. b := make([]byte, len(x.input), 2*cap(x.input))
  107. copy(b, x.input)
  108. x.input = b
  109. }
  110. n, err := x.read(x.input[len(x.input):cap(x.input)])
  111. x.input = x.input[:len(x.input)+n]
  112. if err != nil {
  113. if errors.Is(err, io.EOF) && len(x.input) > 0 {
  114. break
  115. }
  116. return 0, err
  117. }
  118. }
  119. }
  120. var n int
  121. var err error
  122. if x.decode == nil {
  123. x.output, x.input, err = x.input, x.output, nil
  124. } else if n, err = snappy.DecodedLen(x.input); n <= len(dst) && err == nil {
  125. // If the output buffer is large enough to hold the decode value,
  126. // write it there directly instead of using the intermediary output
  127. // buffer.
  128. _, err = x.decode(dst, x.input)
  129. } else {
  130. var b []byte
  131. n = 0
  132. b, err = x.decode(x.output[:cap(x.output)], x.input)
  133. if err == nil {
  134. x.output = b
  135. }
  136. }
  137. return n, err
  138. }
  139. func (x *xerialReader) read(b []byte) (int, error) {
  140. n, err := x.reader.Read(b)
  141. x.nbytes += int64(n)
  142. return n, err
  143. }
  144. func (x *xerialReader) readFull(b []byte) (int, error) {
  145. n, err := io.ReadFull(x.reader, b)
  146. x.nbytes += int64(n)
  147. return n, err
  148. }
  149. // An implementation of a xerial-framed snappy-encoded output stream.
  150. // Each Write made to the writer is framed with a xerial header.
  151. type xerialWriter struct {
  152. writer io.Writer
  153. header [16]byte
  154. input []byte
  155. output []byte
  156. nbytes int64
  157. framed bool
  158. encode func([]byte, []byte) []byte
  159. }
  160. func (x *xerialWriter) Reset(w io.Writer) {
  161. x.writer = w
  162. x.input = x.input[:0]
  163. x.output = x.output[:0]
  164. x.nbytes = 0
  165. }
  166. func (x *xerialWriter) ReadFrom(r io.Reader) (int64, error) {
  167. wn := int64(0)
  168. if cap(x.input) == 0 {
  169. x.input = make([]byte, 0, defaultBufferSize)
  170. }
  171. for {
  172. if x.full() {
  173. x.grow()
  174. }
  175. n, err := r.Read(x.input[len(x.input):cap(x.input)])
  176. wn += int64(n)
  177. x.input = x.input[:len(x.input)+n]
  178. if x.fullEnough() {
  179. if err := x.Flush(); err != nil {
  180. return wn, err
  181. }
  182. }
  183. if err != nil {
  184. if errors.Is(err, io.EOF) {
  185. err = nil
  186. }
  187. return wn, err
  188. }
  189. }
  190. }
  191. func (x *xerialWriter) Write(b []byte) (int, error) {
  192. wn := 0
  193. if cap(x.input) == 0 {
  194. x.input = make([]byte, 0, defaultBufferSize)
  195. }
  196. for len(b) > 0 {
  197. if x.full() {
  198. x.grow()
  199. }
  200. n := copy(x.input[len(x.input):cap(x.input)], b)
  201. b = b[n:]
  202. wn += n
  203. x.input = x.input[:len(x.input)+n]
  204. if x.fullEnough() {
  205. if err := x.Flush(); err != nil {
  206. return wn, err
  207. }
  208. }
  209. }
  210. return wn, nil
  211. }
  212. func (x *xerialWriter) Flush() error {
  213. if len(x.input) == 0 {
  214. return nil
  215. }
  216. var b []byte
  217. if x.encode == nil {
  218. b = x.input
  219. } else {
  220. x.output = x.encode(x.output[:cap(x.output)], x.input)
  221. b = x.output
  222. }
  223. x.input = x.input[:0]
  224. x.output = x.output[:0]
  225. if x.framed && x.nbytes == 0 {
  226. writeXerialHeader(x.header[:])
  227. _, err := x.write(x.header[:])
  228. if err != nil {
  229. return err
  230. }
  231. }
  232. if x.framed {
  233. writeXerialFrame(x.header[:4], len(b))
  234. _, err := x.write(x.header[:4])
  235. if err != nil {
  236. return err
  237. }
  238. }
  239. _, err := x.write(b)
  240. return err
  241. }
  242. func (x *xerialWriter) write(b []byte) (int, error) {
  243. n, err := x.writer.Write(b)
  244. x.nbytes += int64(n)
  245. return n, err
  246. }
  247. func (x *xerialWriter) full() bool {
  248. return len(x.input) == cap(x.input)
  249. }
  250. func (x *xerialWriter) fullEnough() bool {
  251. return x.framed && (cap(x.input)-len(x.input)) < 1024
  252. }
  253. func (x *xerialWriter) grow() {
  254. tmp := make([]byte, len(x.input), 2*cap(x.input))
  255. copy(tmp, x.input)
  256. x.input = tmp
  257. }
  258. func align(n, a int) int {
  259. if (n % a) == 0 {
  260. return n
  261. }
  262. return ((n / a) + 1) * a
  263. }
  264. var (
  265. xerialHeader = [...]byte{130, 83, 78, 65, 80, 80, 89, 0}
  266. xerialVersionInfo = [...]byte{0, 0, 0, 1, 0, 0, 0, 1}
  267. )
  268. func isXerialHeader(src []byte) bool {
  269. return len(src) >= 16 && bytes.Equal(src[:8], xerialHeader[:])
  270. }
  271. func writeXerialHeader(b []byte) {
  272. copy(b[:8], xerialHeader[:])
  273. copy(b[8:], xerialVersionInfo[:])
  274. }
  275. func writeXerialFrame(b []byte, n int) {
  276. binary.BigEndian.PutUint32(b, uint32(n))
  277. }