Skip to content

Commit

Permalink
Add method to remove circular references in data objects and add test (
Browse files Browse the repository at this point in the history
…#54930)

* Add method to remove circular references in data objects and add test

* remove trailing whitespace

* Blacken changed files

Co-authored-by: xeacott <[email protected]>
Co-authored-by: Frode Gundersen <[email protected]>
Co-authored-by: Daniel A. Wozniak <[email protected]>
  • Loading branch information
4 people authored May 4, 2020
1 parent 1c7ce9d commit c6638f7
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
59 changes: 59 additions & 0 deletions salt/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,46 @@ def compare_lists(old=None, new=None):
return ret


def _remove_circular_refs(ob, _seen=None):
"""
Generic method to remove circular references from objects.
This has been taken from author Martijn Pieters
https://stackoverflow.com/questions/44777369/
remove-circular-references-in-dicts-lists-tuples/44777477#44777477
:param ob: dict, list, typle, set, and frozenset
Standard python object
:param object _seen:
Object that has circular reference
:returns:
Cleaned Python object
:rtype:
type(ob)
"""
if _seen is None:
_seen = set()
if id(ob) in _seen:
# Here we caught a circular reference.
# Alert user and cleanup to continue.
log.exception(
"Caught a circular reference in data structure below."
"Cleaning and continuing execution.\n%r\n",
ob,
)
return None
_seen.add(id(ob))
res = ob
if isinstance(ob, dict):
res = {
_remove_circular_refs(k, _seen): _remove_circular_refs(v, _seen)
for k, v in ob.items()
}
elif isinstance(ob, (list, tuple, set, frozenset)):
res = type(ob)(_remove_circular_refs(v, _seen) for v in ob)
# remove id again; only *nested* references count
_seen.remove(id(ob))
return res


def decode(
data,
encoding=None,
Expand Down Expand Up @@ -211,7 +251,11 @@ def decode(
two strings above, in which "й" is represented as two code points (i.e. one
for the base character, and one for the breve mark). Normalizing allows for
a more reliable test case.
"""
# Clean data object before decoding to avoid circular references
data = _remove_circular_refs(data)

_decode_func = (
salt.utils.stringutils.to_unicode
if not to_str
Expand Down Expand Up @@ -283,6 +327,9 @@ def decode_dict(
Decode all string values to Unicode. Optionally use to_str=True to ensure
strings are str types and not unicode on Python 2.
"""
# Clean data object before decoding to avoid circular references
data = _remove_circular_refs(data)

_decode_func = (
salt.utils.stringutils.to_unicode
if not to_str
Expand Down Expand Up @@ -395,6 +442,9 @@ def decode_list(
Decode all string values to Unicode. Optionally use to_str=True to ensure
strings are str types and not unicode on Python 2.
"""
# Clean data object before decoding to avoid circular references
data = _remove_circular_refs(data)

_decode_func = (
salt.utils.stringutils.to_unicode
if not to_str
Expand Down Expand Up @@ -493,7 +543,11 @@ def encode(
original value to silently be returned in cases where encoding fails. This
can be useful for cases where the data passed to this function is likely to
contain binary blobs.
"""
# Clean data object before encoding to avoid circular references
data = _remove_circular_refs(data)

if isinstance(data, Mapping):
return encode_dict(
data, encoding, errors, keep, preserve_dict_class, preserve_tuples
Expand Down Expand Up @@ -536,6 +590,8 @@ def encode_dict(
"""
Encode all string values to bytes
"""
# Clean data object before encoding to avoid circular references
data = _remove_circular_refs(data)
ret = data.__class__() if preserve_dict_class else {}
for key, value in six.iteritems(data):
if isinstance(key, tuple):
Expand Down Expand Up @@ -603,6 +659,9 @@ def encode_list(
"""
Encode all string values to bytes
"""
# Clean data object before encoding to avoid circular references
data = _remove_circular_refs(data)

ret = []
for item in data:
if isinstance(item, list):
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/utils/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,27 @@ def test_decode(self):
self.assertEqual(salt.utils.data.decode(BYTES, keep=True), BYTES)
self.assertRaises(UnicodeDecodeError, salt.utils.data.decode, BYTES, keep=False)

def test_circular_refs_dicts(self):
test_dict = {"key": "value", "type": "test1"}
test_dict["self"] = test_dict
ret = salt.utils.data._remove_circular_refs(ob=test_dict)
self.assertDictEqual(ret, {"key": "value", "type": "test1", "self": None})

def test_circular_refs_lists(self):
test_list = {
"foo": [],
}
test_list["foo"].append((test_list,))
ret = salt.utils.data._remove_circular_refs(ob=test_list)
self.assertDictEqual(ret, {"foo": [(None,)]})

def test_circular_refs_tuple(self):
test_dup = {"foo": "string 1", "bar": "string 1", "ham": 1, "spam": 1}
ret = salt.utils.data._remove_circular_refs(ob=test_dup)
self.assertDictEqual(
ret, {"foo": "string 1", "bar": "string 1", "ham": 1, "spam": 1}
)

def test_decode_to_str(self):
"""
Companion to test_decode, they should both be kept up-to-date with one
Expand Down

0 comments on commit c6638f7

Please sign in to comment.