compress.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. package context
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "sync"
  8. "github.com/andybalholm/brotli"
  9. "github.com/golang/snappy"
  10. "github.com/klauspost/compress/flate"
  11. "github.com/klauspost/compress/gzip"
  12. "github.com/klauspost/compress/s2" // snappy output but likely faster decompression.
  13. )
  14. // The available builtin compression algorithms.
  15. const (
  16. GZIP = "gzip"
  17. DEFLATE = "deflate"
  18. BROTLI = "br"
  19. SNAPPY = "snappy"
  20. S2 = "s2"
  21. )
  22. // IDENTITY no transformation whatsoever.
  23. const IDENTITY = "identity"
  24. var (
  25. // ErrResponseNotCompressed returned from AcquireCompressResponseWriter
  26. // when response's Content-Type header is missing due to golang/go/issues/31753 or
  27. // when accept-encoding is empty. The caller should fallback to the original response writer.
  28. ErrResponseNotCompressed = errors.New("compress: response will not be compressed")
  29. // ErrRequestNotCompressed returned from NewCompressReader
  30. // when request is not compressed.
  31. ErrRequestNotCompressed = errors.New("compress: request is not compressed")
  32. // ErrNotSupportedCompression returned from
  33. // AcquireCompressResponseWriter, NewCompressWriter and NewCompressReader
  34. // when the request's Accept-Encoding was not found in the server's supported
  35. // compression algorithms. Check that error with `errors.Is`.
  36. ErrNotSupportedCompression = errors.New("compress: unsupported compression")
  37. )
  38. // AllEncodings is a slice of default content encodings.
  39. // See `AcquireCompressResponseWriter`.
  40. var AllEncodings = []string{GZIP, DEFLATE, BROTLI, SNAPPY}
  41. // GetEncoding extracts the best available encoding from the request.
  42. func GetEncoding(r *http.Request, offers []string) (string, error) {
  43. acceptEncoding := r.Header[AcceptEncodingHeaderKey]
  44. if len(acceptEncoding) == 0 {
  45. return "", ErrResponseNotCompressed
  46. }
  47. encoding := negotiateAcceptHeader(acceptEncoding, offers, IDENTITY)
  48. if encoding == "" {
  49. return "", fmt.Errorf("%w: %s", ErrNotSupportedCompression, encoding)
  50. }
  51. return encoding, nil
  52. }
  53. type (
  54. noOpWriter struct{}
  55. noOpReadCloser struct {
  56. io.Reader
  57. }
  58. )
  59. var (
  60. _ io.ReadCloser = (*noOpReadCloser)(nil)
  61. _ io.Writer = (*noOpWriter)(nil)
  62. )
  63. func (w *noOpWriter) Write(p []byte) (int, error) { return 0, nil }
  64. func (r *noOpReadCloser) Close() error {
  65. return nil
  66. }
  67. // CompressWriter is an interface which all compress writers should implement.
  68. type CompressWriter interface {
  69. io.WriteCloser
  70. // All known implementations contain `Flush` and `Reset` methods,
  71. // so we wanna declare them upfront.
  72. Flush() error
  73. Reset(io.Writer)
  74. }
  75. // NewCompressWriter returns a CompressWriter of "w" based on the given "encoding".
  76. func NewCompressWriter(w io.Writer, encoding string, level int) (cw CompressWriter, err error) {
  77. switch encoding {
  78. case GZIP:
  79. cw, err = gzip.NewWriterLevel(w, level)
  80. case DEFLATE: // -1 default level, same for gzip.
  81. cw, err = flate.NewWriter(w, level)
  82. case BROTLI: // 6 default level.
  83. if level == -1 {
  84. level = 6
  85. }
  86. cw = brotli.NewWriterLevel(w, level)
  87. case SNAPPY:
  88. cw = snappy.NewBufferedWriter(w)
  89. case S2:
  90. cw = s2.NewWriter(w)
  91. default:
  92. // Throw if "identity" is given. As this is not acceptable on "Content-Encoding" header.
  93. // Only Accept-Encoding (client) can use that; it means, no transformation whatsoever.
  94. err = ErrNotSupportedCompression
  95. }
  96. return
  97. }
  98. // CompressReader is a structure which wraps a compressed reader.
  99. // It is used for determination across common request body and a compressed one.
  100. type CompressReader struct {
  101. io.ReadCloser
  102. // We need this to reset the body to its original state, if requested.
  103. Src io.ReadCloser
  104. // Encoding is the compression alogirthm is used to decompress and read the data.
  105. Encoding string
  106. }
  107. // NewCompressReader returns a new "compressReader" wrapper of "src".
  108. // It returns `ErrRequestNotCompressed` if client's request data are not compressed
  109. // or `ErrNotSupportedCompression` if server missing the decompression algorithm.
  110. // Note: on server-side the request body (src) will be closed automaticaly.
  111. func NewCompressReader(src io.Reader, encoding string) (*CompressReader, error) {
  112. if encoding == "" || src == nil {
  113. return nil, ErrRequestNotCompressed
  114. }
  115. var (
  116. rc io.ReadCloser
  117. err error
  118. )
  119. switch encoding {
  120. case GZIP:
  121. rc, err = gzip.NewReader(src)
  122. case DEFLATE:
  123. rc = flate.NewReader(src)
  124. case BROTLI:
  125. rc = &noOpReadCloser{brotli.NewReader(src)}
  126. case SNAPPY:
  127. rc = &noOpReadCloser{snappy.NewReader(src)}
  128. case S2:
  129. rc = &noOpReadCloser{s2.NewReader(src)}
  130. default:
  131. err = ErrNotSupportedCompression
  132. }
  133. if err != nil {
  134. return nil, err
  135. }
  136. srcReadCloser, ok := src.(io.ReadCloser)
  137. if !ok {
  138. srcReadCloser = &noOpReadCloser{src}
  139. }
  140. return &CompressReader{
  141. ReadCloser: rc,
  142. Src: srcReadCloser,
  143. Encoding: encoding,
  144. }, nil
  145. }
  146. var compressWritersPool = sync.Pool{New: func() interface{} { return &CompressResponseWriter{} }}
  147. // AddCompressHeaders just adds the headers "Vary" to "Accept-Encoding"
  148. // and "Content-Encoding" to the given encoding.
  149. func AddCompressHeaders(h http.Header, encoding string) {
  150. h.Set(VaryHeaderKey, AcceptEncodingHeaderKey)
  151. h.Set(ContentEncodingHeaderKey, encoding)
  152. }
  153. // CompressResponseWriter is a compressed data http.ResponseWriter.
  154. type CompressResponseWriter struct {
  155. CompressWriter
  156. ResponseWriter
  157. http.Hijacker
  158. Disabled bool
  159. Encoding string
  160. Level int
  161. }
  162. var _ ResponseWriter = (*CompressResponseWriter)(nil)
  163. // AcquireCompressResponseWriter returns a CompressResponseWriter from the pool.
  164. // It accepts an Iris response writer, a net/http request value and
  165. // the level of compression (use -1 for default compression level).
  166. //
  167. // It returns the best candidate among "gzip", "defate", "br", "snappy" and "s2"
  168. // based on the request's "Accept-Encoding" header value.
  169. func AcquireCompressResponseWriter(w ResponseWriter, r *http.Request, level int) (*CompressResponseWriter, error) {
  170. encoding, err := GetEncoding(r, AllEncodings)
  171. if err != nil {
  172. return nil, err
  173. }
  174. v := compressWritersPool.Get().(*CompressResponseWriter)
  175. if h, ok := w.(http.Hijacker); ok {
  176. v.Hijacker = h
  177. } else {
  178. v.Hijacker = nil
  179. }
  180. // The Naive() should be used to check for Pusher,
  181. // as examples explicitly says, so don't do it:
  182. // if p, ok := w.Naive().(http.Pusher); ok {
  183. // v.Pusher = p
  184. // } else {
  185. // v.Pusher = nil
  186. // }
  187. v.ResponseWriter = w
  188. v.Disabled = false
  189. if level == -1 && encoding == BROTLI {
  190. level = 6
  191. }
  192. /*
  193. // Writer exists, encoding matching and it's valid because it has a non nil encWriter;
  194. // just reset to reduce allocations.
  195. if v.Encoding == encoding && v.Level == level && v.CompressWriter != nil {
  196. v.CompressWriter.Reset(w)
  197. return v, nil
  198. }
  199. */
  200. v.Encoding = encoding
  201. v.Level = level
  202. encWriter, err := NewCompressWriter(w, encoding, level)
  203. if err != nil {
  204. return nil, err
  205. }
  206. v.CompressWriter = encWriter
  207. AddCompressHeaders(w.Header(), encoding)
  208. return v, nil
  209. }
  210. func releaseCompressResponseWriter(w *CompressResponseWriter) {
  211. compressWritersPool.Put(w)
  212. }
  213. // FlushResponse flushes any data, closes the underline compress writer
  214. // and writes the status code.
  215. // Called automatically before `EndResponse`.
  216. func (w *CompressResponseWriter) FlushResponse() {
  217. w.FlushHeaders()
  218. /* this should NEVER happen, see `context.CompressWriter` method.
  219. if rec, ok := w.ResponseWriter.(*ResponseRecorder); ok {
  220. // Usecase: record, then compression.
  221. w.CompressWriter.Close() // flushes and closes.
  222. rec.FlushResponse()
  223. return
  224. }
  225. */
  226. // write the status, after header set and before any flushed content sent.
  227. w.ResponseWriter.FlushResponse()
  228. if w.IsHijacked() {
  229. // net/http docs:
  230. // It becomes the caller's responsibility to manage
  231. // and close the connection.
  232. return
  233. }
  234. w.CompressWriter.Close() // flushes and closes.
  235. }
  236. // FlushHeaders deletes the encoding headers if
  237. // the compressed writer was disabled otherwise
  238. // removes the content-length so next callers can re-calculate the correct length.
  239. func (w *CompressResponseWriter) FlushHeaders() {
  240. if w.Disabled {
  241. w.Header().Del(VaryHeaderKey)
  242. w.Header().Del(ContentEncodingHeaderKey)
  243. w.CompressWriter.Reset(&noOpWriter{})
  244. } else {
  245. w.ResponseWriter.Header().Del(ContentLengthHeaderKey)
  246. }
  247. }
  248. // EndResponse reeases the writers.
  249. func (w *CompressResponseWriter) EndResponse() {
  250. w.ResponseWriter.EndResponse()
  251. releaseCompressResponseWriter(w)
  252. }
  253. func (w *CompressResponseWriter) Write(p []byte) (int, error) {
  254. if w.Disabled {
  255. // If disabled or the content-type is empty the response will not be compressed (golang/go/issues/31753).
  256. return w.ResponseWriter.Write(p)
  257. }
  258. if w.Header().Get(ContentTypeHeaderKey) == "" {
  259. w.Header().Set(ContentTypeHeaderKey, http.DetectContentType(p))
  260. }
  261. return w.CompressWriter.Write(p)
  262. }
  263. // Flush sends any buffered data to the client.
  264. // Can be called manually.
  265. func (w *CompressResponseWriter) Flush() {
  266. // if w.Disabled {
  267. // w.Header().Del(VaryHeaderKey)
  268. // w.Header().Del(ContentEncodingHeaderKey)
  269. // } else {
  270. // w.encWriter.Flush()
  271. // }
  272. if !w.Disabled {
  273. w.CompressWriter.Flush()
  274. }
  275. w.ResponseWriter.Flush()
  276. }
  277. // WriteTo writes the "p" to "dest" Writer using the compression that this compress writer was made of.
  278. func (w *CompressResponseWriter) WriteTo(dest io.Writer, p []byte) (int, error) {
  279. if w.Disabled {
  280. return dest.Write(p)
  281. }
  282. cw, err := NewCompressWriter(dest, w.Encoding, w.Level)
  283. if err != nil {
  284. return 0, err
  285. }
  286. n, err := cw.Write(p)
  287. cw.Close()
  288. return n, err
  289. }
  290. // Reset implements the ResponseWriterReseter interface.
  291. func (w *CompressResponseWriter) Reset() bool {
  292. if w.Disabled {
  293. // If it's disabled then the underline one is responsible.
  294. rs, ok := w.ResponseWriter.(ResponseWriterReseter)
  295. return ok && rs.Reset()
  296. }
  297. w.CompressWriter.Reset(w.ResponseWriter)
  298. return true
  299. }