provider.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. package sessions
  2. import (
  3. "errors"
  4. "sync"
  5. "time"
  6. )
  7. type (
  8. // provider contains the sessions and external databases (load and update).
  9. // It's the session memory manager
  10. provider struct {
  11. mu sync.RWMutex
  12. sessions map[string]*Session
  13. db Database
  14. destroyListeners []DestroyListener
  15. }
  16. )
  17. // newProvider returns a new sessions provider
  18. func newProvider() *provider {
  19. p := &provider{
  20. sessions: make(map[string]*Session),
  21. db: newMemDB(),
  22. }
  23. return p
  24. }
  25. // RegisterDatabase sets a session database.
  26. func (p *provider) RegisterDatabase(db Database) {
  27. if db == nil {
  28. return
  29. }
  30. p.mu.Lock() // for any case
  31. p.db = db
  32. p.mu.Unlock()
  33. }
  34. // newSession returns a new session from sessionid
  35. func (p *provider) newSession(man *Sessions, sid string, expires time.Duration) *Session {
  36. sess := &Session{
  37. sid: sid,
  38. Man: man,
  39. provider: p,
  40. flashes: make(map[string]*flashMessage),
  41. }
  42. onExpire := func() {
  43. p.mu.Lock()
  44. p.deleteSession(sess)
  45. p.mu.Unlock()
  46. }
  47. lifetime := p.db.Acquire(sid, expires)
  48. // simple and straight:
  49. if !lifetime.IsZero() {
  50. // if stored time is not zero
  51. // start a timer based on the stored time, if not expired.
  52. lifetime.Revive(onExpire)
  53. } else {
  54. // Remember: if db not exist or it has been expired
  55. // then the stored time will be zero(see loadSessionFromDB) and the values will be empty.
  56. //
  57. // Even if the database has an unlimited session (possible by a previous app run)
  58. // priority to the "expires" is given,
  59. // again if <=0 then it does nothing.
  60. lifetime.Begin(expires, onExpire)
  61. }
  62. sess.Lifetime = lifetime
  63. return sess
  64. }
  65. // Init creates the session and returns it
  66. func (p *provider) Init(man *Sessions, sid string, expires time.Duration) *Session {
  67. newSession := p.newSession(man, sid, expires)
  68. p.mu.Lock()
  69. p.sessions[sid] = newSession
  70. p.mu.Unlock()
  71. return newSession
  72. }
  73. // ErrNotFound may be returned from `UpdateExpiration` of a non-existing or
  74. // invalid session entry from memory storage or databases.
  75. // Usage:
  76. // if err != nil && err.Is(err, sessions.ErrNotFound) {
  77. // [handle error...]
  78. // }
  79. var ErrNotFound = errors.New("session not found")
  80. // UpdateExpiration resets the expiration of a session.
  81. // if expires > 0 then it will try to update the expiration and destroy task is delayed.
  82. // if expires <= 0 then it does nothing it returns nil, to destroy a session call the `Destroy` func instead.
  83. //
  84. // If the session is not found, it returns a `NotFound` error, this can only happen when you restart the server and you used the memory-based storage(default),
  85. // because the call of the provider's `UpdateExpiration` is always called when the client has a valid session cookie.
  86. //
  87. // If a backend database is used then it may return an `ErrNotImplemented` error if the underline database does not support this operation.
  88. func (p *provider) UpdateExpiration(sid string, expires time.Duration) error {
  89. if expires <= 0 {
  90. return nil
  91. }
  92. p.mu.RLock()
  93. sess, found := p.sessions[sid]
  94. p.mu.RUnlock()
  95. if !found {
  96. return ErrNotFound
  97. }
  98. sess.Lifetime.Shift(expires)
  99. return p.db.OnUpdateExpiration(sid, expires)
  100. }
  101. // Read returns the store which sid parameter belongs
  102. func (p *provider) Read(man *Sessions, sid string, expires time.Duration) *Session {
  103. p.mu.RLock()
  104. if sess, found := p.sessions[sid]; found {
  105. sess.runFlashGC() // run the flash messages GC, new request here of existing session
  106. p.mu.RUnlock()
  107. return sess
  108. }
  109. p.mu.RUnlock()
  110. return p.Init(man, sid, expires) // if not found create new
  111. }
  112. func (p *provider) registerDestroyListener(ln DestroyListener) {
  113. if ln == nil {
  114. return
  115. }
  116. p.destroyListeners = append(p.destroyListeners, ln)
  117. }
  118. func (p *provider) fireDestroy(sid string) {
  119. for _, ln := range p.destroyListeners {
  120. ln(sid)
  121. }
  122. }
  123. // Destroy destroys the session, removes all sessions and flash values,
  124. // the session itself and updates the registered session databases,
  125. // this called from sessionManager which removes the client's cookie also.
  126. func (p *provider) Destroy(sid string) {
  127. p.mu.Lock()
  128. if sess, found := p.sessions[sid]; found {
  129. p.deleteSession(sess)
  130. }
  131. p.mu.Unlock()
  132. }
  133. // DestroyAll removes all sessions
  134. // from the server-side memory (and database if registered).
  135. // Client's session cookie will still exist but it will be reseted on the next request.
  136. func (p *provider) DestroyAll() {
  137. p.mu.Lock()
  138. for _, sess := range p.sessions {
  139. p.deleteSession(sess)
  140. }
  141. p.mu.Unlock()
  142. }
  143. func (p *provider) deleteSession(sess *Session) {
  144. sid := sess.sid
  145. delete(p.sessions, sid)
  146. p.db.Release(sid)
  147. p.fireDestroy(sid)
  148. }