decode.go 10 KB


  1. package protocol
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "hash/crc32"
  7. "io"
  8. "io/ioutil"
  9. "math"
  10. "reflect"
  11. "sync"
  12. "sync/atomic"
  13. )
  14. type discarder interface {
  15. Discard(int) (int, error)
  16. }
  17. type decoder struct {
  18. reader io.Reader
  19. remain int
  20. buffer [8]byte
  21. err error
  22. table *crc32.Table
  23. crc32 uint32
  24. }
  25. func (d *decoder) Reset(r io.Reader, n int) {
  26. d.reader = r
  27. d.remain = n
  28. d.buffer = [8]byte{}
  29. d.err = nil
  30. d.table = nil
  31. d.crc32 = 0
  32. }
  33. func (d *decoder) Read(b []byte) (int, error) {
  34. if d.err != nil {
  35. return 0, d.err
  36. }
  37. if d.remain == 0 {
  38. return 0, io.EOF
  39. }
  40. if len(b) > d.remain {
  41. b = b[:d.remain]
  42. }
  43. n, err := d.reader.Read(b)
  44. if n > 0 && d.table != nil {
  45. d.crc32 = crc32.Update(d.crc32, d.table, b[:n])
  46. }
  47. d.remain -= n
  48. return n, err
  49. }
  50. func (d *decoder) ReadByte() (byte, error) {
  51. c := d.readByte()
  52. return c, d.err
  53. }
  54. func (d *decoder) done() bool {
  55. return d.remain == 0 || d.err != nil
  56. }
  57. func (d *decoder) setCRC(table *crc32.Table) {
  58. d.table, d.crc32 = table, 0
  59. }
  60. func (d *decoder) decodeBool(v value) {
  61. v.setBool(d.readBool())
  62. }
  63. func (d *decoder) decodeInt8(v value) {
  64. v.setInt8(d.readInt8())
  65. }
  66. func (d *decoder) decodeInt16(v value) {
  67. v.setInt16(d.readInt16())
  68. }
  69. func (d *decoder) decodeInt32(v value) {
  70. v.setInt32(d.readInt32())
  71. }
  72. func (d *decoder) decodeInt64(v value) {
  73. v.setInt64(d.readInt64())
  74. }
  75. func (d *decoder) decodeFloat64(v value) {
  76. v.setFloat64(d.readFloat64())
  77. }
  78. func (d *decoder) decodeString(v value) {
  79. v.setString(d.readString())
  80. }
  81. func (d *decoder) decodeCompactString(v value) {
  82. v.setString(d.readCompactString())
  83. }
  84. func (d *decoder) decodeBytes(v value) {
  85. v.setBytes(d.readBytes())
  86. }
  87. func (d *decoder) decodeCompactBytes(v value) {
  88. v.setBytes(d.readCompactBytes())
  89. }
  90. func (d *decoder) decodeArray(v value, elemType reflect.Type, decodeElem decodeFunc) {
  91. if n := d.readInt32(); n < 0 {
  92. v.setArray(array{})
  93. } else {
  94. a := makeArray(elemType, int(n))
  95. for i := 0; i < int(n) && d.remain > 0; i++ {
  96. decodeElem(d, a.index(i))
  97. }
  98. v.setArray(a)
  99. }
  100. }
  101. func (d *decoder) decodeCompactArray(v value, elemType reflect.Type, decodeElem decodeFunc) {
  102. if n := d.readUnsignedVarInt(); n < 1 {
  103. v.setArray(array{})
  104. } else {
  105. a := makeArray(elemType, int(n-1))
  106. for i := 0; i < int(n-1) && d.remain > 0; i++ {
  107. decodeElem(d, a.index(i))
  108. }
  109. v.setArray(a)
  110. }
  111. }
  112. func (d *decoder) discardAll() {
  113. d.discard(d.remain)
  114. }
  115. func (d *decoder) discard(n int) {
  116. if n > d.remain {
  117. n = d.remain
  118. }
  119. var err error
  120. if r, _ := d.reader.(discarder); r != nil {
  121. n, err = r.Discard(n)
  122. d.remain -= n
  123. } else {
  124. _, err = io.Copy(ioutil.Discard, d)
  125. }
  126. d.setError(err)
  127. }
  128. func (d *decoder) read(n int) []byte {
  129. b := make([]byte, n)
  130. n, err := io.ReadFull(d, b)
  131. b = b[:n]
  132. d.setError(err)
  133. return b
  134. }
  135. func (d *decoder) writeTo(w io.Writer, n int) {
  136. limit := d.remain
  137. if n < limit {
  138. d.remain = n
  139. }
  140. c, err := io.Copy(w, d)
  141. if int(c) < n && err == nil {
  142. err = io.ErrUnexpectedEOF
  143. }
  144. d.remain = limit - int(c)
  145. d.setError(err)
  146. }
  147. func (d *decoder) setError(err error) {
  148. if d.err == nil && err != nil {
  149. d.err = err
  150. d.discardAll()
  151. }
  152. }
  153. func (d *decoder) readFull(b []byte) bool {
  154. n, err := io.ReadFull(d, b)
  155. d.setError(err)
  156. return n == len(b)
  157. }
  158. func (d *decoder) readByte() byte {
  159. if d.readFull(d.buffer[:1]) {
  160. return d.buffer[0]
  161. }
  162. return 0
  163. }
  164. func (d *decoder) readBool() bool {
  165. return d.readByte() != 0
  166. }
  167. func (d *decoder) readInt8() int8 {
  168. if d.readFull(d.buffer[:1]) {
  169. return readInt8(d.buffer[:1])
  170. }
  171. return 0
  172. }
  173. func (d *decoder) readInt16() int16 {
  174. if d.readFull(d.buffer[:2]) {
  175. return readInt16(d.buffer[:2])
  176. }
  177. return 0
  178. }
  179. func (d *decoder) readInt32() int32 {
  180. if d.readFull(d.buffer[:4]) {
  181. return readInt32(d.buffer[:4])
  182. }
  183. return 0
  184. }
  185. func (d *decoder) readInt64() int64 {
  186. if d.readFull(d.buffer[:8]) {
  187. return readInt64(d.buffer[:8])
  188. }
  189. return 0
  190. }
  191. func (d *decoder) readFloat64() float64 {
  192. if d.readFull(d.buffer[:8]) {
  193. return readFloat64(d.buffer[:8])
  194. }
  195. return 0
  196. }
  197. func (d *decoder) readString() string {
  198. if n := d.readInt16(); n < 0 {
  199. return ""
  200. } else {
  201. return bytesToString(d.read(int(n)))
  202. }
  203. }
  204. func (d *decoder) readVarString() string {
  205. if n := d.readVarInt(); n < 0 {
  206. return ""
  207. } else {
  208. return bytesToString(d.read(int(n)))
  209. }
  210. }
  211. func (d *decoder) readCompactString() string {
  212. if n := d.readUnsignedVarInt(); n < 1 {
  213. return ""
  214. } else {
  215. return bytesToString(d.read(int(n - 1)))
  216. }
  217. }
  218. func (d *decoder) readBytes() []byte {
  219. if n := d.readInt32(); n < 0 {
  220. return nil
  221. } else {
  222. return d.read(int(n))
  223. }
  224. }
  225. func (d *decoder) readVarBytes() []byte {
  226. if n := d.readVarInt(); n < 0 {
  227. return nil
  228. } else {
  229. return d.read(int(n))
  230. }
  231. }
  232. func (d *decoder) readCompactBytes() []byte {
  233. if n := d.readUnsignedVarInt(); n < 1 {
  234. return nil
  235. } else {
  236. return d.read(int(n - 1))
  237. }
  238. }
  239. func (d *decoder) readVarInt() int64 {
  240. n := 11 // varints are at most 11 bytes
  241. if n > d.remain {
  242. n = d.remain
  243. }
  244. x := uint64(0)
  245. s := uint(0)
  246. for n > 0 {
  247. b := d.readByte()
  248. if (b & 0x80) == 0 {
  249. x |= uint64(b) << s
  250. return int64(x>>1) ^ -(int64(x) & 1)
  251. }
  252. x |= uint64(b&0x7f) << s
  253. s += 7
  254. n--
  255. }
  256. d.setError(fmt.Errorf("cannot decode varint from input stream"))
  257. return 0
  258. }
  259. func (d *decoder) readUnsignedVarInt() uint64 {
  260. n := 11 // varints are at most 11 bytes
  261. if n > d.remain {
  262. n = d.remain
  263. }
  264. x := uint64(0)
  265. s := uint(0)
  266. for n > 0 {
  267. b := d.readByte()
  268. if (b & 0x80) == 0 {
  269. x |= uint64(b) << s
  270. return x
  271. }
  272. x |= uint64(b&0x7f) << s
  273. s += 7
  274. n--
  275. }
  276. d.setError(fmt.Errorf("cannot decode unsigned varint from input stream"))
  277. return 0
  278. }
  279. type decodeFunc func(*decoder, value)
  280. var (
  281. _ io.Reader = (*decoder)(nil)
  282. _ io.ByteReader = (*decoder)(nil)
  283. readerFrom = reflect.TypeOf((*io.ReaderFrom)(nil)).Elem()
  284. )
  285. func decodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc {
  286. if reflect.PtrTo(typ).Implements(readerFrom) {
  287. return readerDecodeFuncOf(typ)
  288. }
  289. switch typ.Kind() {
  290. case reflect.Bool:
  291. return (*decoder).decodeBool
  292. case reflect.Int8:
  293. return (*decoder).decodeInt8
  294. case reflect.Int16:
  295. return (*decoder).decodeInt16
  296. case reflect.Int32:
  297. return (*decoder).decodeInt32
  298. case reflect.Int64:
  299. return (*decoder).decodeInt64
  300. case reflect.Float64:
  301. return (*decoder).decodeFloat64
  302. case reflect.String:
  303. return stringDecodeFuncOf(flexible, tag)
  304. case reflect.Struct:
  305. return structDecodeFuncOf(typ, version, flexible)
  306. case reflect.Slice:
  307. if typ.Elem().Kind() == reflect.Uint8 { // []byte
  308. return bytesDecodeFuncOf(flexible, tag)
  309. }
  310. return arrayDecodeFuncOf(typ, version, flexible, tag)
  311. default:
  312. panic("unsupported type: " + typ.String())
  313. }
  314. }
  315. func stringDecodeFuncOf(flexible bool, tag structTag) decodeFunc {
  316. if flexible {
  317. // In flexible messages, all strings are compact
  318. return (*decoder).decodeCompactString
  319. }
  320. return (*decoder).decodeString
  321. }
  322. func bytesDecodeFuncOf(flexible bool, tag structTag) decodeFunc {
  323. if flexible {
  324. // In flexible messages, all arrays are compact
  325. return (*decoder).decodeCompactBytes
  326. }
  327. return (*decoder).decodeBytes
  328. }
  329. func structDecodeFuncOf(typ reflect.Type, version int16, flexible bool) decodeFunc {
  330. type field struct {
  331. decode decodeFunc
  332. index index
  333. tagID int
  334. }
  335. var fields []field
  336. taggedFields := map[int]*field{}
  337. forEachStructField(typ, func(typ reflect.Type, index index, tag string) {
  338. forEachStructTag(tag, func(tag structTag) bool {
  339. if tag.MinVersion <= version && version <= tag.MaxVersion {
  340. f := field{
  341. decode: decodeFuncOf(typ, version, flexible, tag),
  342. index: index,
  343. tagID: tag.TagID,
  344. }
  345. if tag.TagID < -1 {
  346. // Normal required field
  347. fields = append(fields, f)
  348. } else {
  349. // Optional tagged field (flexible messages only)
  350. taggedFields[tag.TagID] = &f
  351. }
  352. return false
  353. }
  354. return true
  355. })
  356. })
  357. return func(d *decoder, v value) {
  358. for i := range fields {
  359. f := &fields[i]
  360. f.decode(d, v.fieldByIndex(f.index))
  361. }
  362. if flexible {
  363. // See https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields
  364. // for details of tag buffers in "flexible" messages.
  365. n := int(d.readUnsignedVarInt())
  366. for i := 0; i < n; i++ {
  367. tagID := int(d.readUnsignedVarInt())
  368. size := int(d.readUnsignedVarInt())
  369. f, ok := taggedFields[tagID]
  370. if ok {
  371. f.decode(d, v.fieldByIndex(f.index))
  372. } else {
  373. d.read(size)
  374. }
  375. }
  376. }
  377. }
  378. }
  379. func arrayDecodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc {
  380. elemType := typ.Elem()
  381. elemFunc := decodeFuncOf(elemType, version, flexible, tag)
  382. if flexible {
  383. // In flexible messages, all arrays are compact
  384. return func(d *decoder, v value) { d.decodeCompactArray(v, elemType, elemFunc) }
  385. }
  386. return func(d *decoder, v value) { d.decodeArray(v, elemType, elemFunc) }
  387. }
  388. func readerDecodeFuncOf(typ reflect.Type) decodeFunc {
  389. typ = reflect.PtrTo(typ)
  390. return func(d *decoder, v value) {
  391. if d.err == nil {
  392. _, err := v.iface(typ).(io.ReaderFrom).ReadFrom(d)
  393. if err != nil {
  394. d.setError(err)
  395. }
  396. }
  397. }
  398. }
  399. func readInt8(b []byte) int8 {
  400. return int8(b[0])
  401. }
  402. func readInt16(b []byte) int16 {
  403. return int16(binary.BigEndian.Uint16(b))
  404. }
  405. func readInt32(b []byte) int32 {
  406. return int32(binary.BigEndian.Uint32(b))
  407. }
  408. func readInt64(b []byte) int64 {
  409. return int64(binary.BigEndian.Uint64(b))
  410. }
  411. func readFloat64(b []byte) float64 {
  412. return math.Float64frombits(binary.BigEndian.Uint64(b))
  413. }
  414. func Unmarshal(data []byte, version int16, value interface{}) error {
  415. typ := elemTypeOf(value)
  416. cache, _ := unmarshalers.Load().(map[versionedType]decodeFunc)
  417. key := versionedType{typ: typ, version: version}
  418. decode := cache[key]
  419. if decode == nil {
  420. decode = decodeFuncOf(reflect.TypeOf(value).Elem(), version, false, structTag{
  421. MinVersion: -1,
  422. MaxVersion: -1,
  423. TagID: -2,
  424. Compact: true,
  425. Nullable: true,
  426. })
  427. newCache := make(map[versionedType]decodeFunc, len(cache)+1)
  428. newCache[key] = decode
  429. for typ, fun := range cache {
  430. newCache[typ] = fun
  431. }
  432. unmarshalers.Store(newCache)
  433. }
  434. d, _ := decoders.Get().(*decoder)
  435. if d == nil {
  436. d = &decoder{reader: bytes.NewReader(nil)}
  437. }
  438. d.remain = len(data)
  439. r, _ := d.reader.(*bytes.Reader)
  440. r.Reset(data)
  441. defer func() {
  442. r.Reset(nil)
  443. d.Reset(r, 0)
  444. decoders.Put(d)
  445. }()
  446. decode(d, valueOf(value))
  447. return dontExpectEOF(d.err)
  448. }
  449. var (
  450. decoders sync.Pool // *decoder
  451. unmarshalers atomic.Value // map[versionedType]decodeFunc
  452. )