redistore.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. // Copyright 2012 Brian "bojo" Jones. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package redistore
  5. import (
  6. "bytes"
  7. "encoding/base32"
  8. "encoding/gob"
  9. "encoding/json"
  10. "errors"
  11. "fmt"
  12. "net/http"
  13. "strings"
  14. "time"
  15. "github.com/gomodule/redigo/redis"
  16. "github.com/gorilla/securecookie"
  17. "github.com/gorilla/sessions"
  18. )
  19. // Amount of time for cookies/redis keys to expire.
  20. var sessionExpire = 86400 * 30
  21. // SessionSerializer provides an interface hook for alternative serializers
  22. type SessionSerializer interface {
  23. Deserialize(d []byte, ss *sessions.Session) error
  24. Serialize(ss *sessions.Session) ([]byte, error)
  25. }
  26. // JSONSerializer encode the session map to JSON.
  27. type JSONSerializer struct{}
  28. // Serialize to JSON. Will err if there are unmarshalable key values
  29. func (s JSONSerializer) Serialize(ss *sessions.Session) ([]byte, error) {
  30. m := make(map[string]interface{}, len(ss.Values))
  31. for k, v := range ss.Values {
  32. ks, ok := k.(string)
  33. if !ok {
  34. err := fmt.Errorf("Non-string key value, cannot serialize session to JSON: %v", k)
  35. fmt.Printf("redistore.JSONSerializer.serialize() Error: %v", err)
  36. return nil, err
  37. }
  38. m[ks] = v
  39. }
  40. return json.Marshal(m)
  41. }
  42. // Deserialize back to map[string]interface{}
  43. func (s JSONSerializer) Deserialize(d []byte, ss *sessions.Session) error {
  44. m := make(map[string]interface{})
  45. err := json.Unmarshal(d, &m)
  46. if err != nil {
  47. fmt.Printf("redistore.JSONSerializer.deserialize() Error: %v", err)
  48. return err
  49. }
  50. for k, v := range m {
  51. ss.Values[k] = v
  52. }
  53. return nil
  54. }
  55. // GobSerializer uses gob package to encode the session map
  56. type GobSerializer struct{}
  57. // Serialize using gob
  58. func (s GobSerializer) Serialize(ss *sessions.Session) ([]byte, error) {
  59. buf := new(bytes.Buffer)
  60. enc := gob.NewEncoder(buf)
  61. err := enc.Encode(ss.Values)
  62. if err == nil {
  63. return buf.Bytes(), nil
  64. }
  65. return nil, err
  66. }
  67. // Deserialize back to map[interface{}]interface{}
  68. func (s GobSerializer) Deserialize(d []byte, ss *sessions.Session) error {
  69. dec := gob.NewDecoder(bytes.NewBuffer(d))
  70. return dec.Decode(&ss.Values)
  71. }
  72. // RediStore stores sessions in a redis backend.
  73. type RediStore struct {
  74. Pool *redis.Pool
  75. Codecs []securecookie.Codec
  76. Options *sessions.Options // default configuration
  77. DefaultMaxAge int // default Redis TTL for a MaxAge == 0 session
  78. maxLength int
  79. keyPrefix string
  80. serializer SessionSerializer
  81. }
  82. // SetMaxLength sets RediStore.maxLength if the `l` argument is greater or equal 0
  83. // maxLength restricts the maximum length of new sessions to l.
  84. // If l is 0 there is no limit to the size of a session, use with caution.
  85. // The default for a new RediStore is 4096. Redis allows for max.
  86. // value sizes of up to 512MB (http://redis.io/topics/data-types)
  87. // Default: 4096,
  88. func (s *RediStore) SetMaxLength(l int) {
  89. if l >= 0 {
  90. s.maxLength = l
  91. }
  92. }
  93. // SetKeyPrefix set the prefix
  94. func (s *RediStore) SetKeyPrefix(p string) {
  95. s.keyPrefix = p
  96. }
  97. // SetSerializer sets the serializer
  98. func (s *RediStore) SetSerializer(ss SessionSerializer) {
  99. s.serializer = ss
  100. }
  101. // SetMaxAge restricts the maximum age, in seconds, of the session record
  102. // both in database and a browser. This is to change session storage configuration.
  103. // If you want just to remove session use your session `s` object and change it's
  104. // `Options.MaxAge` to -1, as specified in
  105. // http://godoc.org/github.com/gorilla/sessions#Options
  106. //
  107. // Default is the one provided by this package value - `sessionExpire`.
  108. // Set it to 0 for no restriction.
  109. // Because we use `MaxAge` also in SecureCookie crypting algorithm you should
  110. // use this function to change `MaxAge` value.
  111. func (s *RediStore) SetMaxAge(v int) {
  112. var c *securecookie.SecureCookie
  113. var ok bool
  114. s.Options.MaxAge = v
  115. for i := range s.Codecs {
  116. if c, ok = s.Codecs[i].(*securecookie.SecureCookie); ok {
  117. c.MaxAge(v)
  118. } else {
  119. fmt.Printf("Can't change MaxAge on codec %v\n", s.Codecs[i])
  120. }
  121. }
  122. }
  123. func dial(network, address, password string) (redis.Conn, error) {
  124. c, err := redis.Dial(network, address)
  125. if err != nil {
  126. return nil, err
  127. }
  128. if password != "" {
  129. if _, err := c.Do("AUTH", password); err != nil {
  130. c.Close()
  131. return nil, err
  132. }
  133. }
  134. return c, err
  135. }
  136. // NewRediStore returns a new RediStore.
  137. // size: maximum number of idle connections.
  138. func NewRediStore(size int, network, address, password string, keyPairs ...[]byte) (*RediStore, error) {
  139. return NewRediStoreWithPool(&redis.Pool{
  140. MaxIdle: size,
  141. IdleTimeout: 240 * time.Second,
  142. TestOnBorrow: func(c redis.Conn, t time.Time) error {
  143. _, err := c.Do("PING")
  144. return err
  145. },
  146. Dial: func() (redis.Conn, error) {
  147. return dial(network, address, password)
  148. },
  149. }, keyPairs...)
  150. }
  151. func dialWithDB(network, address, password, DB string) (redis.Conn, error) {
  152. c, err := dial(network, address, password)
  153. if err != nil {
  154. return nil, err
  155. }
  156. if _, err := c.Do("SELECT", DB); err != nil {
  157. c.Close()
  158. return nil, err
  159. }
  160. return c, err
  161. }
  162. // NewRediStoreWithDB - like NewRedisStore but accepts `DB` parameter to select
  163. // redis DB instead of using the default one ("0")
  164. func NewRediStoreWithDB(size int, network, address, password, DB string, keyPairs ...[]byte) (*RediStore, error) {
  165. return NewRediStoreWithPool(&redis.Pool{
  166. MaxIdle: size,
  167. IdleTimeout: 240 * time.Second,
  168. TestOnBorrow: func(c redis.Conn, t time.Time) error {
  169. _, err := c.Do("PING")
  170. return err
  171. },
  172. Dial: func() (redis.Conn, error) {
  173. return dialWithDB(network, address, password, DB)
  174. },
  175. }, keyPairs...)
  176. }
  177. // NewRediStoreWithPool instantiates a RediStore with a *redis.Pool passed in.
  178. func NewRediStoreWithPool(pool *redis.Pool, keyPairs ...[]byte) (*RediStore, error) {
  179. rs := &RediStore{
  180. // http://godoc.org/github.com/gomodule/redigo/redis#Pool
  181. Pool: pool,
  182. Codecs: securecookie.CodecsFromPairs(keyPairs...),
  183. Options: &sessions.Options{
  184. Path: "/",
  185. MaxAge: sessionExpire,
  186. },
  187. DefaultMaxAge: 60 * 20, // 20 minutes seems like a reasonable default
  188. maxLength: 4096,
  189. keyPrefix: "session_",
  190. serializer: GobSerializer{},
  191. }
  192. _, err := rs.ping()
  193. return rs, err
  194. }
  195. // Close closes the underlying *redis.Pool
  196. func (s *RediStore) Close() error {
  197. return s.Pool.Close()
  198. }
  199. // Get returns a session for the given name after adding it to the registry.
  200. //
  201. // See gorilla/sessions FilesystemStore.Get().
  202. func (s *RediStore) Get(r *http.Request, name string) (*sessions.Session, error) {
  203. return sessions.GetRegistry(r).Get(s, name)
  204. }
  205. // New returns a session for the given name without adding it to the registry.
  206. //
  207. // See gorilla/sessions FilesystemStore.New().
  208. func (s *RediStore) New(r *http.Request, name string) (*sessions.Session, error) {
  209. var (
  210. err error
  211. ok bool
  212. )
  213. session := sessions.NewSession(s, name)
  214. // make a copy
  215. options := *s.Options
  216. session.Options = &options
  217. session.IsNew = true
  218. if c, errCookie := r.Cookie(name); errCookie == nil {
  219. err = securecookie.DecodeMulti(name, c.Value, &session.ID, s.Codecs...)
  220. if err == nil {
  221. ok, err = s.load(session)
  222. session.IsNew = !(err == nil && ok) // not new if no error and data available
  223. }
  224. }
  225. return session, err
  226. }
  227. // Save adds a single session to the response.
  228. func (s *RediStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
  229. // Marked for deletion.
  230. if session.Options.MaxAge <= 0 {
  231. if err := s.delete(session); err != nil {
  232. return err
  233. }
  234. http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
  235. } else {
  236. // Build an alphanumeric key for the redis store.
  237. if session.ID == "" {
  238. session.ID = strings.TrimRight(base32.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(32)), "=")
  239. }
  240. if err := s.save(session); err != nil {
  241. return err
  242. }
  243. encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, s.Codecs...)
  244. if err != nil {
  245. return err
  246. }
  247. http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options))
  248. }
  249. return nil
  250. }
  251. // Delete removes the session from redis, and sets the cookie to expire.
  252. //
  253. // WARNING: This method should be considered deprecated since it is not exposed via the gorilla/sessions interface.
  254. // Set session.Options.MaxAge = -1 and call Save instead. - July 18th, 2013
  255. func (s *RediStore) Delete(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
  256. conn := s.Pool.Get()
  257. defer conn.Close()
  258. if _, err := conn.Do("DEL", s.keyPrefix+session.ID); err != nil {
  259. return err
  260. }
  261. // Set cookie to expire.
  262. options := *session.Options
  263. options.MaxAge = -1
  264. http.SetCookie(w, sessions.NewCookie(session.Name(), "", &options))
  265. // Clear session values.
  266. for k := range session.Values {
  267. delete(session.Values, k)
  268. }
  269. return nil
  270. }
  271. // ping does an internal ping against a server to check if it is alive.
  272. func (s *RediStore) ping() (bool, error) {
  273. conn := s.Pool.Get()
  274. defer conn.Close()
  275. data, err := conn.Do("PING")
  276. if err != nil || data == nil {
  277. return false, err
  278. }
  279. return (data == "PONG"), nil
  280. }
  281. // save stores the session in redis.
  282. func (s *RediStore) save(session *sessions.Session) error {
  283. b, err := s.serializer.Serialize(session)
  284. if err != nil {
  285. return err
  286. }
  287. if s.maxLength != 0 && len(b) > s.maxLength {
  288. return errors.New("SessionStore: the value to store is too big")
  289. }
  290. conn := s.Pool.Get()
  291. defer conn.Close()
  292. if err = conn.Err(); err != nil {
  293. return err
  294. }
  295. age := session.Options.MaxAge
  296. if age == 0 {
  297. age = s.DefaultMaxAge
  298. }
  299. _, err = conn.Do("SETEX", s.keyPrefix+session.ID, age, b)
  300. return err
  301. }
  302. // load reads the session from redis.
  303. // returns true if there is a sessoin data in DB
  304. func (s *RediStore) load(session *sessions.Session) (bool, error) {
  305. conn := s.Pool.Get()
  306. defer conn.Close()
  307. if err := conn.Err(); err != nil {
  308. return false, err
  309. }
  310. data, err := conn.Do("GET", s.keyPrefix+session.ID)
  311. if err != nil {
  312. return false, err
  313. }
  314. if data == nil {
  315. return false, nil // no data was associated with this key
  316. }
  317. b, err := redis.Bytes(data, err)
  318. if err != nil {
  319. return false, err
  320. }
  321. return true, s.serializer.Deserialize(b, session)
  322. }
  323. // delete removes keys from redis if MaxAge<0
  324. func (s *RediStore) delete(session *sessions.Session) error {
  325. conn := s.Pool.Get()
  326. defer conn.Close()
  327. if _, err := conn.Do("DEL", s.keyPrefix+session.ID); err != nil {
  328. return err
  329. }
  330. return nil
  331. }