Source code for dcos_test_utils.recordio
"""
Provides facilities for "Record-IO" encoding of data.
"Record-IO" encoding allows one to encode a sequence
of variable-length records by prefixing each record
with its size in bytes:
5\n
hello
6\n
world!
Note that this currently only supports record lengths
encoded as base 10 integer values with newlines as a
delimiter. This is to provide better language
portability: parsing a base 10 integer is simple. Most
other "Record-IO" implementations use a fixed-size
header of 4 bytes to directly encode an unsigned 32 bit
length.
"""
[docs]class Encoder():
"""Encode an arbitray message type into a 'RecordIO' message.
This class encapsulates the process of encoding an
arbitrary message into a 'RecordIO' message. Its
constructor takes a serialization function of the form
'serialize(message)'. This serialization function is
responsible for knowing how to take whatever message type
is passed to 'encode()' and serializing it to a 'UTF-8'
encoded byte array.
Once 'encode(message)' is called, it will use the
serialization function to convert 'message' into a 'UTF-8'
encoded byte array, wrap it in a 'RecordIO' frame,
and return it.
:param serialize: a function to serialize any message
passed to 'encode()' into a 'UTF-8'
encoded byte array
:type serialize: function
"""
def __init__(self, serialize):
self.serialize = serialize
[docs] def encode(self, message):
"""Encode a message into 'RecordIO' format.
:param message: a message to serialize and then wrap in
a 'RecordIO' frame.
:type message: object
:returns: a serialized message wrapped in a 'RecordIO' frame
:rtype: bytes
"""
s = self.serialize(message)
if not isinstance(s, bytes):
raise Exception("Calling 'serialize(message)' must return a 'bytes' object")
return bytes(str(len(s)) + "\n", "UTF-8") + s
[docs]class Decoder():
"""Decode a 'RecordIO' message back to an arbitrary message type.
This class encapsulates the process of decoding a message
previously encoded with 'RecordIO' back to an arbitrary
message type. Its constructor takes a deserialization
function of the form 'deserialize(data)'. This
deserialization function is responsible for knowing how to
take a fully constructed 'RecordIO' message containing a
'UTF-8' encoded byte array and deserialize it back into the
original message type.
The 'decode(data)' message takes a 'UTF-8' encoded byte array
as input and buffers it across subsequent calls to
construct a set of fully constructed 'RecordIO' messages that
are decoded and returned in a list.
:param deserialize: a function to deserialize from 'RecordIO'
messages built up by subsequent calls
to 'decode(data)'
:type deserialize: function
"""
HEADER = 0
RECORD = 1
FAILED = 2
def __init__(self, deserialize):
self.deserialize = deserialize
self.state = self.HEADER
self.buffer = bytes("", "UTF-8")
self.length = 0
[docs] def decode(self, data):
"""Decode a 'RecordIO' formatted message to its original type.
:param data: an array of 'UTF-8' encoded bytes that make up a
partial 'RecordIO' message. Subsequent calls to this
function maintain state to build up a full 'RecordIO'
message and decode it
:type data: bytes
:returns: a list of deserialized messages
:rtype: list
"""
if not isinstance(data, bytes):
raise Exception("Parameter 'data' must of of type 'bytes'")
if self.state == self.FAILED:
raise Exception("Decoder is in a FAILED state")
records = []
for c in data:
if self.state == self.HEADER:
if c != ord('\n'):
self.buffer += bytes([c])
continue
try:
self.length = int(self.buffer.decode("UTF-8"))
assert self.length >= 0, "Negative record length '{length}'".format(length=self.length)
except Exception as exception:
self.state = self.FAILED
raise Exception("Failed to decode length '{buffer}': {error}"
.format(buffer=self.buffer, error=exception)) from exception
self.buffer = bytes("", "UTF-8")
self.state = self.RECORD
# Note that for 0 length records, we immediately decode.
if self.length == 0:
records.append(self.deserialize(self.buffer))
self.state = self.HEADER
elif self.state == self.RECORD:
assert self.length
assert len(self.buffer) < self.length
self.buffer += bytes([c])
if len(self.buffer) == self.length:
records.append(self.deserialize(self.buffer))
self.buffer = bytes("", "UTF-8")
self.state = self.HEADER
return records