class Parent(object):
@cached_property
def test(self):
return 1
class Child(Parent):
@cached_property
def test(self):
result = super(Child, self).test
result = result / 0
return result
child_instance = Child()
print(child_instance.test)
# gives a traceback
print(child_instance.test)
# prints "1"
diff --git a/cached_property.py b/cached_property.py
index 67fa01a..6122eab 100644
--- a/cached_property.py
+++ b/cached_property.py
@@ -32,7 +32,13 @@ class cached_property(object):
if asyncio and asyncio.iscoroutinefunction(self.func):
return self._wrap_in_coroutine(obj)
- value = obj.__dict__[self.func.__name__] = self.func(obj)
+ try:
+ value = self.func(obj)
+ except Exception:
+ obj.__dict__.pop(self.func.__name__, None)
+ raise
+
+ obj.__dict__[self.func.__name__] = value
return value
def _wrap_in_coroutine(self, obj):
@@ -69,7 +75,12 @@ class threaded_cached_property(object):
except KeyError:
# if not, do the calculation and release the lock
- return obj_dict.setdefault(name, self.func(obj))
+ try:
+ value = self.func(obj)
+ except Exception:
+ obj_dict.pop(name, None)
+ raise
+ return obj_dict.setdefault(name, value)
class cached_property_with_ttl(object):
@@ -108,7 +119,11 @@ class cached_property_with_ttl(object):
if not ttl_expired:
return value
- value = self.func(obj)
+ try:
+ value = self.func(obj)
+ except Exception:
+ obj_dict.pop(name, None)
+ raise
obj_dict[name] = (value, now)
return value
diff --git a/tests/test_cached_property.py b/tests/test_cached_property.py
index 5082416..e3499b2 100644
--- a/tests/test_cached_property.py
+++ b/tests/test_cached_property.py
@@ -88,6 +88,23 @@ class TestCachedProperty(unittest.TestCase):
# rather than through an instance.
self.assertTrue(isinstance(Check.add_cached, self.cached_property_factory))
+ def test_error_on_inheritance(self):
+ class CheckWithInheritance(CheckFactory(self.cached_property_factory)):
+ @self.cached_property_factory
+ def add_cached(inst):
+ return super(CheckWithInheritance, inst).add_cached / 0
+
+ check = CheckWithInheritance()
+ stub_to_check = object()
+
+ result = stub_to_check
+ with self.assertRaises(ZeroDivisionError):
+ result = check.add_cached
+ with self.assertRaises(ZeroDivisionError):
+ result = check.add_cached
+
+ self.assertEqual(result, stub_to_check)
+
def test_reset_cached_property(self):
Check = CheckFactory(self.cached_property_factory)
check = Check()