You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
95 lines
2.9 KiB
95 lines
2.9 KiB
import os
|
|
import pprint
|
|
import sys
|
|
import unittest
|
|
|
|
from .messages import text_type
|
|
|
|
_PY2K = sys.version_info < (3,)
|
|
_PRIMITIVES = [int, str, bool]
|
|
if _PY2K:
|
|
# Not available in py3
|
|
# noinspection PyUnresolvedReferences
|
|
_PRIMITIVES.append(unicode) # noqa
|
|
# Not available in py3
|
|
# noinspection PyUnresolvedReferences
|
|
_STR_F = unicode # noqa
|
|
else:
|
|
_STR_F = str
|
|
|
|
|
|
def patch_unittest_diff(test_filter=None):
|
|
"""
|
|
Patches "assertEquals" to throw DiffError.
|
|
|
|
@:param test_filter callback to check each test. If not None it should return True to test or EqualsAssertionError will be skipped
|
|
"""
|
|
if sys.version_info < (2, 7):
|
|
return
|
|
old = unittest.TestCase.assertEqual
|
|
|
|
def _patched_equals(self, first, second, msg=None):
|
|
try:
|
|
old(self, first, second, msg)
|
|
return
|
|
except AssertionError as e:
|
|
if not test_filter or test_filter(self):
|
|
error = EqualsAssertionError(first, second, msg, real_exception=e)
|
|
if error.can_be_serialized():
|
|
from .jb_local_exc_store import store_exception
|
|
store_exception(error)
|
|
raise
|
|
|
|
unittest.TestCase.assertEqual = _patched_equals
|
|
|
|
|
|
def _format_and_convert(val):
|
|
if "_JB_PPRINT_PRIMITIVES" in os.environ:
|
|
return pprint.pformat(val)
|
|
# No need to pretty-print primitives
|
|
return val if any(x for x in _PRIMITIVES if isinstance(val, x)) else pprint.pformat(val)
|
|
|
|
|
|
class EqualsAssertionError(AssertionError):
|
|
MESSAGE_SEP = " :: "
|
|
NOT_EQ_SEP = " != "
|
|
|
|
# Real exception could be provided, but not serialized
|
|
def __init__(self, expected, actual, msg=None, preformated=False, real_exception=None):
|
|
self.real_exception = real_exception
|
|
self.expected = expected
|
|
self.actual = actual
|
|
self.msg = text_type(msg)
|
|
|
|
if not preformated:
|
|
self.expected = _format_and_convert(self.expected)
|
|
self.actual = _format_and_convert(self.actual)
|
|
self.msg = text_type(msg) if msg else ""
|
|
|
|
self.expected = _STR_F(self.expected)
|
|
self.actual = _STR_F(self.actual)
|
|
|
|
def can_be_serialized(self):
|
|
if any([self.MESSAGE_SEP in s or self.NOT_EQ_SEP in s for s in [self.expected, self.actual, self.msg]]):
|
|
return False
|
|
return len(self.actual) + len(self.expected) < 10000
|
|
|
|
def __str__(self):
|
|
return self._serialize()
|
|
|
|
def __unicode__(self):
|
|
return self._serialize()
|
|
|
|
def _serialize(self):
|
|
return self.msg + self.MESSAGE_SEP + self.expected + self.NOT_EQ_SEP + self.actual
|
|
|
|
@classmethod
|
|
def deserialize_error(cls, serialized_message):
|
|
message, diff = serialized_message.split(cls.MESSAGE_SEP)
|
|
exp, act = diff.split(cls.NOT_EQ_SEP)
|
|
return EqualsAssertionError(exp, act, message, preformated=True)
|
|
|
|
|
|
def deserialize_error(serialized_message):
|
|
return EqualsAssertionError.deserialize_error(serialized_message)
|