zstd.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. // Package zstd implements Zstandard compression.
  2. package zstd
  3. import (
  4. "io"
  5. "sync"
  6. "github.com/klauspost/compress/zstd"
  7. )
  8. // Codec is the implementation of a compress.Codec which supports creating
  9. // readers and writers for kafka messages compressed with zstd.
  10. type Codec struct {
  11. // The compression level configured on writers created by the codec.
  12. //
  13. // Default to 3.
  14. Level int
  15. encoderPool sync.Pool // *encoder
  16. }
  17. // Code implements the compress.Codec interface.
  18. func (c *Codec) Code() int8 { return 4 }
  19. // Name implements the compress.Codec interface.
  20. func (c *Codec) Name() string { return "zstd" }
  21. // NewReader implements the compress.Codec interface.
  22. func (c *Codec) NewReader(r io.Reader) io.ReadCloser {
  23. p := new(reader)
  24. if p.dec, _ = decoderPool.Get().(*zstd.Decoder); p.dec != nil {
  25. p.dec.Reset(r)
  26. } else {
  27. z, err := zstd.NewReader(r,
  28. zstd.WithDecoderConcurrency(1),
  29. )
  30. if err != nil {
  31. p.err = err
  32. } else {
  33. p.dec = z
  34. }
  35. }
  36. return p
  37. }
  38. func (c *Codec) level() int {
  39. if c.Level != 0 {
  40. return c.Level
  41. }
  42. return 3
  43. }
  44. func (c *Codec) zstdLevel() zstd.EncoderLevel {
  45. return zstd.EncoderLevelFromZstd(c.level())
  46. }
  47. var decoderPool sync.Pool // *zstd.Decoder
  48. type reader struct {
  49. dec *zstd.Decoder
  50. err error
  51. }
  52. // Close implements the io.Closer interface.
  53. func (r *reader) Close() error {
  54. if r.dec != nil {
  55. r.dec.Reset(devNull{}) // don't retain the underlying reader
  56. decoderPool.Put(r.dec)
  57. r.dec = nil
  58. r.err = io.ErrClosedPipe
  59. }
  60. return nil
  61. }
  62. // Read implements the io.Reader interface.
  63. func (r *reader) Read(p []byte) (int, error) {
  64. if r.err != nil {
  65. return 0, r.err
  66. }
  67. if r.dec == nil {
  68. return 0, io.EOF
  69. }
  70. return r.dec.Read(p)
  71. }
  72. // WriteTo implements the io.WriterTo interface.
  73. func (r *reader) WriteTo(w io.Writer) (int64, error) {
  74. if r.err != nil {
  75. return 0, r.err
  76. }
  77. if r.dec == nil {
  78. return 0, io.ErrClosedPipe
  79. }
  80. return r.dec.WriteTo(w)
  81. }
  82. // NewWriter implements the compress.Codec interface.
  83. func (c *Codec) NewWriter(w io.Writer) io.WriteCloser {
  84. p := new(writer)
  85. if enc, _ := c.encoderPool.Get().(*zstd.Encoder); enc == nil {
  86. z, err := zstd.NewWriter(w,
  87. zstd.WithEncoderLevel(c.zstdLevel()),
  88. zstd.WithEncoderConcurrency(1),
  89. zstd.WithZeroFrames(true),
  90. )
  91. if err != nil {
  92. p.err = err
  93. } else {
  94. p.enc = z
  95. }
  96. } else {
  97. p.enc = enc
  98. p.enc.Reset(w)
  99. }
  100. p.c = c
  101. return p
  102. }
  103. type writer struct {
  104. c *Codec
  105. enc *zstd.Encoder
  106. err error
  107. }
  108. // Close implements the io.Closer interface.
  109. func (w *writer) Close() error {
  110. if w.enc != nil {
  111. // Close needs to be called to write the end of stream marker and flush
  112. // the buffers. The zstd package documents that the encoder is re-usable
  113. // after being closed.
  114. err := w.enc.Close()
  115. if err != nil {
  116. w.err = err
  117. }
  118. w.enc.Reset(devNull{}) // don't retain the underlying writer
  119. w.c.encoderPool.Put(w.enc)
  120. w.enc = nil
  121. return err
  122. }
  123. return w.err
  124. }
  125. // WriteTo implements the io.WriterTo interface.
  126. func (w *writer) Write(p []byte) (int, error) {
  127. if w.err != nil {
  128. return 0, w.err
  129. }
  130. if w.enc == nil {
  131. return 0, io.ErrClosedPipe
  132. }
  133. return w.enc.Write(p)
  134. }
  135. // ReadFrom implements the io.ReaderFrom interface.
  136. func (w *writer) ReadFrom(r io.Reader) (int64, error) {
  137. if w.err != nil {
  138. return 0, w.err
  139. }
  140. if w.enc == nil {
  141. return 0, io.ErrClosedPipe
  142. }
  143. return w.enc.ReadFrom(r)
  144. }
  145. type devNull struct{}
  146. func (devNull) Read([]byte) (int, error) { return 0, io.EOF }
  147. func (devNull) Write([]byte) (int, error) { return 0, nil }