Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions objpool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@
log = logging.getLogger(__name__)


class DummySemaphore(object):
_Semaphore__value = 0

def acquire(self, *args, **kwargs):
return True

def release(self):
pass


class ObjectPoolError(Exception):
pass

Expand Down Expand Up @@ -114,16 +124,28 @@ class ObjectPool(object):
relies upon will be copied when the new process is being created.

"""
def __init__(self, size=None):
def __init__(self, size=None, create=None, verify=None, cleanup=None):
self._pool_pid = getpid()
if size is None:
size = 0

try:
self.size = int(size)
assert size >= 1
int(size)
assert size >= 0
except:
raise ValueError("Invalid size for pool (positive integer "
raise ValueError("Invalid size for pool (non-negative integer "
"required): %r" % (size,))

self._semaphore = Semaphore(size) # Pool grows up to size limit
self._create_func = create
self._verify_func = verify
self._cleanup_func = cleanup

self.size = size
if self.size > 0:
self._semaphore = Semaphore(self.size)
else:
self._semaphore = DummySemaphore()

self._mutex = Lock() # Protect shared _set oject
self._set = set()
log.debug("Initialized pool %r", self)
Expand Down Expand Up @@ -237,7 +259,9 @@ def _pool_create(self):
Must be thread-safe.

"""
raise NotImplementedError
if self._create_func is None:
raise NotImplementedError
return self._create_func()

def _pool_verify(self, obj):
"""Verify an object after getting it from the pool.
Expand All @@ -249,7 +273,9 @@ def _pool_verify(self, obj):
Must be thread-safe.

"""
raise NotImplementedError
if self._verify_func is None:
return True
return self._verify_func(obj)

def _pool_cleanup(self, obj):
"""Cleanup an object before being put back into the pool.
Expand All @@ -259,7 +285,8 @@ def _pool_cleanup(self, obj):
Must be thread-safe.

"""
raise NotImplementedError
if self._cleanup_func is not None:
return self._cleanup_func(obj)


class PooledObject(object):
Expand Down
41 changes: 32 additions & 9 deletions objpool/test/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,20 @@ def _pool_cleanup(self, obj):


class ObjectPoolTestCase(unittest.TestCase):
def test_create_pool_requires_size(self):
def test_create_pool_invalid_sizes(self):
"""Test __init__() requires valid size argument"""
self.assertRaises(ValueError, ObjectPool)
self.assertRaises(ValueError, ObjectPool, size="size10")
self.assertRaises(ValueError, ObjectPool, size=0)
self.assertRaises(ValueError, ObjectPool, size=-1)
self.assertRaises(ValueError, ObjectPool, size="size10")

def test_create_pool_valid_sizes(self):
ObjectPool(size=0)
ObjectPool(size=None)

def test_unbounded_pool(self):
pool = ObjectPool(size=0, create=[1,2,3].pop)
self.assertEqual(pool.pool_get(), 3)
self.assertEqual(pool.pool_get(), 2)
self.assertEqual(pool.pool_get(), 1)

def test_create_pool(self):
"""Test pool creation works"""
Expand All @@ -123,12 +131,27 @@ def test_get_not_implemented(self):
"""Test pool_get() method not implemented in abstract class"""
pool = ObjectPool(100)
self.assertRaises(NotImplementedError, pool._pool_create)
self.assertRaises(NotImplementedError, pool._pool_verify, None)

def test_put_not_implemented(self):
"""Test pool_put() method not implemented in abstract class"""
pool = ObjectPool(100)
self.assertRaises(NotImplementedError, pool._pool_cleanup, None)
def test_get_with_factory(self):
obj_generator = iter(range(10)).next
pool = ObjectPool(3, create=obj_generator)
self.assertEqual(pool.pool_get(), 0)
self.assertEqual(pool.pool_get(), 1)
self.assertEqual(pool.pool_get(), 2)

def test_put_with_factory(self):
cleaned_objects = []
pool = ObjectPool(3,
create=[2, 1, 0].pop,
verify=lambda o: o % 2 == 0,
cleanup=cleaned_objects.append,
)
self.assertEqual(pool.pool_get(), 0)
pool.pool_put(0)
self.assertEqual(pool.pool_get(), 0)
self.assertRaises(PoolVerificationError, pool.pool_get)
self.assertEqual(pool.pool_get(), 2)
self.assertEqual(cleaned_objects, [0])


class NumbersPoolTestCase(unittest.TestCase):
Expand Down