block.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. package lz4stream
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "io"
  6. "sync"
  7. "github.com/pierrec/lz4/v4/internal/lz4block"
  8. "github.com/pierrec/lz4/v4/internal/lz4errors"
  9. "github.com/pierrec/lz4/v4/internal/xxh32"
  10. )
  11. type Blocks struct {
  12. Block *FrameDataBlock
  13. Blocks chan chan *FrameDataBlock
  14. mu sync.Mutex
  15. err error
  16. }
  17. func (b *Blocks) initW(f *Frame, dst io.Writer, num int) {
  18. if num == 1 {
  19. b.Blocks = nil
  20. b.Block = NewFrameDataBlock(f)
  21. return
  22. }
  23. b.Block = nil
  24. if cap(b.Blocks) != num {
  25. b.Blocks = make(chan chan *FrameDataBlock, num)
  26. }
  27. // goroutine managing concurrent block compression goroutines.
  28. go func() {
  29. // Process next block compression item.
  30. for c := range b.Blocks {
  31. // Read the next compressed block result.
  32. // Waiting here ensures that the blocks are output in the order they were sent.
  33. // The incoming channel is always closed as it indicates to the caller that
  34. // the block has been processed.
  35. block := <-c
  36. if block == nil {
  37. // Notify the block compression routine that we are done with its result.
  38. // This is used when a sentinel block is sent to terminate the compression.
  39. close(c)
  40. return
  41. }
  42. // Do not attempt to write the block upon any previous failure.
  43. if b.err == nil {
  44. // Write the block.
  45. if err := block.Write(f, dst); err != nil {
  46. // Keep the first error.
  47. b.err = err
  48. // All pending compression goroutines need to shut down, so we need to keep going.
  49. }
  50. }
  51. close(c)
  52. }
  53. }()
  54. }
  55. func (b *Blocks) close(f *Frame, num int) error {
  56. if num == 1 {
  57. if b.Block != nil {
  58. b.Block.Close(f)
  59. }
  60. err := b.err
  61. b.err = nil
  62. return err
  63. }
  64. if b.Blocks == nil {
  65. err := b.err
  66. b.err = nil
  67. return err
  68. }
  69. c := make(chan *FrameDataBlock)
  70. b.Blocks <- c
  71. c <- nil
  72. <-c
  73. err := b.err
  74. b.err = nil
  75. return err
  76. }
  77. // ErrorR returns any error set while uncompressing a stream.
  78. func (b *Blocks) ErrorR() error {
  79. b.mu.Lock()
  80. defer b.mu.Unlock()
  81. return b.err
  82. }
  83. // initR returns a channel that streams the uncompressed blocks if in concurrent
  84. // mode and no error. When the channel is closed, check for any error with b.ErrorR.
  85. //
  86. // If not in concurrent mode, the uncompressed block is b.Block and the returned error
  87. // needs to be checked.
  88. func (b *Blocks) initR(f *Frame, num int, src io.Reader) (chan []byte, error) {
  89. size := f.Descriptor.Flags.BlockSizeIndex()
  90. if num == 1 {
  91. b.Blocks = nil
  92. b.Block = NewFrameDataBlock(f)
  93. return nil, nil
  94. }
  95. b.Block = nil
  96. blocks := make(chan chan []byte, num)
  97. // data receives the uncompressed blocks.
  98. data := make(chan []byte)
  99. // Read blocks from the source sequentially
  100. // and uncompress them concurrently.
  101. // In legacy mode, accrue the uncompress sizes in cum.
  102. var cum uint32
  103. go func() {
  104. var cumx uint32
  105. var err error
  106. for b.ErrorR() == nil {
  107. block := NewFrameDataBlock(f)
  108. cumx, err = block.Read(f, src, 0)
  109. if err != nil {
  110. block.Close(f)
  111. break
  112. }
  113. // Recheck for an error as reading may be slow and uncompressing is expensive.
  114. if b.ErrorR() != nil {
  115. block.Close(f)
  116. break
  117. }
  118. c := make(chan []byte)
  119. blocks <- c
  120. go func() {
  121. defer block.Close(f)
  122. data, err := block.Uncompress(f, size.Get(), nil, false)
  123. if err != nil {
  124. b.closeR(err)
  125. // Close the block channel to indicate an error.
  126. close(c)
  127. } else {
  128. c <- data
  129. }
  130. }()
  131. }
  132. // End the collection loop and the data channel.
  133. c := make(chan []byte)
  134. blocks <- c
  135. c <- nil // signal the collection loop that we are done
  136. <-c // wait for the collect loop to complete
  137. if f.isLegacy() && cum == cumx {
  138. err = io.EOF
  139. }
  140. b.closeR(err)
  141. close(data)
  142. }()
  143. // Collect the uncompressed blocks and make them available
  144. // on the returned channel.
  145. go func(leg bool) {
  146. defer close(blocks)
  147. skipBlocks := false
  148. for c := range blocks {
  149. buf, ok := <-c
  150. if !ok {
  151. // A closed channel indicates an error.
  152. // All remaining channels should be discarded.
  153. skipBlocks = true
  154. continue
  155. }
  156. if buf == nil {
  157. // Signal to end the loop.
  158. close(c)
  159. return
  160. }
  161. if skipBlocks {
  162. // A previous error has occurred, skipping remaining channels.
  163. continue
  164. }
  165. // Perform checksum now as the blocks are received in order.
  166. if f.Descriptor.Flags.ContentChecksum() {
  167. _, _ = f.checksum.Write(buf)
  168. }
  169. if leg {
  170. cum += uint32(len(buf))
  171. }
  172. data <- buf
  173. close(c)
  174. }
  175. }(f.isLegacy())
  176. return data, nil
  177. }
  178. // closeR safely sets the error on b if not already set.
  179. func (b *Blocks) closeR(err error) {
  180. b.mu.Lock()
  181. if b.err == nil {
  182. b.err = err
  183. }
  184. b.mu.Unlock()
  185. }
  186. func NewFrameDataBlock(f *Frame) *FrameDataBlock {
  187. buf := f.Descriptor.Flags.BlockSizeIndex().Get()
  188. return &FrameDataBlock{Data: buf, data: buf}
  189. }
  190. type FrameDataBlock struct {
  191. Size DataBlockSize
  192. Data []byte // compressed or uncompressed data (.data or .src)
  193. Checksum uint32
  194. data []byte // buffer for compressed data
  195. src []byte // uncompressed data
  196. err error // used in concurrent mode
  197. }
  198. func (b *FrameDataBlock) Close(f *Frame) {
  199. b.Size = 0
  200. b.Checksum = 0
  201. b.err = nil
  202. if b.data != nil {
  203. // Block was not already closed.
  204. lz4block.Put(b.data)
  205. b.Data = nil
  206. b.data = nil
  207. b.src = nil
  208. }
  209. }
  210. // Block compression errors are ignored since the buffer is sized appropriately.
  211. func (b *FrameDataBlock) Compress(f *Frame, src []byte, level lz4block.CompressionLevel) *FrameDataBlock {
  212. data := b.data
  213. if f.isLegacy() {
  214. // In legacy mode, the buffer is sized according to CompressBlockBound,
  215. // but only 8Mb is buffered for compression.
  216. src = src[:8<<20]
  217. } else {
  218. data = data[:len(src)] // trigger the incompressible flag in CompressBlock
  219. }
  220. var n int
  221. switch level {
  222. case lz4block.Fast:
  223. n, _ = lz4block.CompressBlock(src, data)
  224. default:
  225. n, _ = lz4block.CompressBlockHC(src, data, level)
  226. }
  227. if n == 0 {
  228. b.Size.UncompressedSet(true)
  229. b.Data = src
  230. } else {
  231. b.Size.UncompressedSet(false)
  232. b.Data = data[:n]
  233. }
  234. b.Size.sizeSet(len(b.Data))
  235. b.src = src // keep track of the source for content checksum
  236. if f.Descriptor.Flags.BlockChecksum() {
  237. b.Checksum = xxh32.ChecksumZero(src)
  238. }
  239. return b
  240. }
  241. func (b *FrameDataBlock) Write(f *Frame, dst io.Writer) error {
  242. // Write is called in the same order as blocks are compressed,
  243. // so content checksum must be done here.
  244. if f.Descriptor.Flags.ContentChecksum() {
  245. _, _ = f.checksum.Write(b.src)
  246. }
  247. buf := f.buf[:]
  248. binary.LittleEndian.PutUint32(buf, uint32(b.Size))
  249. if _, err := dst.Write(buf[:4]); err != nil {
  250. return err
  251. }
  252. if _, err := dst.Write(b.Data); err != nil {
  253. return err
  254. }
  255. if b.Checksum == 0 {
  256. return nil
  257. }
  258. binary.LittleEndian.PutUint32(buf, b.Checksum)
  259. _, err := dst.Write(buf[:4])
  260. return err
  261. }
  262. // Read updates b with the next block data, size and checksum if available.
  263. func (b *FrameDataBlock) Read(f *Frame, src io.Reader, cum uint32) (uint32, error) {
  264. x, err := f.readUint32(src)
  265. if err != nil {
  266. return 0, err
  267. }
  268. if f.isLegacy() {
  269. switch x {
  270. case frameMagicLegacy:
  271. // Concatenated legacy frame.
  272. return b.Read(f, src, cum)
  273. case cum:
  274. // Only works in non concurrent mode, for concurrent mode
  275. // it is handled separately.
  276. // Linux kernel format appends the total uncompressed size at the end.
  277. return 0, io.EOF
  278. }
  279. } else if x == 0 {
  280. // Marker for end of stream.
  281. return 0, io.EOF
  282. }
  283. b.Size = DataBlockSize(x)
  284. size := b.Size.size()
  285. if size > cap(b.data) {
  286. return x, lz4errors.ErrOptionInvalidBlockSize
  287. }
  288. b.data = b.data[:size]
  289. if _, err := io.ReadFull(src, b.data); err != nil {
  290. return x, err
  291. }
  292. if f.Descriptor.Flags.BlockChecksum() {
  293. sum, err := f.readUint32(src)
  294. if err != nil {
  295. return 0, err
  296. }
  297. b.Checksum = sum
  298. }
  299. return x, nil
  300. }
  301. func (b *FrameDataBlock) Uncompress(f *Frame, dst, dict []byte, sum bool) ([]byte, error) {
  302. if b.Size.Uncompressed() {
  303. n := copy(dst, b.data)
  304. dst = dst[:n]
  305. } else {
  306. n, err := lz4block.UncompressBlock(b.data, dst, dict)
  307. if err != nil {
  308. return nil, err
  309. }
  310. dst = dst[:n]
  311. }
  312. if f.Descriptor.Flags.BlockChecksum() {
  313. if c := xxh32.ChecksumZero(dst); c != b.Checksum {
  314. err := fmt.Errorf("%w: got %x; expected %x", lz4errors.ErrInvalidBlockChecksum, c, b.Checksum)
  315. return nil, err
  316. }
  317. }
  318. if sum && f.Descriptor.Flags.ContentChecksum() {
  319. _, _ = f.checksum.Write(dst)
  320. }
  321. return dst, nil
  322. }
  323. func (f *Frame) readUint32(r io.Reader) (x uint32, err error) {
  324. if _, err = io.ReadFull(r, f.buf[:4]); err != nil {
  325. return
  326. }
  327. x = binary.LittleEndian.Uint32(f.buf[:4])
  328. return
  329. }