Bug 680324 - ldap connections should have a max lifetime configurable - r=rkelly
authorTarek Ziade <tarek@ziade.org>
Thu, 08 Sep 2011 11:34:07 +0200
changeset 839 58360f363e8b75780183d5e13b262e56b1abadbc
parent 838 8b69140b4d30da24aa1f66984617ed7a61882d3d
child 840 dcf41be8a1f49c0c7a2cfb6c58539e6df0521d9b
push id345
push usertziade@mozilla.com
push dateThu, 08 Sep 2011 09:34:22 +0000
reviewersrkelly
bugs680324
Bug 680324 - ldap connections should have a max lifetime configurable - r=rkelly
services/auth/ldapsql.py
services/ldappool.py
services/tests/test_ldappool.py
--- a/services/auth/ldapsql.py
+++ b/services/auth/ldapsql.py
@@ -87,17 +87,18 @@ class LDAPAuth(ResetCodeManager):
 
     def __init__(self, ldapuri, sqluri, use_tls=False, bind_user='binduser',
                  bind_password='binduser', admin_user='adminuser',
                  admin_password='adminuser', users_root='ou=users,dc=mozilla',
                  users_base_dn=None, pool_size=100, pool_recycle=3600,
                  reset_on_return=True, single_box=False, ldap_timeout=-1,
                  nodes_scheme='https', check_account_state=True,
                  create_tables=False, ldap_pool_size=10, ldap_use_pool=False,
-                 connector_cls=StateConnector, check_node=False, **kw):
+                 connector_cls=StateConnector, check_node=False,
+                 ldap_max_lifetime=600, **kw):
         self.check_account_state = check_account_state
         self.ldapuri = ldapuri
         self.sqluri = sqluri
         self.bind_user = bind_user
         self.bind_password = bind_password
         self.admin_user = admin_user
         self.admin_password = admin_password
         self.use_tls = use_tls
@@ -106,17 +107,18 @@ class LDAPAuth(ResetCodeManager):
         self.single_box = single_box
         self.nodes_scheme = nodes_scheme
         self.ldap_timeout = ldap_timeout
         # by default, the ldap connections use the bind user
         self.conn = ConnectionManager(ldapuri, bind_user, bind_password,
                                       use_tls=use_tls, timeout=ldap_timeout,
                                       size=ldap_pool_size,
                                       use_pool=ldap_use_pool,
-                                      connector_cls=connector_cls)
+                                      connector_cls=connector_cls,
+                                      max_lifetime=ldap_max_lifetime)
         sqlkw = {'pool_size': int(pool_size),
                  'pool_recycle': int(pool_recycle),
                  'logging_name': 'weaveserver'}
 
         if self.sqluri is not None:
             if self.sqluri.startswith('mysql'):
                 sqlkw['reset_on_return'] = reset_on_return
             engine = create_engine(sqluri, **sqlkw)
--- a/services/ldappool.py
+++ b/services/ldappool.py
@@ -48,16 +48,23 @@ from services.exceptions import (Backend
 
 class StateConnector(ReconnectLDAPObject):
     """Just remembers who is connected, and if connected"""
     def __init__(self, *args, **kw):
         ReconnectLDAPObject.__init__(self, *args, **kw)
         self.connected = False
         self.who = ''
         self.cred = ''
+        self._connection_time = None
+
+    def get_lifetime(self):
+        """Returns the lifetime of the connection on the server in seconds."""
+        if self._connection_time is None:
+            return 0
+        return time.time() - self._connection_time
 
     def __str__(self):
         res = 'LDAP Connector'
         if self.connected:
             res += ' (connected)'
         else:
             res += ' (disconnected)'
 
@@ -71,26 +78,28 @@ class StateConnector(ReconnectLDAPObject
 
     def simple_bind_s(self, who='', cred='', serverctrls=None,
                       clientctrls=None):
         res = ReconnectLDAPObject.simple_bind_s(self, who, cred, serverctrls,
                                                 clientctrls)
         self.connected = True
         self.who = who
         self.cred = cred
+        self._connection_time = time.time()
         return res
 
     def unbind_ext_s(self, serverctrls=None, clientctrls=None):
         try:
             return ReconnectLDAPObject.unbind_ext_s(self, serverctrls,
                                                     clientctrls)
         finally:
             self.connected = False
             self.who = None
             self.cred = None
+            self._connection_time = None
 
     def add_s(self, *args, **kwargs):
         return self._apply_method_s(ReconnectLDAPObject.add_s, *args,
                                     **kwargs)
 
     def modify_s(self, *args, **kwargs):
         return self._apply_method_s(ReconnectLDAPObject.modify_s, *args,
                                     **kwargs)
@@ -98,43 +107,53 @@ class StateConnector(ReconnectLDAPObject
 
 class ConnectionManager(object):
     """LDAP Connection Manager.
 
     Provides a context manager for LDAP connectors.
     """
     def __init__(self, uri, bind=None, passwd=None, size=10, retry_max=3,
                  retry_delay=.1, use_tls=False, single_box=False, timeout=-1,
-                 connector_cls=StateConnector, use_pool=False):
+                 connector_cls=StateConnector, use_pool=False,
+                 max_lifetime=600):
         self._pool = []
         self.size = size
         self.retry_max = retry_max
         self.retry_delay = retry_delay
         self.uri = uri
         self.bind = bind
         self.passwd = passwd
         self._pool_lock = RLock()
         self.use_tls = False
         self.timeout = timeout
         self.connector_cls = connector_cls
         self.use_pool = use_pool
+        self.max_lifetime = max_lifetime
 
     def __len__(self):
         return len(self._pool)
 
     def _match(self, bind, passwd):
         passwd = passwd.encode('utf8')
         with self._pool_lock:
             inactives = []
 
             for conn in reversed(self._pool):
                 # already in usage
                 if conn.active:
                     continue
 
+                # let's check the lifetime
+                if conn.get_lifetime() > self.max_lifetime:
+                    # this connector has lived for too long,
+                    # we want to unbind it and remove it from the pool
+                    conn.unbind_s()
+                    self._pool.remove(conn)
+                    continue
+
                 # we found a connector for this bind
                 if conn.who == bind and conn.cred == passwd:
                     conn.active = True
                     return conn
 
                 inactives.append(conn)
 
             # no connector was available, let's rebind the latest inactive one
@@ -226,31 +245,36 @@ class ConnectionManager(object):
 
         # we need to create a new connector
         conn = self._create_connector(bind, passwd)
 
         # adding it to the pool
         if self.use_pool:
             with self._pool_lock:
                 self._pool.append(conn)
+        else:
+            # with no pool, the connector is always active
+            conn.active = True
 
         return conn
 
     def _release_connection(self, connection):
         if self.use_pool:
             with self._pool_lock:
                 if not connection.connected:
                     # unconnected connector, let's drop it
                     self._pool.remove(connection)
                 else:
                     # can be reused - let's mark is as not active
                     connection.active = False
 
                     # done.
                     return
+        else:
+            connection.active = False
 
         # let's try to unbind it
         try:
             connection.unbind_ext_s()
         except ldap.LDAPError:
             # avoid error on invalid state
             pass
 
--- a/services/tests/test_ldappool.py
+++ b/services/tests/test_ldappool.py
@@ -33,18 +33,18 @@
 # the terms of any one of the MPL, the GPL or the LGPL.
 #
 # ***** END LICENSE BLOCK *****
 import unittest
 import threading
 import time
 try:
     import ldap
-    from services.auth.ldappool import (ConnectionPool, StateConnector,
-                                       MaxConnectionReachedError)
+    from services.ldappool import ConnectionManager, StateConnector
+    from services.exceptions import MaxConnectionReachedError
     LDAP = True
 except ImportError:
     LDAP = False
 
 if LDAP:
     # patching StateConnector
     StateConnector.users = {'uid=tarek,ou=users,dc=mozilla':
                                         {'uidNumber': ['1'],
@@ -54,16 +54,17 @@ if LDAP:
                             'cn=admin,dc=mozilla': {'cn': ['admin'],
                                                     'mail': ['admin'],
                                                     'uidNumber': ['100']}}
 
     def _simple_bind(self, who='', cred='', *args):
         self.connected = True
         self.who = who
         self.cred = cred
+        self._connection_time = time.time()
 
     StateConnector.simple_bind_s = _simple_bind
 
     def _search(self, dn, *args, **kw):
         if dn in self.users:
             return [(dn, self.users[dn])]
         elif dn == 'ou=users,dc=mozilla':
             uid = kw['filterstr'].split('=')[-1][:-1]
@@ -123,35 +124,37 @@ class LDAPWorker(threading.Thread):
 
 class TestLDAPSQLAuth(unittest.TestCase):
 
     def test_pool(self):
         if not LDAP:
             return
         dn = 'uid=adminuser,ou=logins,dc=mozilla'
         passwd = 'adminuser'
-        pool = ConnectionPool('ldap://localhost', dn, passwd)
+        pool = ConnectionManager('ldap://localhost', dn, passwd,
+                                 use_pool=True)
+
         workers = [LDAPWorker(pool) for i in range(10)]
 
         for worker in workers:
             worker.start()
 
         for worker in workers:
             worker.join()
             self.assertEquals(len(worker.results), 10)
             cn = worker.results[0][0][1]['cn']
             self.assertEquals(cn, ['admin'])
 
     def test_pool_full(self):
         if not LDAP:
             return
         dn = 'uid=adminuser,ou=logins,dc=mozilla'
         passwd = 'adminuser'
-        pool = ConnectionPool('ldap://localhost', dn, passwd, size=1,
-                              retry_delay=1., retry_max=5)
+        pool = ConnectionManager('ldap://localhost', dn, passwd, size=1,
+                              retry_delay=1., retry_max=5, use_pool=True)
 
         class Worker(threading.Thread):
 
             def __init__(self, pool, duration):
                 threading.Thread.__init__(self)
                 self.pool = pool
                 self.duration = duration
 
@@ -194,32 +197,34 @@ class TestLDAPSQLAuth(unittest.TestCase)
         # we still have one active connector
         self.assertEqual(len(pool), 1)
 
     def test_pool_cleanup(self):
         if not LDAP:
             return
         dn = 'uid=adminuser,ou=logins,dc=mozilla'
         passwd = 'adminuser'
-        pool = ConnectionPool('ldap://localhost', dn, passwd, size=1)
+        pool = ConnectionManager('ldap://localhost', dn, passwd, size=1,
+                                 use_pool=True)
         with pool.connection('bind1') as conn:  # NOQA
             pass
 
         with pool.connection('bind2') as conn:  # NOQA
             pass
 
         # the second call should have removed the first conn
         self.assertEqual(len(pool), 1)
 
     def test_pool_reuse(self):
         if not LDAP:
             return
         dn = 'uid=adminuser,ou=logins,dc=mozilla'
         passwd = 'adminuser'
-        pool = ConnectionPool('ldap://localhost', dn, passwd)
+        pool = ConnectionManager('ldap://localhost', dn, passwd,
+                                 use_pool=True)
 
         with pool.connection() as conn:
             self.assertTrue(conn.active)
 
         self.assertFalse(conn.active)
         self.assertTrue(conn.connected)
 
         with pool.connection() as conn2:
@@ -242,19 +247,49 @@ class TestLDAPSQLAuth(unittest.TestCase)
         self.assertFalse(conn.active)
         self.assertTrue(conn.connected)
 
         with pool.connection('bind', 'passwd') as conn2:
             pass
 
         self.assertTrue(conn is conn2)
 
-        # same bind different password: discard
+        # same bind different password, inactive: rebind
         with pool.connection('bind', 'passwd') as conn:
             self.assertTrue(conn.active)
 
         self.assertFalse(conn.active)
         self.assertTrue(conn.connected)
 
         with pool.connection('bind', 'passwd2') as conn2:
             pass
 
-        self.assertTrue(conn is not conn2)
+        self.assertTrue(conn is conn2)
+
+    def test_max_lifetime(self):
+        if not LDAP:
+            return
+
+        dn = 'uid=adminuser,ou=logins,dc=mozilla'
+        passwd = 'adminuser'
+        pool = ConnectionManager('ldap://localhost', dn, passwd,
+                                 max_lifetime=0.5, use_pool=True)
+
+        with pool.connection('bind', 'password') as conn:
+            self.assertTrue(conn.active)
+
+        self.assertFalse(conn.active)
+        self.assertTrue(conn.connected)
+
+        # same bind and password: reuse
+        with pool.connection('bind', 'passwd') as conn2:
+            self.assertTrue(conn2.active)
+
+        self.assertTrue(conn is conn2)
+
+        time.sleep(0.6)
+
+        # same bind and password, but max lifetime reached: new one
+        with pool.connection('bind', 'passwd') as conn3:
+            self.assertTrue(conn3.active)
+
+        self.assertTrue(conn3 is not conn2)
+        self.assertTrue(conn3 is not conn)