Skip to content

Commit

Permalink
We depend on msgpack >= 1.0, simplify logic
Browse files Browse the repository at this point in the history
  • Loading branch information
s0undt3ch committed Feb 26, 2024
1 parent c990077 commit ef099c3
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 96 deletions.
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Jinja2
jmespath
msgpack>=0.5,!=0.5.5
msgpack>=1.0.0
PyYAML
MarkupSafe
requests>=1.0.0
Expand Down
27 changes: 8 additions & 19 deletions salt/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,26 +84,15 @@ def ext_type_decoder(code, data):

gc.disable() # performance optimization for msgpack
loads_kwargs = {"use_list": True, "ext_hook": ext_type_decoder}
if salt.utils.msgpack.version >= (0, 4, 0):
# msgpack only supports 'encoding' starting in 0.4.0.
# Due to this, if we don't need it, don't pass it at all so
# that under Python 2 we can still work with older versions
# of msgpack.
if salt.utils.msgpack.version >= (0, 5, 2):
if encoding is None:
loads_kwargs["raw"] = True
else:
loads_kwargs["raw"] = False
else:
loads_kwargs["encoding"] = encoding
try:
ret = salt.utils.msgpack.unpackb(msg, **loads_kwargs)
except UnicodeDecodeError:
# msg contains binary data
loads_kwargs.pop("raw", None)
loads_kwargs.pop("encoding", None)
ret = salt.utils.msgpack.loads(msg, **loads_kwargs)
if encoding is None:
loads_kwargs["raw"] = True
else:
loads_kwargs["raw"] = False
try:
ret = salt.utils.msgpack.unpackb(msg, **loads_kwargs)
except UnicodeDecodeError:
# msg contains binary data
loads_kwargs.pop("raw", None)
ret = salt.utils.msgpack.loads(msg, **loads_kwargs)
if encoding is None and not raw:
ret = salt.transport.frame.decode_embedded_strs(ret)
Expand Down
16 changes: 2 additions & 14 deletions salt/transport/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,7 @@ def return_message(msg):
else:
return _null

# msgpack deprecated `encoding` starting with version 0.5.2
if salt.utils.msgpack.version >= (0, 5, 2):
# Under Py2 we still want raw to be set to True
msgpack_kwargs = {"raw": False}
else:
msgpack_kwargs = {"encoding": "utf-8"}
unpacker = salt.utils.msgpack.Unpacker(**msgpack_kwargs)
unpacker = salt.utils.msgpack.Unpacker(raw=False)
while not stream.closed():
try:
wire_bytes = yield stream.read_bytes(4096, partial=True)
Expand Down Expand Up @@ -280,13 +274,7 @@ def __init__(self, socket_path, io_loop=None):
self.socket_path = socket_path
self._closing = False
self.stream = None
# msgpack deprecated `encoding` starting with version 0.5.2
if salt.utils.msgpack.version >= (0, 5, 2):
# Under Py2 we still want raw to be set to True
msgpack_kwargs = {"raw": False}
else:
msgpack_kwargs = {"encoding": "utf-8"}
self.unpacker = salt.utils.msgpack.Unpacker(**msgpack_kwargs)
self.unpacker = salt.utils.msgpack.Unpacker(raw=False)
self._connecting_future = None

def connected(self):
Expand Down
35 changes: 7 additions & 28 deletions salt/utils/msgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@

# There is a serialization issue on ARM and potentially other platforms for some msgpack bindings, check for it
if (
msgpack.version >= (0, 4, 0)
and msgpack.loads(msgpack.dumps([1, 2, 3], use_bin_type=False), use_list=True)
msgpack.loads(msgpack.dumps([1, 2, 3], use_bin_type=False), use_list=True)
is None
):
raise ImportError
elif msgpack.loads(msgpack.dumps([1, 2, 3]), use_list=True) is None:
raise ImportError
HAS_MSGPACK = True
except ImportError:
try:
Expand Down Expand Up @@ -59,13 +56,7 @@ def _sanitize_msgpack_kwargs(kwargs):
https://github.com/msgpack/msgpack-python/blob/master/ChangeLog.rst
"""
assert isinstance(kwargs, dict)
if version < (0, 6, 0) and kwargs.pop("strict_map_key", None) is not None:
log.info("removing unsupported `strict_map_key` argument from msgpack call")
if version < (0, 5, 2) and kwargs.pop("raw", None) is not None:
log.info("removing unsupported `raw` argument from msgpack call")
if version < (0, 4, 0) and kwargs.pop("use_bin_type", None) is not None:
log.info("removing unsupported `use_bin_type` argument from msgpack call")
if version >= (1, 0, 0) and kwargs.pop("encoding", None) is not None:
if kwargs.pop("encoding", None) is not None:
log.debug("removing unsupported `encoding` argument from msgpack call")

return kwargs
Expand All @@ -78,32 +69,20 @@ def _sanitize_msgpack_unpack_kwargs(kwargs):
https://github.com/msgpack/msgpack-python/blob/master/ChangeLog.rst
"""
assert isinstance(kwargs, dict)
if version >= (1, 0, 0):
kwargs.setdefault("raw", True)
kwargs.setdefault("strict_map_key", False)
kwargs.setdefault("raw", True)
kwargs.setdefault("strict_map_key", False)
return _sanitize_msgpack_kwargs(kwargs)


def _add_msgpack_unpack_kwargs(kwargs):
"""
Add any msgpack unpack kwargs here.
max_buffer_size: will make sure the buffer is set to a minimum
of 100MiB in versions >=6 and <1.0
"""
assert isinstance(kwargs, dict)
if version >= (0, 6, 0) and version < (1, 0, 0):
kwargs["max_buffer_size"] = 100 * 1024 * 1024
return _sanitize_msgpack_unpack_kwargs(kwargs)


class Unpacker(msgpack.Unpacker):
"""
Wraps the msgpack.Unpacker and removes non-relevant arguments
"""

def __init__(self, *args, **kwargs):
msgpack.Unpacker.__init__(self, *args, **_add_msgpack_unpack_kwargs(kwargs))
msgpack.Unpacker.__init__(
self, *args, **_sanitize_msgpack_unpack_kwargs(kwargs)
)


def pack(o, stream, **kwargs):
Expand Down
50 changes: 16 additions & 34 deletions tests/unit/utils/test_msgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,7 @@
from salt.utils.odict import OrderedDict
from tests.support.unit import TestCase

try:
import msgpack
except ImportError:
import msgpack_pure as msgpack # pylint: disable=import-error


# A keyword to pass to tests that use `raw`, which was added in msgpack 0.5.2
raw = {"raw": False} if msgpack.version > (0, 5, 2) else {}
msgpack = pytest.importorskip("msgpack")


@pytest.mark.skipif(
Expand Down Expand Up @@ -156,10 +149,7 @@ def test_map_size(self):
bio.write(packer.pack(i * 2)) # value

bio.seek(0)
if salt.utils.msgpack.version > (0, 6, 0):
unpacker = salt.utils.msgpack.Unpacker(bio, strict_map_key=False)
else:
unpacker = salt.utils.msgpack.Unpacker(bio)
unpacker = salt.utils.msgpack.Unpacker(bio, strict_map_key=False)
for size in sizes:
self.assertEqual(unpacker.unpack(), {i: i * 2 for i in range(size)})

Expand Down Expand Up @@ -293,7 +283,7 @@ def _test_unpacker_ext_hook(self, pack_func, **kwargs):
class MyUnpacker(salt.utils.msgpack.Unpacker):
def __init__(self):
my_kwargs = {}
super().__init__(ext_hook=self._hook, **raw)
super().__init__(ext_hook=self._hook, raw=False)

def _hook(self, code, data):
if code == 1:
Expand All @@ -314,21 +304,20 @@ def _hook(self, code, data):
def _check(
self, data, pack_func, unpack_func, use_list=False, strict_map_key=False
):
my_kwargs = {}
if salt.utils.msgpack.version >= (0, 6, 0):
my_kwargs["strict_map_key"] = strict_map_key
ret = unpack_func(pack_func(data), use_list=use_list, **my_kwargs)
ret = unpack_func(
pack_func(data), use_list=use_list, strict_map_key=strict_map_key
)
self.assertEqual(ret, data)

def _test_pack_unicode(self, pack_func, unpack_func):
test_data = ["", "abcd", ["defgh"], "Русский текст"]
for td in test_data:
ret = unpack_func(pack_func(td), use_list=True, **raw)
ret = unpack_func(pack_func(td), use_list=True, raw=False)
self.assertEqual(ret, td)
packer = salt.utils.msgpack.Packer()
data = packer.pack(td)
ret = salt.utils.msgpack.Unpacker(
BytesIO(data), use_list=True, **raw
BytesIO(data), use_list=True, raw=False
).unpack()
self.assertEqual(ret, td)

Expand All @@ -352,30 +341,30 @@ def _test_pack_byte_arrays(self, pack_func, unpack_func):

def _test_ignore_unicode_errors(self, pack_func, unpack_func):
ret = unpack_func(
pack_func(b"abc\xeddef", use_bin_type=False), unicode_errors="ignore", **raw
pack_func(b"abc\xeddef", use_bin_type=False),
unicode_errors="ignore",
raw=False,
)
self.assertEqual("abcdef", ret)

def _test_strict_unicode_unpack(self, pack_func, unpack_func):
packed = pack_func(b"abc\xeddef", use_bin_type=False)
self.assertRaises(UnicodeDecodeError, unpack_func, packed, use_list=True, **raw)
self.assertRaises(
UnicodeDecodeError, unpack_func, packed, use_list=True, raw=False
)

def _test_ignore_errors_pack(self, pack_func, unpack_func):
ret = unpack_func(
pack_func("abc\uDC80\uDCFFdef", use_bin_type=True, unicode_errors="ignore"),
use_list=True,
**raw
raw=False,
)
self.assertEqual("abcdef", ret)

def _test_decode_binary(self, pack_func, unpack_func):
ret = unpack_func(pack_func(b"abc"), use_list=True)
self.assertEqual(b"abc", ret)

@pytest.mark.skipif(
salt.utils.msgpack.version < (0, 2, 2),
"use_single_float was added in msgpack==0.2.2",
)
def _test_pack_float(self, pack_func, **kwargs):
self.assertEqual(
b"\xca" + struct.pack(">f", 1.0), pack_func(1.0, use_single_float=True)
Expand All @@ -402,16 +391,9 @@ def _test_pair_list(self, unpack_func, **kwargs):
pairlist = [(b"a", 1), (2, b"b"), (b"foo", b"bar")]
packer = salt.utils.msgpack.Packer()
packed = packer.pack_map_pairs(pairlist)
if salt.utils.msgpack.version > (0, 6, 0):
unpacked = unpack_func(packed, object_pairs_hook=list, strict_map_key=False)
else:
unpacked = unpack_func(packed, object_pairs_hook=list)
unpacked = unpack_func(packed, object_pairs_hook=list, strict_map_key=False)
self.assertEqual(pairlist, unpacked)

@pytest.mark.skipif(
salt.utils.msgpack.version < (0, 6, 0),
"getbuffer() was added to Packer in msgpack 0.6.0",
)
def _test_get_buffer(self, pack_func, **kwargs):
packer = msgpack.Packer(autoreset=False, use_bin_type=True)
packer.pack([1, 2])
Expand Down

0 comments on commit ef099c3

Please sign in to comment.