compression.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "compress/flate"
  7. "errors"
  8. "io"
  9. "log"
  10. "strings"
  11. "sync"
  12. )
  13. const (
  14. minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
  15. maxCompressionLevel = flate.BestCompression
  16. defaultCompressionLevel = 1
  17. )
  18. var (
  19. flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
  20. flateReaderPool = sync.Pool{New: func() interface{} {
  21. return flate.NewReader(nil)
  22. }}
  23. )
  24. func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
  25. const tail =
  26. // Add four bytes as specified in RFC
  27. "\x00\x00\xff\xff" +
  28. // Add final block to squelch unexpected EOF error from flate reader.
  29. "\x01\x00\x00\xff\xff"
  30. fr, _ := flateReaderPool.Get().(io.ReadCloser)
  31. if err := fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil); err != nil {
  32. panic(err)
  33. }
  34. return &flateReadWrapper{fr}
  35. }
  36. func isValidCompressionLevel(level int) bool {
  37. return minCompressionLevel <= level && level <= maxCompressionLevel
  38. }
  39. func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
  40. p := &flateWriterPools[level-minCompressionLevel]
  41. tw := &truncWriter{w: w}
  42. fw, _ := p.Get().(*flate.Writer)
  43. if fw == nil {
  44. fw, _ = flate.NewWriter(tw, level)
  45. } else {
  46. fw.Reset(tw)
  47. }
  48. return &flateWriteWrapper{fw: fw, tw: tw, p: p}
  49. }
  50. // truncWriter is an io.Writer that writes all but the last four bytes of the
  51. // stream to another io.Writer.
  52. type truncWriter struct {
  53. w io.WriteCloser
  54. n int
  55. p [4]byte
  56. }
  57. func (w *truncWriter) Write(p []byte) (int, error) {
  58. n := 0
  59. // fill buffer first for simplicity.
  60. if w.n < len(w.p) {
  61. n = copy(w.p[w.n:], p)
  62. p = p[n:]
  63. w.n += n
  64. if len(p) == 0 {
  65. return n, nil
  66. }
  67. }
  68. m := len(p)
  69. if m > len(w.p) {
  70. m = len(w.p)
  71. }
  72. if nn, err := w.w.Write(w.p[:m]); err != nil {
  73. return n + nn, err
  74. }
  75. copy(w.p[:], w.p[m:])
  76. copy(w.p[len(w.p)-m:], p[len(p)-m:])
  77. nn, err := w.w.Write(p[:len(p)-m])
  78. return n + nn, err
  79. }
  80. type flateWriteWrapper struct {
  81. fw *flate.Writer
  82. tw *truncWriter
  83. p *sync.Pool
  84. }
  85. func (w *flateWriteWrapper) Write(p []byte) (int, error) {
  86. if w.fw == nil {
  87. return 0, errWriteClosed
  88. }
  89. return w.fw.Write(p)
  90. }
  91. func (w *flateWriteWrapper) Close() error {
  92. if w.fw == nil {
  93. return errWriteClosed
  94. }
  95. err1 := w.fw.Flush()
  96. w.p.Put(w.fw)
  97. w.fw = nil
  98. if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
  99. return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
  100. }
  101. err2 := w.tw.w.Close()
  102. if err1 != nil {
  103. return err1
  104. }
  105. return err2
  106. }
  107. type flateReadWrapper struct {
  108. fr io.ReadCloser
  109. }
  110. func (r *flateReadWrapper) Read(p []byte) (int, error) {
  111. if r.fr == nil {
  112. return 0, io.ErrClosedPipe
  113. }
  114. n, err := r.fr.Read(p)
  115. if err == io.EOF {
  116. // Preemptively place the reader back in the pool. This helps with
  117. // scenarios where the application does not call NextReader() soon after
  118. // this final read.
  119. if err := r.Close(); err != nil {
  120. log.Printf("websocket: flateReadWrapper.Close() returned error: %v", err)
  121. }
  122. }
  123. return n, err
  124. }
  125. func (r *flateReadWrapper) Close() error {
  126. if r.fr == nil {
  127. return io.ErrClosedPipe
  128. }
  129. err := r.fr.Close()
  130. flateReaderPool.Put(r.fr)
  131. r.fr = nil
  132. return err
  133. }