sasl.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. // Package sasl is an implementation detail of the mgo package.
  2. //
  3. // This package is not meant to be used by itself.
  4. //
  5. package sasl
  6. // #cgo LDFLAGS: -lsasl2
  7. //
  8. // struct sasl_conn {};
  9. //
  10. // #include <stdlib.h>
  11. // #include <sasl/sasl.h>
  12. //
  13. // sasl_callback_t *mgo_sasl_callbacks(const char *username, const char *password);
  14. //
  15. import "C"
  16. import (
  17. "fmt"
  18. "strings"
  19. "sync"
  20. "unsafe"
  21. )
  22. type saslStepper interface {
  23. Step(serverData []byte) (clientData []byte, done bool, err error)
  24. Close()
  25. }
  26. type saslSession struct {
  27. conn *C.sasl_conn_t
  28. step int
  29. mech string
  30. cstrings []*C.char
  31. callbacks *C.sasl_callback_t
  32. }
  33. var initError error
  34. var initOnce sync.Once
  35. func initSASL() {
  36. rc := C.sasl_client_init(nil)
  37. if rc != C.SASL_OK {
  38. initError = saslError(rc, nil, "cannot initialize SASL library")
  39. }
  40. }
  41. func New(username, password, mechanism, service, host string) (saslStepper, error) {
  42. initOnce.Do(initSASL)
  43. if initError != nil {
  44. return nil, initError
  45. }
  46. ss := &saslSession{mech: mechanism}
  47. if service == "" {
  48. service = "mongodb"
  49. }
  50. if i := strings.Index(host, ":"); i >= 0 {
  51. host = host[:i]
  52. }
  53. ss.callbacks = C.mgo_sasl_callbacks(ss.cstr(username), ss.cstr(password))
  54. rc := C.sasl_client_new(ss.cstr(service), ss.cstr(host), nil, nil, ss.callbacks, 0, &ss.conn)
  55. if rc != C.SASL_OK {
  56. ss.Close()
  57. return nil, saslError(rc, nil, "cannot create new SASL client")
  58. }
  59. return ss, nil
  60. }
  61. func (ss *saslSession) cstr(s string) *C.char {
  62. cstr := C.CString(s)
  63. ss.cstrings = append(ss.cstrings, cstr)
  64. return cstr
  65. }
  66. func (ss *saslSession) Close() {
  67. for _, cstr := range ss.cstrings {
  68. C.free(unsafe.Pointer(cstr))
  69. }
  70. ss.cstrings = nil
  71. if ss.callbacks != nil {
  72. C.free(unsafe.Pointer(ss.callbacks))
  73. }
  74. // The documentation of SASL dispose makes it clear that this should only
  75. // be done when the connection is done, not when the authentication phase
  76. // is done, because an encryption layer may have been negotiated.
  77. // Even then, we'll do this for now, because it's simpler and prevents
  78. // keeping track of this state for every socket. If it breaks, we'll fix it.
  79. C.sasl_dispose(&ss.conn)
  80. }
  81. func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) {
  82. ss.step++
  83. if ss.step > 10 {
  84. return nil, false, fmt.Errorf("too many SASL steps without authentication")
  85. }
  86. var cclientData *C.char
  87. var cclientDataLen C.uint
  88. var rc C.int
  89. if ss.step == 1 {
  90. var mechanism *C.char // ignored - must match cred
  91. rc = C.sasl_client_start(ss.conn, ss.cstr(ss.mech), nil, &cclientData, &cclientDataLen, &mechanism)
  92. } else {
  93. var cserverData *C.char
  94. var cserverDataLen C.uint
  95. if len(serverData) > 0 {
  96. cserverData = (*C.char)(unsafe.Pointer(&serverData[0]))
  97. cserverDataLen = C.uint(len(serverData))
  98. }
  99. rc = C.sasl_client_step(ss.conn, cserverData, cserverDataLen, nil, &cclientData, &cclientDataLen)
  100. }
  101. if cclientData != nil && cclientDataLen > 0 {
  102. clientData = C.GoBytes(unsafe.Pointer(cclientData), C.int(cclientDataLen))
  103. }
  104. if rc == C.SASL_OK {
  105. return clientData, true, nil
  106. }
  107. if rc == C.SASL_CONTINUE {
  108. return clientData, false, nil
  109. }
  110. return nil, false, saslError(rc, ss.conn, "cannot establish SASL session")
  111. }
  112. func saslError(rc C.int, conn *C.sasl_conn_t, msg string) error {
  113. var detail string
  114. if conn == nil {
  115. detail = C.GoString(C.sasl_errstring(rc, nil, nil))
  116. } else {
  117. detail = C.GoString(C.sasl_errdetail(conn))
  118. }
  119. return fmt.Errorf(msg + ": " + detail)
  120. }