socket.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. // mgo - MongoDB driver for Go
  2. //
  3. // Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
  4. //
  5. // All rights reserved.
  6. //
  7. // Redistribution and use in source and binary forms, with or without
  8. // modification, are permitted provided that the following conditions are met:
  9. //
  10. // 1. Redistributions of source code must retain the above copyright notice, this
  11. // list of conditions and the following disclaimer.
  12. // 2. Redistributions in binary form must reproduce the above copyright notice,
  13. // this list of conditions and the following disclaimer in the documentation
  14. // and/or other materials provided with the distribution.
  15. //
  16. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  17. // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  18. // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  19. // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
  20. // ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  21. // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  22. // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  23. // ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  25. // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. package mgo
  27. import (
  28. "errors"
  29. "labix.org/v2/mgo/bson"
  30. "net"
  31. "sync"
  32. "time"
  33. )
  34. type replyFunc func(err error, reply *replyOp, docNum int, docData []byte)
  35. type mongoSocket struct {
  36. sync.Mutex
  37. server *mongoServer // nil when cached
  38. conn net.Conn
  39. timeout time.Duration
  40. addr string // For debugging only.
  41. nextRequestId uint32
  42. replyFuncs map[uint32]replyFunc
  43. references int
  44. creds []Credential
  45. logout []Credential
  46. cachedNonce string
  47. gotNonce sync.Cond
  48. dead error
  49. serverInfo *mongoServerInfo
  50. }
  51. type queryOpFlags uint32
  52. const (
  53. _ queryOpFlags = 1 << iota
  54. flagTailable
  55. flagSlaveOk
  56. flagLogReplay
  57. flagNoCursorTimeout
  58. flagAwaitData
  59. )
  60. type queryOp struct {
  61. collection string
  62. query interface{}
  63. skip int32
  64. limit int32
  65. selector interface{}
  66. flags queryOpFlags
  67. replyFunc replyFunc
  68. options queryWrapper
  69. hasOptions bool
  70. serverTags []bson.D
  71. }
  72. type queryWrapper struct {
  73. Query interface{} "$query"
  74. OrderBy interface{} "$orderby,omitempty"
  75. Hint interface{} "$hint,omitempty"
  76. Explain bool "$explain,omitempty"
  77. Snapshot bool "$snapshot,omitempty"
  78. ReadPreference bson.D "$readPreference,omitempty"
  79. }
  80. func (op *queryOp) finalQuery(socket *mongoSocket) interface{} {
  81. if op.flags&flagSlaveOk != 0 && len(op.serverTags) > 0 && socket.ServerInfo().Mongos {
  82. op.hasOptions = true
  83. op.options.ReadPreference = bson.D{{"mode", "secondaryPreferred"}, {"tags", op.serverTags}}
  84. }
  85. if op.hasOptions {
  86. if op.query == nil {
  87. var empty bson.D
  88. op.options.Query = empty
  89. } else {
  90. op.options.Query = op.query
  91. }
  92. debugf("final query is %#v\n", &op.options)
  93. return &op.options
  94. }
  95. return op.query
  96. }
  97. type getMoreOp struct {
  98. collection string
  99. limit int32
  100. cursorId int64
  101. replyFunc replyFunc
  102. }
  103. type replyOp struct {
  104. flags uint32
  105. cursorId int64
  106. firstDoc int32
  107. replyDocs int32
  108. }
  109. type insertOp struct {
  110. collection string // "database.collection"
  111. documents []interface{} // One or more documents to insert
  112. flags uint32
  113. }
  114. type updateOp struct {
  115. collection string // "database.collection"
  116. selector interface{}
  117. update interface{}
  118. flags uint32
  119. }
  120. type deleteOp struct {
  121. collection string // "database.collection"
  122. selector interface{}
  123. flags uint32
  124. }
  125. type killCursorsOp struct {
  126. cursorIds []int64
  127. }
  128. type requestInfo struct {
  129. bufferPos int
  130. replyFunc replyFunc
  131. }
  132. func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket {
  133. socket := &mongoSocket{
  134. conn: conn,
  135. addr: server.Addr,
  136. server: server,
  137. replyFuncs: make(map[uint32]replyFunc),
  138. }
  139. socket.gotNonce.L = &socket.Mutex
  140. if err := socket.InitialAcquire(server.Info(), timeout); err != nil {
  141. panic("newSocket: InitialAcquire returned error: " + err.Error())
  142. }
  143. stats.socketsAlive(+1)
  144. debugf("Socket %p to %s: initialized", socket, socket.addr)
  145. socket.resetNonce()
  146. go socket.readLoop()
  147. return socket
  148. }
  149. // Server returns the server that the socket is associated with.
  150. // It returns nil while the socket is cached in its respective server.
  151. func (socket *mongoSocket) Server() *mongoServer {
  152. socket.Lock()
  153. server := socket.server
  154. socket.Unlock()
  155. return server
  156. }
  157. // ServerInfo returns details for the server at the time the socket
  158. // was initially acquired.
  159. func (socket *mongoSocket) ServerInfo() *mongoServerInfo {
  160. socket.Lock()
  161. serverInfo := socket.serverInfo
  162. socket.Unlock()
  163. return serverInfo
  164. }
  165. // InitialAcquire obtains the first reference to the socket, either
  166. // right after the connection is made or once a recycled socket is
  167. // being put back in use.
  168. func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error {
  169. socket.Lock()
  170. if socket.references > 0 {
  171. panic("Socket acquired out of cache with references")
  172. }
  173. if socket.dead != nil {
  174. dead := socket.dead
  175. socket.Unlock()
  176. return dead
  177. }
  178. socket.references++
  179. socket.serverInfo = serverInfo
  180. socket.timeout = timeout
  181. stats.socketsInUse(+1)
  182. stats.socketRefs(+1)
  183. socket.Unlock()
  184. return nil
  185. }
  186. // Acquire obtains an additional reference to the socket.
  187. // The socket will only be recycled when it's released as many
  188. // times as it's been acquired.
  189. func (socket *mongoSocket) Acquire() (info *mongoServerInfo) {
  190. socket.Lock()
  191. if socket.references == 0 {
  192. panic("Socket got non-initial acquire with references == 0")
  193. }
  194. // We'll track references to dead sockets as well.
  195. // Caller is still supposed to release the socket.
  196. socket.references++
  197. stats.socketRefs(+1)
  198. serverInfo := socket.serverInfo
  199. socket.Unlock()
  200. return serverInfo
  201. }
  202. // Release decrements a socket reference. The socket will be
  203. // recycled once its released as many times as it's been acquired.
  204. func (socket *mongoSocket) Release() {
  205. socket.Lock()
  206. if socket.references == 0 {
  207. panic("socket.Release() with references == 0")
  208. }
  209. socket.references--
  210. stats.socketRefs(-1)
  211. if socket.references == 0 {
  212. stats.socketsInUse(-1)
  213. server := socket.server
  214. socket.Unlock()
  215. socket.LogoutAll()
  216. // If the socket is dead server is nil.
  217. if server != nil {
  218. server.RecycleSocket(socket)
  219. }
  220. } else {
  221. socket.Unlock()
  222. }
  223. }
  224. // SetTimeout changes the timeout used on socket operations.
  225. func (socket *mongoSocket) SetTimeout(d time.Duration) {
  226. socket.Lock()
  227. socket.timeout = d
  228. socket.Unlock()
  229. }
  230. type deadlineType int
  231. const (
  232. readDeadline deadlineType = 1
  233. writeDeadline deadlineType = 2
  234. )
  235. func (socket *mongoSocket) updateDeadline(which deadlineType) {
  236. var when time.Time
  237. if socket.timeout > 0 {
  238. when = time.Now().Add(socket.timeout)
  239. }
  240. whichstr := ""
  241. switch which {
  242. case readDeadline | writeDeadline:
  243. whichstr = "read/write"
  244. socket.conn.SetDeadline(when)
  245. case readDeadline:
  246. whichstr = "read"
  247. socket.conn.SetReadDeadline(when)
  248. case writeDeadline:
  249. whichstr = "write"
  250. socket.conn.SetWriteDeadline(when)
  251. default:
  252. panic("invalid parameter to updateDeadline")
  253. }
  254. debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when)
  255. }
  256. // Close terminates the socket use.
  257. func (socket *mongoSocket) Close() {
  258. socket.kill(errors.New("Closed explicitly"), false)
  259. }
  260. func (socket *mongoSocket) kill(err error, abend bool) {
  261. socket.Lock()
  262. if socket.dead != nil {
  263. debugf("Socket %p to %s: killed again: %s (previously: %s)", socket, socket.addr, err.Error(), socket.dead.Error())
  264. socket.Unlock()
  265. return
  266. }
  267. logf("Socket %p to %s: closing: %s (abend=%v)", socket, socket.addr, err.Error(), abend)
  268. socket.dead = err
  269. socket.conn.Close()
  270. stats.socketsAlive(-1)
  271. replyFuncs := socket.replyFuncs
  272. socket.replyFuncs = make(map[uint32]replyFunc)
  273. server := socket.server
  274. socket.server = nil
  275. socket.gotNonce.Broadcast()
  276. socket.Unlock()
  277. for _, replyFunc := range replyFuncs {
  278. logf("Socket %p to %s: notifying replyFunc of closed socket: %s", socket, socket.addr, err.Error())
  279. replyFunc(err, nil, -1, nil)
  280. }
  281. if abend {
  282. server.AbendSocket(socket)
  283. }
  284. }
  285. func (socket *mongoSocket) SimpleQuery(op *queryOp) (data []byte, err error) {
  286. var wait, change sync.Mutex
  287. var replyDone bool
  288. var replyData []byte
  289. var replyErr error
  290. wait.Lock()
  291. op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) {
  292. change.Lock()
  293. if !replyDone {
  294. replyDone = true
  295. replyErr = err
  296. if err == nil {
  297. replyData = docData
  298. }
  299. }
  300. change.Unlock()
  301. wait.Unlock()
  302. }
  303. err = socket.Query(op)
  304. if err != nil {
  305. return nil, err
  306. }
  307. wait.Lock()
  308. change.Lock()
  309. data = replyData
  310. err = replyErr
  311. change.Unlock()
  312. return data, err
  313. }
  314. func (socket *mongoSocket) Query(ops ...interface{}) (err error) {
  315. if lops := socket.flushLogout(); len(lops) > 0 {
  316. ops = append(lops, ops...)
  317. }
  318. buf := make([]byte, 0, 256)
  319. // Serialize operations synchronously to avoid interrupting
  320. // other goroutines while we can't really be sending data.
  321. // Also, record id positions so that we can compute request
  322. // ids at once later with the lock already held.
  323. requests := make([]requestInfo, len(ops))
  324. requestCount := 0
  325. for _, op := range ops {
  326. debugf("Socket %p to %s: serializing op: %#v", socket, socket.addr, op)
  327. start := len(buf)
  328. var replyFunc replyFunc
  329. switch op := op.(type) {
  330. case *updateOp:
  331. buf = addHeader(buf, 2001)
  332. buf = addInt32(buf, 0) // Reserved
  333. buf = addCString(buf, op.collection)
  334. buf = addInt32(buf, int32(op.flags))
  335. debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.selector)
  336. buf, err = addBSON(buf, op.selector)
  337. if err != nil {
  338. return err
  339. }
  340. debugf("Socket %p to %s: serializing update document: %#v", socket, socket.addr, op.update)
  341. buf, err = addBSON(buf, op.update)
  342. if err != nil {
  343. return err
  344. }
  345. case *insertOp:
  346. buf = addHeader(buf, 2002)
  347. buf = addInt32(buf, int32(op.flags))
  348. buf = addCString(buf, op.collection)
  349. for _, doc := range op.documents {
  350. debugf("Socket %p to %s: serializing document for insertion: %#v", socket, socket.addr, doc)
  351. buf, err = addBSON(buf, doc)
  352. if err != nil {
  353. return err
  354. }
  355. }
  356. case *queryOp:
  357. buf = addHeader(buf, 2004)
  358. buf = addInt32(buf, int32(op.flags))
  359. buf = addCString(buf, op.collection)
  360. buf = addInt32(buf, op.skip)
  361. buf = addInt32(buf, op.limit)
  362. buf, err = addBSON(buf, op.finalQuery(socket))
  363. if err != nil {
  364. return err
  365. }
  366. if op.selector != nil {
  367. buf, err = addBSON(buf, op.selector)
  368. if err != nil {
  369. return err
  370. }
  371. }
  372. replyFunc = op.replyFunc
  373. case *getMoreOp:
  374. buf = addHeader(buf, 2005)
  375. buf = addInt32(buf, 0) // Reserved
  376. buf = addCString(buf, op.collection)
  377. buf = addInt32(buf, op.limit)
  378. buf = addInt64(buf, op.cursorId)
  379. replyFunc = op.replyFunc
  380. case *deleteOp:
  381. buf = addHeader(buf, 2006)
  382. buf = addInt32(buf, 0) // Reserved
  383. buf = addCString(buf, op.collection)
  384. buf = addInt32(buf, int32(op.flags))
  385. debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.selector)
  386. buf, err = addBSON(buf, op.selector)
  387. if err != nil {
  388. return err
  389. }
  390. case *killCursorsOp:
  391. buf = addHeader(buf, 2007)
  392. buf = addInt32(buf, 0) // Reserved
  393. buf = addInt32(buf, int32(len(op.cursorIds)))
  394. for _, cursorId := range op.cursorIds {
  395. buf = addInt64(buf, cursorId)
  396. }
  397. default:
  398. panic("internal error: unknown operation type")
  399. }
  400. setInt32(buf, start, int32(len(buf)-start))
  401. if replyFunc != nil {
  402. request := &requests[requestCount]
  403. request.replyFunc = replyFunc
  404. request.bufferPos = start
  405. requestCount++
  406. }
  407. }
  408. // Buffer is ready for the pipe. Lock, allocate ids, and enqueue.
  409. socket.Lock()
  410. if socket.dead != nil {
  411. dead := socket.dead
  412. socket.Unlock()
  413. debugf("Socket %p to %s: failing query, already closed: %s", socket, socket.addr, socket.dead.Error())
  414. // XXX This seems necessary in case the session is closed concurrently
  415. // with a query being performed, but it's not yet tested:
  416. for i := 0; i != requestCount; i++ {
  417. request := &requests[i]
  418. if request.replyFunc != nil {
  419. request.replyFunc(dead, nil, -1, nil)
  420. }
  421. }
  422. return dead
  423. }
  424. wasWaiting := len(socket.replyFuncs) > 0
  425. // Reserve id 0 for requests which should have no responses.
  426. requestId := socket.nextRequestId + 1
  427. if requestId == 0 {
  428. requestId++
  429. }
  430. socket.nextRequestId = requestId + uint32(requestCount)
  431. for i := 0; i != requestCount; i++ {
  432. request := &requests[i]
  433. setInt32(buf, request.bufferPos+4, int32(requestId))
  434. socket.replyFuncs[requestId] = request.replyFunc
  435. requestId++
  436. }
  437. debugf("Socket %p to %s: sending %d op(s) (%d bytes)", socket, socket.addr, len(ops), len(buf))
  438. stats.sentOps(len(ops))
  439. socket.updateDeadline(writeDeadline)
  440. _, err = socket.conn.Write(buf)
  441. if !wasWaiting && requestCount > 0 {
  442. socket.updateDeadline(readDeadline)
  443. }
  444. socket.Unlock()
  445. return err
  446. }
  447. func fill(r net.Conn, b []byte) error {
  448. l := len(b)
  449. n, err := r.Read(b)
  450. for n != l && err == nil {
  451. var ni int
  452. ni, err = r.Read(b[n:])
  453. n += ni
  454. }
  455. return err
  456. }
  457. // Estimated minimum cost per socket: 1 goroutine + memory for the largest
  458. // document ever seen.
  459. func (socket *mongoSocket) readLoop() {
  460. p := make([]byte, 36) // 16 from header + 20 from OP_REPLY fixed fields
  461. s := make([]byte, 4)
  462. conn := socket.conn // No locking, conn never changes.
  463. for {
  464. // XXX Handle timeouts, , etc
  465. err := fill(conn, p)
  466. if err != nil {
  467. socket.kill(err, true)
  468. return
  469. }
  470. totalLen := getInt32(p, 0)
  471. responseTo := getInt32(p, 8)
  472. opCode := getInt32(p, 12)
  473. // Don't use socket.server.Addr here. socket is not
  474. // locked and socket.server may go away.
  475. debugf("Socket %p to %s: got reply (%d bytes)", socket, socket.addr, totalLen)
  476. _ = totalLen
  477. if opCode != 1 {
  478. socket.kill(errors.New("opcode != 1, corrupted data?"), true)
  479. return
  480. }
  481. reply := replyOp{
  482. flags: uint32(getInt32(p, 16)),
  483. cursorId: getInt64(p, 20),
  484. firstDoc: getInt32(p, 28),
  485. replyDocs: getInt32(p, 32),
  486. }
  487. stats.receivedOps(+1)
  488. stats.receivedDocs(int(reply.replyDocs))
  489. socket.Lock()
  490. replyFunc, ok := socket.replyFuncs[uint32(responseTo)]
  491. if ok {
  492. delete(socket.replyFuncs, uint32(responseTo))
  493. }
  494. socket.Unlock()
  495. if replyFunc != nil && reply.replyDocs == 0 {
  496. replyFunc(nil, &reply, -1, nil)
  497. } else {
  498. for i := 0; i != int(reply.replyDocs); i++ {
  499. err := fill(conn, s)
  500. if err != nil {
  501. if replyFunc != nil {
  502. replyFunc(err, nil, -1, nil)
  503. }
  504. socket.kill(err, true)
  505. return
  506. }
  507. b := make([]byte, int(getInt32(s, 0)))
  508. // copy(b, s) in an efficient way.
  509. b[0] = s[0]
  510. b[1] = s[1]
  511. b[2] = s[2]
  512. b[3] = s[3]
  513. err = fill(conn, b[4:])
  514. if err != nil {
  515. if replyFunc != nil {
  516. replyFunc(err, nil, -1, nil)
  517. }
  518. socket.kill(err, true)
  519. return
  520. }
  521. if globalDebug && globalLogger != nil {
  522. m := bson.M{}
  523. if err := bson.Unmarshal(b, m); err == nil {
  524. debugf("Socket %p to %s: received document: %#v", socket, socket.addr, m)
  525. }
  526. }
  527. if replyFunc != nil {
  528. replyFunc(nil, &reply, i, b)
  529. }
  530. // XXX Do bound checking against totalLen.
  531. }
  532. }
  533. socket.Lock()
  534. if len(socket.replyFuncs) == 0 {
  535. // Nothing else to read for now. Disable deadline.
  536. socket.conn.SetReadDeadline(time.Time{})
  537. } else {
  538. socket.updateDeadline(readDeadline)
  539. }
  540. socket.Unlock()
  541. // XXX Do bound checking against totalLen.
  542. }
  543. }
  544. var emptyHeader = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
  545. func addHeader(b []byte, opcode int) []byte {
  546. i := len(b)
  547. b = append(b, emptyHeader...)
  548. // Enough for current opcodes.
  549. b[i+12] = byte(opcode)
  550. b[i+13] = byte(opcode >> 8)
  551. return b
  552. }
  553. func addInt32(b []byte, i int32) []byte {
  554. return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24))
  555. }
  556. func addInt64(b []byte, i int64) []byte {
  557. return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24),
  558. byte(i>>32), byte(i>>40), byte(i>>48), byte(i>>56))
  559. }
  560. func addCString(b []byte, s string) []byte {
  561. b = append(b, []byte(s)...)
  562. b = append(b, 0)
  563. return b
  564. }
  565. func addBSON(b []byte, doc interface{}) ([]byte, error) {
  566. if doc == nil {
  567. return append(b, 5, 0, 0, 0, 0), nil
  568. }
  569. data, err := bson.Marshal(doc)
  570. if err != nil {
  571. return b, err
  572. }
  573. return append(b, data...), nil
  574. }
  575. func setInt32(b []byte, pos int, i int32) {
  576. b[pos] = byte(i)
  577. b[pos+1] = byte(i >> 8)
  578. b[pos+2] = byte(i >> 16)
  579. b[pos+3] = byte(i >> 24)
  580. }
  581. func getInt32(b []byte, pos int) int32 {
  582. return (int32(b[pos+0])) |
  583. (int32(b[pos+1]) << 8) |
  584. (int32(b[pos+2]) << 16) |
  585. (int32(b[pos+3]) << 24)
  586. }
  587. func getInt64(b []byte, pos int) int64 {
  588. return (int64(b[pos+0])) |
  589. (int64(b[pos+1]) << 8) |
  590. (int64(b[pos+2]) << 16) |
  591. (int64(b[pos+3]) << 24) |
  592. (int64(b[pos+4]) << 32) |
  593. (int64(b[pos+5]) << 40) |
  594. (int64(b[pos+6]) << 48) |
  595. (int64(b[pos+7]) << 56)
  596. }