compress.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. package context
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "sync"
  8. "github.com/andybalholm/brotli"
  9. "github.com/klauspost/compress/flate"
  10. "github.com/klauspost/compress/gzip"
  11. "github.com/klauspost/compress/s2" // snappy output but likely faster decompression.
  12. "github.com/klauspost/compress/snappy"
  13. )
  14. // The available builtin compression algorithms.
  15. const (
  16. GZIP = "gzip"
  17. DEFLATE = "deflate"
  18. BROTLI = "br"
  19. SNAPPY = "snappy"
  20. S2 = "s2"
  21. )
  22. // IDENTITY no transformation whatsoever.
  23. const IDENTITY = "identity"
  24. var (
  25. // ErrResponseNotCompressed returned from AcquireCompressResponseWriter
  26. // when response's Content-Type header is missing due to golang/go/issues/31753 or
  27. // when accept-encoding is empty. The caller should fallback to the original response writer.
  28. ErrResponseNotCompressed = errors.New("compress: response will not be compressed")
  29. // ErrRequestNotCompressed returned from NewCompressReader
  30. // when request is not compressed.
  31. ErrRequestNotCompressed = errors.New("compress: request is not compressed")
  32. // ErrNotSupportedCompression returned from
  33. // AcquireCompressResponseWriter, NewCompressWriter and NewCompressReader
  34. // when the request's Accept-Encoding was not found in the server's supported
  35. // compression algorithms. Check that error with `errors.Is`.
  36. ErrNotSupportedCompression = errors.New("compress: unsupported compression")
  37. )
  38. // AllEncodings is a slice of default content encodings.
  39. // See `AcquireCompressResponseWriter`.
  40. var AllEncodings = []string{GZIP, DEFLATE, BROTLI, SNAPPY}
  41. // GetEncoding extracts the best available encoding from the request.
  42. func GetEncoding(r *http.Request, offers []string) (string, error) {
  43. acceptEncoding := r.Header[AcceptEncodingHeaderKey]
  44. if len(acceptEncoding) == 0 {
  45. return "", ErrResponseNotCompressed
  46. }
  47. encoding := negotiateAcceptHeader(acceptEncoding, offers, IDENTITY)
  48. if encoding == "" {
  49. return "", fmt.Errorf("%w: %s", ErrNotSupportedCompression, encoding)
  50. }
  51. return encoding, nil
  52. }
  53. type (
  54. noOpWriter struct{}
  55. noOpReadCloser struct {
  56. io.Reader
  57. }
  58. )
  59. var (
  60. _ io.ReadCloser = (*noOpReadCloser)(nil)
  61. _ io.Writer = (*noOpWriter)(nil)
  62. )
  63. func (w *noOpWriter) Write(p []byte) (int, error) { return 0, nil }
  64. func (r *noOpReadCloser) Close() error {
  65. return nil
  66. }
  67. // CompressWriter is an interface which all compress writers should implement.
  68. type CompressWriter interface {
  69. io.WriteCloser
  70. // All known implementations contain `Flush` and `Reset` methods,
  71. // so we wanna declare them upfront.
  72. Flush() error
  73. Reset(io.Writer)
  74. }
  75. // NewCompressWriter returns a CompressWriter of "w" based on the given "encoding".
  76. func NewCompressWriter(w io.Writer, encoding string, level int) (cw CompressWriter, err error) {
  77. switch encoding {
  78. case GZIP:
  79. cw, err = gzip.NewWriterLevel(w, level)
  80. case DEFLATE: // -1 default level, same for gzip.
  81. cw, err = flate.NewWriter(w, level)
  82. case BROTLI: // 6 default level.
  83. if level == -1 {
  84. level = 6
  85. }
  86. cw = brotli.NewWriterLevel(w, level)
  87. case SNAPPY:
  88. cw = snappy.NewWriter(w)
  89. case S2:
  90. cw = s2.NewWriter(w)
  91. default:
  92. // Throw if "identity" is given. As this is not acceptable on "Content-Encoding" header.
  93. // Only Accept-Encoding (client) can use that; it means, no transformation whatsoever.
  94. err = ErrNotSupportedCompression
  95. }
  96. return
  97. }
  98. // CompressReader is a structure which wraps a compressed reader.
  99. // It is used for determination across common request body and a compressed one.
  100. type CompressReader struct {
  101. io.ReadCloser
  102. // We need this to reset the body to its original state, if requested.
  103. Src io.ReadCloser
  104. // Encoding is the compression alogirthm is used to decompress and read the data.
  105. Encoding string
  106. }
  107. // NewCompressReader returns a new "compressReader" wrapper of "src".
  108. // It returns `ErrRequestNotCompressed` if client's request data are not compressed
  109. // or `ErrNotSupportedCompression` if server missing the decompression algorithm.
  110. // Note: on server-side the request body (src) will be closed automaticaly.
  111. func NewCompressReader(src io.Reader, encoding string) (*CompressReader, error) {
  112. if encoding == "" || src == nil {
  113. return nil, ErrRequestNotCompressed
  114. }
  115. var (
  116. rc io.ReadCloser
  117. err error
  118. )
  119. switch encoding {
  120. case GZIP:
  121. rc, err = gzip.NewReader(src)
  122. case DEFLATE:
  123. rc = flate.NewReader(src)
  124. case BROTLI:
  125. rc = &noOpReadCloser{brotli.NewReader(src)}
  126. case SNAPPY:
  127. rc = &noOpReadCloser{snappy.NewReader(src)}
  128. case S2:
  129. rc = &noOpReadCloser{s2.NewReader(src)}
  130. default:
  131. err = ErrNotSupportedCompression
  132. }
  133. if err != nil {
  134. return nil, err
  135. }
  136. srcReadCloser, ok := src.(io.ReadCloser)
  137. if !ok {
  138. srcReadCloser = &noOpReadCloser{src}
  139. }
  140. return &CompressReader{
  141. ReadCloser: rc,
  142. Src: srcReadCloser,
  143. Encoding: encoding,
  144. }, nil
  145. }
  146. var compressWritersPool = sync.Pool{New: func() interface{} { return &CompressResponseWriter{} }}
  147. // AddCompressHeaders just adds the headers "Vary" to "Accept-Encoding"
  148. // and "Content-Encoding" to the given encoding.
  149. func AddCompressHeaders(h http.Header, encoding string) {
  150. h.Set(VaryHeaderKey, AcceptEncodingHeaderKey)
  151. h.Set(ContentEncodingHeaderKey, encoding)
  152. }
  153. // CompressResponseWriter is a compressed data http.ResponseWriter.
  154. type CompressResponseWriter struct {
  155. CompressWriter
  156. ResponseWriter
  157. Disabled bool
  158. Encoding string
  159. Level int
  160. }
  161. var _ ResponseWriter = (*CompressResponseWriter)(nil)
  162. // AcquireCompressResponseWriter returns a CompressResponseWriter from the pool.
  163. // It accepts an Iris response writer, a net/http request value and
  164. // the level of compression (use -1 for default compression level).
  165. //
  166. // It returns the best candidate among "gzip", "defate", "br", "snappy" and "s2"
  167. // based on the request's "Accept-Encoding" header value.
  168. func AcquireCompressResponseWriter(w ResponseWriter, r *http.Request, level int) (*CompressResponseWriter, error) {
  169. encoding, err := GetEncoding(r, AllEncodings)
  170. if err != nil {
  171. return nil, err
  172. }
  173. v := compressWritersPool.Get().(*CompressResponseWriter)
  174. v.ResponseWriter = w
  175. v.Disabled = false
  176. if level == -1 && encoding == BROTLI {
  177. level = 6
  178. }
  179. /*
  180. // Writer exists, encoding matching and it's valid because it has a non nil encWriter;
  181. // just reset to reduce allocations.
  182. if v.Encoding == encoding && v.Level == level && v.CompressWriter != nil {
  183. v.CompressWriter.Reset(w)
  184. return v, nil
  185. }
  186. */
  187. v.Encoding = encoding
  188. v.Level = level
  189. encWriter, err := NewCompressWriter(w, encoding, level)
  190. if err != nil {
  191. return nil, err
  192. }
  193. v.CompressWriter = encWriter
  194. AddCompressHeaders(w.Header(), encoding)
  195. return v, nil
  196. }
  197. func releaseCompressResponseWriter(w *CompressResponseWriter) {
  198. compressWritersPool.Put(w)
  199. }
  200. // FlushResponse flushes any data, closes the underline compress writer
  201. // and writes the status code.
  202. // Called automatically before `EndResponse`.
  203. func (w *CompressResponseWriter) FlushResponse() {
  204. w.FlushHeaders()
  205. /* this should NEVER happen, see `context.CompressWriter` method.
  206. if rec, ok := w.ResponseWriter.(*ResponseRecorder); ok {
  207. // Usecase: record, then compression.
  208. w.CompressWriter.Close() // flushes and closes.
  209. rec.FlushResponse()
  210. return
  211. }
  212. */
  213. // write the status, after header set and before any flushed content sent.
  214. w.ResponseWriter.FlushResponse()
  215. w.CompressWriter.Close() // flushes and closes.
  216. }
  217. // FlushHeaders deletes the encoding headers if
  218. // the compressed writer was disabled otherwise
  219. // removes the content-length so next callers can re-calculate the correct length.
  220. func (w *CompressResponseWriter) FlushHeaders() {
  221. if w.Disabled {
  222. w.Header().Del(VaryHeaderKey)
  223. w.Header().Del(ContentEncodingHeaderKey)
  224. w.CompressWriter.Reset(&noOpWriter{})
  225. } else {
  226. w.ResponseWriter.Header().Del(ContentLengthHeaderKey)
  227. }
  228. }
  229. // EndResponse reeases the writers.
  230. func (w *CompressResponseWriter) EndResponse() {
  231. w.ResponseWriter.EndResponse()
  232. releaseCompressResponseWriter(w)
  233. }
  234. func (w *CompressResponseWriter) Write(p []byte) (int, error) {
  235. if w.Disabled {
  236. // If disabled or the content-type is empty the response will not be compressed (golang/go/issues/31753).
  237. return w.ResponseWriter.Write(p)
  238. }
  239. if w.Header().Get(ContentTypeHeaderKey) == "" {
  240. w.Header().Set(ContentTypeHeaderKey, http.DetectContentType(p))
  241. }
  242. return w.CompressWriter.Write(p)
  243. }
  244. // Flush sends any buffered data to the client.
  245. // Can be called manually.
  246. func (w *CompressResponseWriter) Flush() {
  247. // if w.Disabled {
  248. // w.Header().Del(VaryHeaderKey)
  249. // w.Header().Del(ContentEncodingHeaderKey)
  250. // } else {
  251. // w.encWriter.Flush()
  252. // }
  253. if !w.Disabled {
  254. w.CompressWriter.Flush()
  255. }
  256. w.ResponseWriter.Flush()
  257. }
  258. // WriteTo writes the "p" to "dest" Writer using the compression that this compress writer was made of.
  259. func (w *CompressResponseWriter) WriteTo(dest io.Writer, p []byte) (int, error) {
  260. if w.Disabled {
  261. return dest.Write(p)
  262. }
  263. cw, err := NewCompressWriter(dest, w.Encoding, w.Level)
  264. if err != nil {
  265. return 0, err
  266. }
  267. n, err := cw.Write(p)
  268. cw.Close()
  269. return n, err
  270. }