from typing import Callable, Iterator, List, Optional, Set, Union, cast, overload
import pandas as pd
from kartothek.core.factory import DatasetFactory
from kartothek.core.index import ExplicitSecondaryIndex
from kartothek.core.typing import StoreInput
from kartothek.io_components.metapartition import MetaPartition
from kartothek.io_components.utils import normalize_args
from kartothek.serialization import (
PredicatesType,
check_predicates,
columns_in_predicates,
)
from kartothek.utils.migration_helpers import (
deprecate_parameters_if_set,
get_deprecation_warning_remove_parameter_multi_table,
)
@overload
@deprecate_parameters_if_set(
get_deprecation_warning_remove_parameter_multi_table(
deprecated_in="5.3", removed_in="6.0"
),
"label_filter",
"concat_partitions_on_primary_index",
"store",
"dispatch_metadata",
)
def dispatch_metapartitions_from_factory(
dataset_factory: DatasetFactory,
label_filter: Optional[Callable] = None,
concat_partitions_on_primary_index: bool = False,
predicates: PredicatesType = None,
store: Optional[StoreInput] = None,
dispatch_by: None = None,
dispatch_metadata: bool = False,
) -> Iterator[MetaPartition]:
...
@overload
@deprecate_parameters_if_set(
get_deprecation_warning_remove_parameter_multi_table(
deprecated_in="5.3", removed_in="6.0"
),
"label_filter",
"concat_partitions_on_primary_index",
"store",
"dispatch_metadata",
)
def dispatch_metapartitions_from_factory(
dataset_factory: DatasetFactory,
label_filter: Optional[Callable],
concat_partitions_on_primary_index: bool,
predicates: PredicatesType,
store: Optional[StoreInput],
dispatch_by: List[str],
dispatch_metadata: bool,
) -> Iterator[List[MetaPartition]]:
...
@normalize_args
@deprecate_parameters_if_set(
get_deprecation_warning_remove_parameter_multi_table(
deprecated_in="5.3", removed_in="6.0"
),
"label_filter",
"concat_partitions_on_primary_index",
"store",
"dispatch_metadata",
)
def dispatch_metapartitions_from_factory(
dataset_factory: DatasetFactory,
label_filter: Optional[Callable] = None,
concat_partitions_on_primary_index: bool = False,
predicates: PredicatesType = None,
store: Optional[StoreInput] = None,
dispatch_by: Optional[List[str]] = None,
dispatch_metadata: bool = False,
) -> Union[Iterator[MetaPartition], Iterator[List[MetaPartition]]]:
"""
:meta private:
"""
if dispatch_by is not None and concat_partitions_on_primary_index:
raise ValueError(
"Both `dispatch_by` and `concat_partitions_on_primary_index` are provided, "
"`concat_partitions_on_primary_index` is deprecated and will be removed in the next major release. "
"Please only provide the `dispatch_by` argument. "
)
if concat_partitions_on_primary_index:
dispatch_by = dataset_factory.partition_keys
if dispatch_by is not None and not set(dispatch_by).issubset(
set(dataset_factory.index_columns)
):
raise RuntimeError(
f"Dispatch columns must be indexed.\nRequested index: {dispatch_by} but available index columns: {sorted(dataset_factory.index_columns)}"
)
check_predicates(predicates)
# Determine which indices need to be loaded.
index_cols: Set[str] = set()
if dispatch_by:
index_cols |= set(dispatch_by)
if predicates:
predicate_cols = set(columns_in_predicates(predicates))
predicate_index_cols = predicate_cols & set(dataset_factory.index_columns)
index_cols |= predicate_index_cols
for col in index_cols:
dataset_factory.load_index(col)
base_df = dataset_factory.get_indices_as_dataframe(
list(index_cols), predicates=predicates
)
if label_filter:
base_df = base_df[base_df.index.map(label_filter)]
indices_to_dispatch = {
name: ix.unload()
for name, ix in dataset_factory.indices.items()
if isinstance(ix, ExplicitSecondaryIndex)
}
if dispatch_by is not None:
base_df = cast(pd.DataFrame, base_df)
if len(dispatch_by) == 0:
merged_partitions = [((""), base_df)]
else:
# Group the resulting MetaParitions by partition keys or a subset of those keys
merged_partitions = base_df.groupby(
by=list(dispatch_by), sort=True, as_index=False
)
for group_name, group in merged_partitions:
if not isinstance(group_name, tuple):
group_name = (group_name,) # type: ignore
mps = []
logical_conjunction = list(
zip(dispatch_by, ["=="] * len(dispatch_by), group_name)
)
for label in group.index.unique():
mps.append(
MetaPartition.from_partition(
partition=dataset_factory.partitions[label],
dataset_metadata=dataset_factory.metadata
if dispatch_metadata
else None,
indices=indices_to_dispatch if dispatch_metadata else None,
metadata_version=dataset_factory.metadata_version,
table_meta=dataset_factory.table_meta,
partition_keys=dataset_factory.partition_keys,
logical_conjunction=logical_conjunction,
)
)
yield mps
else:
for part_label in base_df.index.unique():
part = dataset_factory.partitions[part_label]
yield MetaPartition.from_partition(
partition=part,
dataset_metadata=dataset_factory.metadata
if dispatch_metadata
else None,
indices=indices_to_dispatch if dispatch_metadata else None,
metadata_version=dataset_factory.metadata_version,
table_meta=dataset_factory.table_meta,
partition_keys=dataset_factory.partition_keys,
)