# -*- coding: utf-8 -*-
from ..utils.common import lls, ld2tk, bstr
from .serializable import Serializable
import logging
from collections.abc import Mapping, Sequence, Set
from itertools import chain
from collections import OrderedDict
from functools import lru_cache
import array
import decimal
import fractions
import pprint
import sys
import hashlib
HASH_WIDTH = sys.hash_info.width // 8
if sys.version_info[0] + 0.1 * sys.version_info[1] >= 3.6:
PY36 = True
else:
PY36 = False
# create logger
logger = logging.getLogger(__name__)
# logger.debug('level %d' % (logger.getEffectiveLevel()))
[文档]class CircularCallError(RuntimeError):
pass
DEEPCMP_RESULT = None
[文档]def deepcmp(obj1, obj2, seenlist=None, verbose=False, eqcmp=False):
""" Recursively descends into obj1's every component and
compares with its counterpart in obj2.
Factors includes;
* if they are the same object
* type
* quick string
* ```__eq__``` or ```__cmp__``` if requested
* state from ```__getstate__```
* quick length
* members if is ```Mapping```, ```Sequence``` (
set, list, dict, ordereddict, UserDict ... )
* properties/attributes in ```__dict__```
Detects cyclic references.
Returns
-------
``None`` if finds no difference, a string of explanation
otherwise.
:eqcmp: if True, use __eq__ or __cmp__ if the objs have them. If False only use as the last resort. default True.
"""
global DEEPCMP_RESULT
# seen and level are to be used as nonlocal variables in run()
# to overcome python2's lack of nonlocal type this method is usded
# https://stackoverflow.com/a/28433571
class _context:
if seenlist is None:
seen = []
else:
seen = seenlist
level = 0
def run(o1, o2, v=False, eqcmp=True):
"""
Paremeters
----------
Returns
-------
"""
#
# nonlocal seen
# nonlocal level
id1, id2 = id(o1), id(o2)
if id1 == id2:
if v:
print('These are the same object o1=%s ||| o2=%s.' %
(bstr(o1, 20), bstr(o2, 20)))
return None
pair = (id1, id2) if id1 < id2 else (id2, id1)
c = o1.__class__
c2 = o2.__class__
_context.level += 1
if v:
print('deepcmp level %d seenlist length %d' %
(_context.level, len(_context.seen)))
print('1 ' + str(c) + lls(o1, 45))
print('2 ' + str(c2) + lls(o2, 45))
if pair in _context.seen:
msg = 'deja vue %s' % str(pair)
raise CircularCallError(msg)
_context.seen.append(pair)
if c != c2:
if v:
print('type diff')
_context.level -= 1
del _context.seen[-1]
return ' due to diff types: ' + c.__name__ + ' and ' + c2.__name__
if c == str:
if v:
print('find strings')
_context.level -= 1
del _context.seen[-1]
if o1 != o2:
return ' due to difference: "%s" ||| "%s"' % (o1, o2)
else:
return None
has_eqcmp = (hasattr(o1, '__eq__') or hasattr(
o1, '__cmp__')) and not issubclass(c, DeepEqual)
if eqcmp and has_eqcmp:
if v:
print('obj1 has __eq__ or __cmp__ and not using deepcmp')
# checked in-seen to ensure whst follows will not cause RecursionError
try:
t = o1 == o2
except CircularCallError as e:
if v:
print('Get circular call using eq/cmp: '+str(e))
pass
else:
_context.level -= 1
del _context.seen[-1]
if t:
return None
else: # o1 != o2:
s = ' due to "%s" != "%s"' % (lls(o1, 155), lls(o2, 155))
return s
if hasattr(o1, '__getstate__'):
if v:
print('Find __getstate__')
try:
o1 = o1.__getstate__()
o2 = o2.__getstate__()
except TypeError:
logger.error('__getstate__ trouble')
raise
else: # no exception for __getstate__
r = run(o1, o2, v=v, eqcmp=eqcmp)
del _context.seen[-1]
_context.level -= 1
if r:
return ' due to o1.__getstate__ != o2.__getstate__' + r
else:
return None
try:
# this is not good if len() is delegated
# if hasattr(o1, '__len__') and len(o1) != len(o2):
if hasattr(o1, '__len__') and len(o1) != len(o2):
del _context.seen[-1]
_context.level -= 1
return ' due to diff %s lengths: %d and %d (%s, %s)' %\
(c.__name__, len(o1), len(o2), lls(
list(o1), 115), lls(list(o2), 115))
except AttributeError:
pass
if issubclass(c, Mapping):
if v:
print('Find Mapping')
print('check keys')
from .odict import ODict
if issubclass(c, (OrderedDict, ODict)) or PY36:
#
r = run(list(o1.keys()), list(o2.keys()), v=v, eqcmp=eqcmp)
else:
# old dict or UserDict
r = run(tuple(sorted(o1.keys(), key=hash)),
tuple(sorted(o1.keys(), key=hash)),
v=v, eqcmp=eqcmp)
if r is not None:
del _context.seen[-1]
_context.level -= 1
return " due to diff " + c.__name__ + " keys" + r
if v:
print('check values')
for k in o1.keys():
if k not in o2:
del _context.seen[-1]
_context.level -= 1
return ' due to o2 has no key=%s' % (lls(k, 155))
r = run(o1[k], o2[k], v=v, eqcmp=eqcmp)
if r is not None:
s = ' due to diff values for key=%s' % (lls(k, 155))
del _context.seen[-1]
_context.level -= 1
return s + r
del _context.seen[-1]
_context.level -= 1
return None
elif issubclass(c, (Set, Sequence)):
if v:
print('Find Set, Sequence.')
if issubclass(c, Sequence):
if v:
print('Check Sequence.')
for i in range(len(o1)):
r = run(o1[i], o2[i], v=v, eqcmp=eqcmp)
if r is not None:
del _context.seen[-1]
_context.level -= 1
return ' due to diff at index=%d (%s %s)' % \
(i,
lls(o1[i], 10),
lls(o2[i], 10)) + r
_context.level -= 1
del _context.seen[-1]
return None
else:
if v:
print('Check Set.')
if 1:
del _context.seen[-1]
_context.level -= 1
if o1.difference(o2):
return ' due to at least one in the foremer not in the latter'
else:
return None
else:
oc = o2.copy()
for m in o1:
found = False
for n in oc:
r = run(m, n, v=v, eqcmp=eqcmp)
if r is None:
found = True
break
if not found:
del _context.seen[-1]
_context.level -= 1
return ' due to %s not in the latter' % (lls(m, 155))
oc.remove(n)
del _context.seen[-1]
_context.level -= 1
return None
else:
if hasattr(o1, '__dict__'):
if v:
print('obj1 has __dict__')
o1 = sorted(vars(o1).items())
o2 = sorted(vars(o2).items())
r = run(o1, o2, v=v, eqcmp=eqcmp)
del _context.seen[-1]
_context.level -= 1
if r:
return ' due to o1.__dict__ != o2.__dict__' + r
else:
return None
# elif hasattr(o1, '__iter__') and hasattr(o1, '__next__') or \
# hasattr(o1, '__getitem__'):
# # two iterators are equal if all comparable properties are equal.
# del _context.seen[-1]
# _context.level -= 1
# return None
elif has_eqcmp:
# last resort
if o1 == o2:
del _context.seen[-1]
_context.level -= 1
return None
else:
del _context.seen[-1]
_context.level -= 1
return ' according to __eq__ or __cmp__'
else: # o1 != o2:
if v:
print('no way')
s = ' due to no reason found for "%s" == "%s"' % (
lls(o1, 155), lls(o2, 155))
del _context.seen[-1]
_context.level -= 1
return s
res = run(obj1, obj2, v=verbose, eqcmp=eqcmp)
DEEPCMP_RESULT = res
return res
XHASH_VERBOSE = False
[文档]def xhash(hash_list=None, seenlist=None, verbose=None):
""" get the hash of a tuple of hashes of all members of given sequence.
:hash_list: use instead of self.getstate__()
:verbose: set to trace.
"""
if verbose is None:
verbose = XHASH_VERBOSE
# https://stackoverflow.com/a/28433571
class _context:
if seenlist is None:
seen = []
else:
seen = seenlist
level = 0
def run(hash_list=None):
_context.level += 1
ind = ' ' * _context.level
hashes = []
if 0 and verbose:
print('entering id%d id%d lv%d len%d' % (id(_context.level), id(_context.seen),
_context.level, len(_context.seen)))
hlid = id(hash_list)
if hlid in _context.seen:
if verbose:
print(ind + 'seen it')
_context.level -= 1
del _context.seen[-1]
return 0
_context.seen.append(hlid)
if issubclass(hash_list.__class__, int):
res = hash_list
if verbose:
print(ind + 'int "%s" -- %s' % (lls(hash_list, 20), res))
_context.level -= 1
del _context.seen[-1]
return res
elif issubclass(hash_list.__class__, (float, decimal.Decimal, fractions.Fraction)):
res = hash(hash_list)
if verbose:
print(ind + '%s "%s" -- %s' %
(hash_list.__class__.__name__, lls(hash_list, 20), res))
_context.level -= 1
del _context.seen[-1]
return res
elif issubclass(hash_list.__class__, (str, bytes)):
# put str first so it is not treated as a sequence
res = hash(hash_list)
if verbose:
print(ind + 'str/bytes "%s" -- %s' %
(lls(hash_list, 20), res))
_context.level -= 1
del _context.seen[-1]
return res
elif issubclass(hash_list.__class__, (array.array)):
hasher = hashlib.new('sha256', hash_list.typecode.encode('utf-8'))
hasher.update(hash_list)
res = int.from_bytes(
hasher.digest()[:HASH_WIDTH], byteorder=sys.byteorder)
# source = (hash_list.typecode,
# hash_list.itemsize,
# len(hash_list),
# len(hash_list[0]))
if verbose:
print(ind + '%s %s %s' % (hash_list.__class__.__name__,
lls(hash_list, 20), res))
_context.level -= 1
del _context.seen[-1]
return res
elif hasattr(hash_list, '__getstate__'):
try:
o = hash_list.__getstate__()
except TypeError:
logger.error('__getstate__ trouble')
raise
else: # no exception for __getstate__
source = chain.from_iterable(o.items())
if verbose:
print(ind + '%s %s %s' % (hash_list.__class__.__name__,
lls(source, 20), 'has __getstate__'))
elif issubclass(hash_list.__class__, (Set, Sequence)):
source = hash_list
if verbose:
print(ind + '%s %s %s' % (hash_list.__class__.__name__,
lls(source, 20), 'is Sequence'))
elif issubclass(hash_list.__class__, Mapping):
source = chain.from_iterable(hash_list.items())
if verbose:
print(ind + '%s %s %s' % (hash_list.__class__.__name__,
lls(source, 20), 'is Mapping'))
else:
res = hash(hash_list)
if verbose:
print(ind + '%s %s -- %s' % (hash_list.__class__.__name__,
lls(hash_list, 20), res))
_context.level -= 1
del _context.seen[-1]
return res
for t in source:
if hasattr(t, 'hash'):
h = t.hash()
else:
h = run(t)
if verbose:
print(ind + '> %s %s -- %s' % (h.__class__.__name__,
lls(h, 20), h))
hashes.append(h)
# if there is only one element only hash the element
res = hash(hashes[0] if len(hashes) == 1 else tuple(hashes))
if verbose:
print(ind + '%s %s -- %s' % ('RET', str(len(hashes)), res))
_context.level -= 1
del _context.seen[-1]
return res
return run(hash_list=hash_list)
class DeepcmpEqual(object):
""" mh: Can compare key-val pairs of another object
with self. False if compare with None
or exceptions raised, e.g. obj does not have items()
"""
def __init__(self, **kwds):
super().__init__(**kwds)
[文档] def equals(self, obj, verbose=False):
"""
Paremeters
----------
Returns
-------
"""
r = self.diff(obj, [], verbose=verbose)
# logging.debug(r)
return r is None
def __eq__(self, obj):
"""
Paremeters
----------
Returns
-------
"""
return self.equals(obj)
def __ne__(self, obj):
"""
Paremeters
----------
Returns
-------
"""
return not self.__eq__(obj)
def diff(self, obj, seenlist, verbose=False):
""" recursively compare components of list and dict.
until meeting equality.
seenlist: a list of classes that has been seen. will not descend in to them.
Paremeters
----------
Returns
-------
"""
if issubclass(self.__class__, Serializable):
if issubclass(obj.__class__, Serializable):
r = deepcmp(self.__getstate__(),
obj.__getstate__(), seenlist=seenlist, verbose=verbose)
else:
return('different classes')
else:
r = deepcmp(self, obj, seenlist=seenlist, verbose=verbose)
return r
[文档]class EqualDict(object):
""" mh: Can compare key-val pairs of another object
with self. False if compare with None
or exceptions raised, e.g. obj does not have items()
"""
def __init__(self, **kwds):
super().__init__(**kwds)
[文档] def equals(self, obj, verbose=False):
"""
Paremeters
----------
Returns
-------
"""
if obj is None:
return False
try:
if self.__dict__ != obj.__dict__:
if verbose:
print('@@ diff ' + lls(self.__dict__) +
'\n>>diff \n' + lls(obj.__dict__))
return False
except Exception as err:
# print('Exception in dict eq comparison ' + lls(err))
return False
return True
def __eq__(self, obj):
"""
Paremeters
----------
Returns
-------
"""
return self.equals(obj)
def __ne__(self, obj):
"""
Paremeters
----------
Returns
-------
"""
return not self.__eq__(obj)
[文档]class EqualODict(object):
""" mh: Can compare order and key-val pairs of another object
with self. False if compare with None
or exceptions raised, e.g. obj does not have items()
"""
def __init__(self, **kwds):
super().__init__(**kwds)
[文档] def equals(self, obj, verbose=False):
"""
Paremeters
----------
Returns
-------
"""
if obj is None:
return False
try:
return list(self.items()) == list(obj.items())
except Exception:
return False
return True
def __eq__(self, obj):
"""
Paremeters
----------
Returns
-------
"""
return self.equals(obj)
def __ne__(self, obj):
"""
Paremeters
----------
Returns
-------
"""
return not self.__eq__(obj)
[文档]class StateEqual():
""" Equality tested by hashed state.
"""
[文档] def __init__(self, *args, **kwds):
""" Must pass *args* so `DataWrapper` in `Composite` can get `data`.
"""
super().__init__(*args, **kwds) # StateEqual
[文档] def hash(self, **kwds):
return xhash(self.__getstate__(), **kwds)
def __eq__(self, obj, **kwds):
""" compares hash. """
if obj is None:
return False
if id(self) == id(obj):
return True
if type(self) != type(obj):
return False
try:
h1 = self.hash()
h2 = obj.hash()
except AttributeError:
return False
# print('hashes ', h1, h2)
return h1 == h2
equals = __eq__
def __xne__(self, obj):
return not self.__eq__(obj)
def __hash__(self, **kwds):
return self.hash(**kwds)
__hash__ = hash
[文档]class DeepcmpEqual():
""" Equality tested by `deepcmp`.
"""
[文档] def __init__(self, *args, **kwds):
""" Must pass *args* so `DataWrapper` in `Composite` can get `data`.
"""
super().__init__(*args, **kwds) # DeepcmpEqual
def __eq__(self, obj, **kwds):
""" compares using `deepcmp. """
if obj is None:
return False
if id(self) == id(obj):
return True
if type(self) != type(obj):
return False
res = deepcmp(self, obj, **kwds)
return res is None
equals = __eq__
def __xne__(self, obj):
return not self.__eq__(obj)
def __hash__(self, **kwds):
return self.hash(**kwds)
__hash__ = hash
DeepEqual = DeepcmpEqual