diff --git a/objpool/__init__.py b/objpool/__init__.py index d1c18e3..1316dbe 100644 --- a/objpool/__init__.py +++ b/objpool/__init__.py @@ -57,6 +57,16 @@ log = logging.getLogger(__name__) +class DummySemaphore(object): + _Semaphore__value = 0 + + def acquire(self, *args, **kwargs): + return True + + def release(self): + pass + + class ObjectPoolError(Exception): pass @@ -114,16 +124,28 @@ class ObjectPool(object): relies upon will be copied when the new process is being created. """ - def __init__(self, size=None): + def __init__(self, size=None, create=None, verify=None, cleanup=None): self._pool_pid = getpid() + if size is None: + size = 0 + try: - self.size = int(size) - assert size >= 1 + int(size) + assert size >= 0 except: - raise ValueError("Invalid size for pool (positive integer " + raise ValueError("Invalid size for pool (non-negative integer " "required): %r" % (size,)) - self._semaphore = Semaphore(size) # Pool grows up to size limit + self._create_func = create + self._verify_func = verify + self._cleanup_func = cleanup + + self.size = size + if self.size > 0: + self._semaphore = Semaphore(self.size) + else: + self._semaphore = DummySemaphore() + self._mutex = Lock() # Protect shared _set oject self._set = set() log.debug("Initialized pool %r", self) @@ -237,7 +259,9 @@ def _pool_create(self): Must be thread-safe. """ - raise NotImplementedError + if self._create_func is None: + raise NotImplementedError + return self._create_func() def _pool_verify(self, obj): """Verify an object after getting it from the pool. @@ -249,7 +273,9 @@ def _pool_verify(self, obj): Must be thread-safe. """ - raise NotImplementedError + if self._verify_func is None: + return True + return self._verify_func(obj) def _pool_cleanup(self, obj): """Cleanup an object before being put back into the pool. @@ -259,7 +285,8 @@ def _pool_cleanup(self, obj): Must be thread-safe. """ - raise NotImplementedError + if self._cleanup_func is not None: + return self._cleanup_func(obj) class PooledObject(object): diff --git a/objpool/test/tests.py b/objpool/test/tests.py index 86fe745..6893c63 100644 --- a/objpool/test/tests.py +++ b/objpool/test/tests.py @@ -107,12 +107,20 @@ def _pool_cleanup(self, obj): class ObjectPoolTestCase(unittest.TestCase): - def test_create_pool_requires_size(self): + def test_create_pool_invalid_sizes(self): """Test __init__() requires valid size argument""" - self.assertRaises(ValueError, ObjectPool) - self.assertRaises(ValueError, ObjectPool, size="size10") - self.assertRaises(ValueError, ObjectPool, size=0) self.assertRaises(ValueError, ObjectPool, size=-1) + self.assertRaises(ValueError, ObjectPool, size="size10") + + def test_create_pool_valid_sizes(self): + ObjectPool(size=0) + ObjectPool(size=None) + + def test_unbounded_pool(self): + pool = ObjectPool(size=0, create=[1,2,3].pop) + self.assertEqual(pool.pool_get(), 3) + self.assertEqual(pool.pool_get(), 2) + self.assertEqual(pool.pool_get(), 1) def test_create_pool(self): """Test pool creation works""" @@ -123,12 +131,27 @@ def test_get_not_implemented(self): """Test pool_get() method not implemented in abstract class""" pool = ObjectPool(100) self.assertRaises(NotImplementedError, pool._pool_create) - self.assertRaises(NotImplementedError, pool._pool_verify, None) - def test_put_not_implemented(self): - """Test pool_put() method not implemented in abstract class""" - pool = ObjectPool(100) - self.assertRaises(NotImplementedError, pool._pool_cleanup, None) + def test_get_with_factory(self): + obj_generator = iter(range(10)).next + pool = ObjectPool(3, create=obj_generator) + self.assertEqual(pool.pool_get(), 0) + self.assertEqual(pool.pool_get(), 1) + self.assertEqual(pool.pool_get(), 2) + + def test_put_with_factory(self): + cleaned_objects = [] + pool = ObjectPool(3, + create=[2, 1, 0].pop, + verify=lambda o: o % 2 == 0, + cleanup=cleaned_objects.append, + ) + self.assertEqual(pool.pool_get(), 0) + pool.pool_put(0) + self.assertEqual(pool.pool_get(), 0) + self.assertRaises(PoolVerificationError, pool.pool_get) + self.assertEqual(pool.pool_get(), 2) + self.assertEqual(cleaned_objects, [0]) class NumbersPoolTestCase(unittest.TestCase):