import logging
from functools import partial
from typing import List, Union
import dask.dataframe as dd
import pandas as pd
_logger = logging.getLogger()
_PAYLOAD_COL = "__ktk_shuffle_payload"
try:
# Technically distributed is an optional dependency
from distributed.protocol import serialize_bytes
from distributed.protocol import deserialize_bytes
HAS_DISTRIBUTED = True
except ImportError:
HAS_DISTRIBUTED = False
serialize_bytes = None
deserialize_bytes = None
__all__ = (
"pack_payload_pandas",
"pack_payload",
"unpack_payload_pandas",
"unpack_payload",
)
[docs]def pack_payload_pandas(partition: pd.DataFrame, group_key: List[str]) -> pd.DataFrame:
if not HAS_DISTRIBUTED:
_logger.warning(
"Shuffle payload columns cannot be compressed since distributed is not installed."
)
return partition
if partition.empty:
res = partition[group_key]
res[_PAYLOAD_COL] = b""
else:
res = partition.groupby(
group_key,
sort=False,
observed=True,
# Keep the as_index s.t. the group values are not dropped. With this
# the behaviour seems to be consistent along pandas versions
as_index=True,
).apply(lambda x: pd.Series({_PAYLOAD_COL: serialize_bytes(x)}))
res = res.reset_index()
return res
[docs]def pack_payload(df: dd.DataFrame, group_key: Union[List[str], str]) -> dd.DataFrame:
"""
Pack all payload columns (everything except of group_key) into a single
columns. This column will contain a single byte string containing the
serialized and compressed payload data. The payload data is just dead weight
when reshuffling. By compressing it once before the shuffle starts, this
saves a lot of memory and network/disk IO.
Example::
>>> import pandas as pd
... import dask.dataframe as dd
... from dask.dataframe.shuffle import pack_payload
...
... df = pd.DataFrame({"A": [1, 1] * 2 + [2, 2] * 2 + [3, 3] * 2, "B": range(12)})
... ddf = dd.from_pandas(df, npartitions=2)
>>> ddf.partitions[0].compute()
A B
0 1 0
1 1 1
2 1 2
3 1 3
4 2 4
5 2 5
>>> pack_payload(ddf, "A").partitions[0].compute()
A __dask_payload_bytes
0 1 b'\x03\x00\x00\x00\x00\x00\x00\x00)\x00\x00\x03...
1 2 b'\x03\x00\x00\x00\x00\x00\x00\x00)\x00\x00\x03...
See also https://github.com/dask/dask/pull/6259
"""
if (
# https://github.com/pandas-dev/pandas/issues/34455
isinstance(df._meta.index, pd.Float64Index)
# TODO: Try to find out what's going on an file a bug report
# For datetime indices the apply seems to be corrupt
# s.t. apply(lambda x:x) returns different values
or isinstance(df._meta.index, pd.DatetimeIndex)
):
return df
if not HAS_DISTRIBUTED:
_logger.warning(
"Shuffle payload columns cannot be compressed since distributed is not installed."
)
return df
if not isinstance(group_key, list):
group_key = [group_key]
packed_meta = df._meta[group_key]
packed_meta[_PAYLOAD_COL] = b""
_pack_payload = partial(pack_payload_pandas, group_key=group_key)
return df.map_partitions(_pack_payload, meta=packed_meta)
[docs]def unpack_payload_pandas(
partition: pd.DataFrame, unpack_meta: pd.DataFrame
) -> pd.DataFrame:
"""
Revert ``pack_payload_pandas`` and restore packed payload
unpack_meta:
A dataframe indicating the schema of the unpacked data. This will be returned in case the input is empty
"""
if not HAS_DISTRIBUTED:
_logger.warning(
"Shuffle payload columns cannot be compressed since distributed is not installed."
)
return partition
if partition.empty:
return unpack_meta.iloc[:0]
mapped = partition[_PAYLOAD_COL].map(deserialize_bytes)
return pd.concat(mapped.values, copy=False, ignore_index=True)
[docs]def unpack_payload(df: dd.DataFrame, unpack_meta: pd.DataFrame) -> dd.DataFrame:
"""Revert payload packing of ``pack_payload`` and restores full dataframe."""
if (
# https://github.com/pandas-dev/pandas/issues/34455
isinstance(df._meta.index, pd.Float64Index)
# TODO: Try to find out what's going on an file a bug report
# For datetime indices the apply seems to be corrupt
# s.t. apply(lambda x:x) returns different values
or isinstance(df._meta.index, pd.DatetimeIndex)
):
return df
if not HAS_DISTRIBUTED:
_logger.warning(
"Shuffle payload columns cannot be compressed since distributed is not installed."
)
return df
return df.map_partitions(
unpack_payload_pandas, unpack_meta=unpack_meta, meta=unpack_meta
)