# -*- coding: utf-8 -*-
import difflib
import logging
import pprint
from copy import copy, deepcopy
from functools import reduce
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import simplejson
from simplekv import KeyValueStore
from kartothek.core import naming
from kartothek.core._compat import load_json
from kartothek.core.utils import ensure_string_type
_logger = logging.getLogger()
__all__ = (
"SchemaWrapper",
"make_meta",
"validate_compatible",
"empty_dataframe_from_schema",
"normalize_type",
"normalize_column_order",
)
[docs]class SchemaWrapper:
"""
Wrapper object for pyarrow.Schema to handle forwards and backwards compatibility.
"""
def __init__(self, schema, origin: Union[str, Set[str]]):
if isinstance(origin, str):
origin = {origin}
elif isinstance(origin, set):
origin = copy(origin)
if not all(isinstance(s, str) for s in origin):
raise TypeError("Schema origin elements must be strings.")
self.__schema = schema
self.__origin = origin
self._schema_compat()
[docs] def with_origin(self, origin: Union[str, Set[str]]) -> "SchemaWrapper":
"""
Create new SchemaWrapper with given origin.
Parameters
----------
origin:
New origin.
"""
return SchemaWrapper(self.__schema, origin)
def _schema_compat(self):
# https://issues.apache.org/jira/browse/ARROW-5104
schema = self.__schema
if self.__schema is not None and self.__schema.pandas_metadata is not None:
pandas_metadata = schema.pandas_metadata
index_cols = pandas_metadata["index_columns"]
if len(index_cols) > 1:
raise NotImplementedError("Treatement of MultiIndex not implemented.")
for ix, col in enumerate(index_cols):
# Range index is now serialized using start/end information. This special treatment
# removes it from the columns which is fine
if isinstance(col, dict):
pass
# other indices are still tracked as a column
else:
index_level_ix = schema.get_field_index(col)
# this may happen for the schema of an empty df
if index_level_ix >= 0:
schema = schema.remove(index_level_ix)
schema = schema.remove_metadata()
md = {b"pandas": _dict_to_binary(pandas_metadata)}
schema = schema.with_metadata(md)
self.__schema = schema
[docs] def internal(self):
return self.__schema
@property
def origin(self) -> Set[str]:
return copy(self.__origin)
def __repr__(self):
return self.__schema.__repr__()
def __eq__(self, other):
return self.equals(other)
def __ne__(self, other):
return not self.equals(other)
def __getstate__(self):
return (_schema2bytes(self.__schema), self.__origin)
def __setstate__(self, state):
self.__schema = _bytes2schema(state[0])
self.__origin = state[1]
def __getattr__(self, attr):
return getattr(self.__schema, attr)
def __hash__(self):
# see https://issues.apache.org/jira/browse/ARROW-2719
return hash(_schema2bytes(self.__schema))
def __getitem__(self, i):
return self.__schema[i]
def __len__(self):
return len(self.__schema)
[docs] def equals(self, other, check_metadata=False):
if isinstance(other, SchemaWrapper):
return self.__schema.equals(other.__schema, check_metadata)
else:
return self.__schema.equals(other, check_metadata)
equals.__doc__ = pa.Schema.equals.__doc__
[docs] def remove(self, i):
return SchemaWrapper(self.__schema.remove(i), self.__origin)
remove.__doc__ = pa.Schema.set.__doc__
remove_metadata.__doc__ = pa.Schema.remove_metadata.__doc__
[docs] def set(self, i, field):
return SchemaWrapper(self.__schema.set(i, field), self.__origin)
set.__doc__ = pa.Schema.set.__doc__
[docs]def normalize_column_order(schema, partition_keys=None):
"""
Normalize column order in schema.
Columns are sorted in the following way:
1. Partition keys (as provided by ``partition_keys``)
2. DataFrame columns in alphabetic order
3. Remaining fields as generated by pyarrow, mostly index columns
Parameters
----------
schema: SchemaWrapper
Schema information for DataFrame.
partition_keys: Union[None, List[str]]
Partition keys used to split the dataset.
Returns
-------
schema: SchemaWrapper
Schema information for DataFrame.
"""
if not isinstance(schema, SchemaWrapper):
schema = SchemaWrapper(schema, "__unknown__")
if partition_keys is None:
partition_keys = []
else:
partition_keys = list(partition_keys)
pandas_metadata = schema.pandas_metadata
origin = schema.origin
cols_partition = {}
cols_payload = []
cols_misc = []
for cmd in pandas_metadata["columns"]:
name = cmd.get("name")
field_name = cmd["field_name"]
field_idx = schema.get_field_index(field_name)
if field_idx >= 0:
field = schema[field_idx]
else:
field = None
if name is None:
cols_misc.append((cmd, field))
elif name in partition_keys:
cols_partition[name] = (cmd, field)
else:
cols_payload.append((name, cmd, field))
ordered = []
for k in partition_keys:
if k in cols_partition:
ordered.append(cols_partition[k])
ordered += [(cmd, f) for _name, cmd, f in sorted(cols_payload, key=lambda x: x[0])]
ordered += cols_misc
pandas_metadata["columns"] = [cmd for cmd, _ in ordered]
fields = [f for _, f in ordered if f is not None]
metadata = schema.metadata
metadata[b"pandas"] = _dict_to_binary(pandas_metadata)
schema = pa.schema(fields, metadata)
return SchemaWrapper(schema, origin)
[docs]def normalize_type(
t_pa: pa.DataType,
t_pd: Optional[str],
t_np: Optional[str],
metadata: Optional[Dict[str, Any]],
) -> Tuple[pa.DataType, Optional[str], Optional[str], Optional[Dict[str, Any]]]:
"""
This will normalize types as followed:
- all signed integers (``int8``, ``int16``, ``int32``, ``int64``) will be converted to ``int64``
- all unsigned integers (``uint8``, ``uint16``, ``uint32``, ``uint64``) will be converted to ``uint64``
- all floats (``float32``, ``float64``) will be converted to ``float64``
- all list value types will be normalized (e.g. ``list[int16]`` to ``list[int64]``, ``list[list[uint8]]`` to
``list[list[uint64]]``)
- all dict value types will be normalized (e.g. ``dictionary<values=float32, indices=int16, ordered=0>`` to
``float64``)
Parameters
----------
t_pa
pyarrow type object, e.g. ``pa.list_(pa.int8())``.
t_pd
pandas type identifier, e.g. ``"list[int8]"``.
t_np
numpy type identifier, e.g. ``"object"``.
metadata
metadata associated with the type, e.g. information about categorials.
"""
if pa.types.is_signed_integer(t_pa):
return pa.int64(), "int64", "int64", None
elif pa.types.is_unsigned_integer(t_pa):
return pa.uint64(), "uint64", "uint64", None
elif pa.types.is_floating(t_pa):
return pa.float64(), "float64", "float64", None
elif pa.types.is_list(t_pa):
assert t_pd is not None
t_pa2, t_pd2, t_np2, metadata2 = normalize_type(
t_pa.value_type, t_pd[len("list[") : -1], None, None
)
return pa.list_(t_pa2), "list[{}]".format(t_pd2), "object", None
elif pa.types.is_dictionary(t_pa):
# downcast to dictionary content, `t_pd` is useless in that case
return normalize_type(t_pa.value_type, t_np, t_np, None)
else:
return t_pa, t_pd, t_np, metadata
def _get_common_metadata_key(dataset_uuid, table):
return "{}/{}/{}".format(dataset_uuid, table, naming.TABLE_METADATA_FILE)
def read_schema_metadata(
dataset_uuid: str, store: KeyValueStore, table: str
) -> SchemaWrapper:
"""
Read schema and metadata from store.
Parameters
----------
dataset_uuid
Unique ID of the dataset in question.
store
Object that implements `.get(key)` to read data.
table
Table to read metadata for.
Returns
-------
schema: Schema
Schema information for DataFrame/table.
"""
key = _get_common_metadata_key(dataset_uuid=dataset_uuid, table=table)
return SchemaWrapper(_bytes2schema(store.get(key)), key)
def store_schema_metadata(
schema: SchemaWrapper, dataset_uuid: str, store: KeyValueStore, table: str
) -> str:
"""
Store schema and metadata to store.
Parameters
----------
schema
Schema information for DataFrame/table.
dataset_uuid
Unique ID of the dataset in question.
store
Object that implements `.put(key, data)` to write data.
table
Table to write metadata for.
Returns
-------
key: str
Key to which the metadata was written to.
"""
key = _get_common_metadata_key(dataset_uuid=dataset_uuid, table=table)
return store.put(key, _schema2bytes(schema.internal()))
def _schema2bytes(schema: SchemaWrapper) -> bytes:
buf = pa.BufferOutputStream()
pq.write_metadata(schema, buf, version="2.0", coerce_timestamps="us")
return buf.getvalue().to_pybytes()
def _bytes2schema(data: bytes) -> SchemaWrapper:
reader = pa.BufferReader(data)
schema = pq.read_schema(reader)
fields = []
for idx in range(len(schema)):
f = schema[idx]
# schema data recovered from parquet always contains timestamp data in us-granularity, but pandas will use
# ns-granularity, so we re-align the two different worlds here
if f.type == pa.timestamp("us"):
f = pa.field(f.name, pa.timestamp("ns"))
fields.append(f)
return pa.schema(fields, schema.metadata)
def _pandas_in_schemas(schemas):
"""
Check if any schema contains pandas metadata
"""
has_pandas = False
for schema in schemas:
if schema.metadata and b"pandas" in schema.metadata:
has_pandas = True
return has_pandas
def _determine_schemas_to_compare(
schemas: Sequence[SchemaWrapper], ignore_pandas: bool
) -> Tuple[Optional[SchemaWrapper], List[Tuple[SchemaWrapper, List[str]]]]:
"""
Iterate over a list of `pyarrow.Schema` objects and prepares them for comparison by picking a reference
and determining all null columns.
.. note::
If pandas metadata exists, the version stored in the metadata is overwritten with the currently
installed version since we expect to stay backwards compatible
Returns
-------
reference: Schema
A reference schema which is picked from the input list. The reference schema is guaranteed
to be a schema having the least number of null columns of all input columns. The set of null
columns is guaranteed to be a true subset of all null columns of all input schemas. If no such
schema can be found, an Exception is raised
list_of_schemas: List[Tuple[Schema, List]]
A list holding pairs of (Schema, null_columns) where the null_columns are all columns which are null and
must be removed before comparing the schemas
"""
has_pandas = _pandas_in_schemas(schemas) and not ignore_pandas
schemas_to_evaluate: List[Tuple[SchemaWrapper, List[str]]] = []
reference = None
null_cols_in_reference = set()
for schema in schemas:
if not isinstance(schema, SchemaWrapper):
schema = SchemaWrapper(schema, "__unknown__")
if has_pandas:
metadata = schema.metadata
if metadata is None or b"pandas" not in metadata:
raise ValueError(
"Pandas and non-Pandas schemas are not comparable. "
"Use ignore_pandas=True if you only want to compare "
"on Arrow level."
)
pandas_metadata = load_json(metadata[b"pandas"].decode("utf8"))
# we don't care about the pandas version, since we assume it's safe
# to read datasets that were written by older or newer versions.
pandas_metadata["pandas_version"] = "{}".format(pd.__version__)
metadata_clean = deepcopy(metadata)
metadata_clean[b"pandas"] = _dict_to_binary(pandas_metadata)
current = SchemaWrapper(pa.schema(schema, metadata_clean), schema.origin)
else:
current = schema
# If a field is null we cannot compare it and must therefore reject it
null_columns = {field.name for field in current if field.type == pa.null()} # type: ignore
# Determine a valid reference schema. A valid reference schema is considered to be the schema
# of all input schemas with the least empty columns.
# The reference schema ought to be a schema whose empty columns are a true subset for all sets
# of empty columns. This ensures that the actual reference schema is the schema with the most
# information possible. A schema which doesn't fulfil this requirement would weaken the
# comparison and would allow for false positives
# Trivial case
if reference is None:
reference = current
null_cols_in_reference = null_columns
# The reference has enough information to validate against current schema.
# Append it to the list of schemas to be verified
elif null_cols_in_reference.issubset(null_columns):
schemas_to_evaluate.append((current, null_columns))
# current schema includes all information of reference and more.
# Add reference to schemas_to_evaluate and update reference
elif null_columns.issubset(null_cols_in_reference):
schemas_to_evaluate.append((reference, null_cols_in_reference))
reference = current
null_cols_in_reference = null_columns
# If there is no clear subset available elect the schema with the least null columns as `reference`.
# Iterate over the null columns of `reference` and replace it with a non-null field of the `current`
# schema which recovers the loop invariant (null columns of `reference` is subset of `current`)
else:
if len(null_columns) < len(null_cols_in_reference):
reference, current = current, reference
null_cols_in_reference, null_columns = (
null_columns,
null_cols_in_reference,
)
for col in null_cols_in_reference - null_columns:
# Enrich the information in the reference by grabbing the missing fields
# from the current iteration. This assumes that we only check for global validity and
# isn't relevant where the reference comes from.
reference = _swap_fields_by_name(reference, current, col)
null_cols_in_reference.remove(col)
schemas_to_evaluate.append((current, null_columns))
assert (reference is not None) or (not schemas_to_evaluate)
return reference, schemas_to_evaluate
def _swap_fields_by_name(reference, current, field_name):
current_field = current.field(field_name)
reference_index = reference.get_field_index(field_name)
return reference.set(reference_index, current_field)
def _strip_columns_from_schema(schema, field_names):
stripped_schema = schema
for name in field_names:
ix = stripped_schema.get_field_index(name)
if ix >= 0:
stripped_schema = stripped_schema.remove(ix)
else:
# If the returned index is negative, the field doesn't exist in the schema.
# This is most likely an indicator for incompatible schemas and we refuse to strip the schema
# to not obfurscate the validation result
_logger.warning(
"Unexpected field `%s` encountered while trying to strip `null` columns.\n"
"Schema was:\n\n`%s`" % (name, schema)
)
return schema
return stripped_schema
def _remove_diff_header(diff):
diff = list(diff)
for ix, el in enumerate(diff):
# This marks the first actual entry of the diff
# e.g. @@ -1,5 + 2,5 @@
if el.startswith("@"):
return diff[ix:]
return diff
def _diff_schemas(first, second):
# see https://issues.apache.org/jira/browse/ARROW-4176
first_pyarrow_info = str(first.remove_metadata())
second_pyarrow_info = str(second.remove_metadata())
pyarrow_diff = _remove_diff_header(
difflib.unified_diff(
str(first_pyarrow_info).splitlines(), str(second_pyarrow_info).splitlines()
)
)
first_pandas_info = first.pandas_metadata
second_pandas_info = second.pandas_metadata
pandas_meta_diff = _remove_diff_header(
difflib.unified_diff(
pprint.pformat(first_pandas_info).splitlines(),
pprint.pformat(second_pandas_info).splitlines(),
)
)
diff_string = (
"Arrow schema:\n"
+ "\n".join(pyarrow_diff)
+ "\n\nPandas_metadata:\n"
+ "\n".join(pandas_meta_diff)
)
return diff_string
[docs]def validate_compatible(schemas, ignore_pandas=False):
"""
Validate that all schemas in a given list are compatible.
Apart from the pandas version preserved in the schema metadata, schemas must be completely identical. That includes
a perfect match of the whole metadata (except the pandas version) and pyarrow types.
Use :meth:`make_meta` and :meth:`normalize_column_order` for type and column order normalization.
In the case that all schemas don't contain any pandas metadata, we will check the Arrow
schemas directly for compatibility.
Parameters
----------
schemas: List[Schema]
Schema information from multiple sources, e.g. multiple partitions. List may be empty.
ignore_pandas: bool
Ignore the schema information given by Pandas an always use the Arrow schema.
Returns
-------
schema: SchemaWrapper
The reference schema which was tested against
Raises
------
ValueError
At least two schemas are incompatible.
"""
reference, schemas_to_evaluate = _determine_schemas_to_compare(
schemas, ignore_pandas
)
for current, null_columns in schemas_to_evaluate:
# Compare each schema to the reference but ignore the null_cols and the Pandas schema information.
reference_to_compare = _strip_columns_from_schema(
reference, null_columns
).remove_metadata()
current_to_compare = _strip_columns_from_schema(
current, null_columns
).remove_metadata()
def _fmt_origin(origin):
origin = sorted(origin)
# dask cuts of exception messages at 1k chars:
# https://github.com/dask/distributed/blob/6e0c0a6b90b1d3c/distributed/core.py#L964
# therefore, we cut the the maximum length
max_len = 200
inner_msg = ", ".join(origin)
ellipsis = "..."
if len(inner_msg) > max_len + len(ellipsis):
inner_msg = inner_msg[:max_len] + ellipsis
return "{{{}}}".format(inner_msg)
if reference_to_compare != current_to_compare:
schema_diff = _diff_schemas(reference, current)
exception_message = """Schema violation
Origin schema: {origin_schema}
Origin reference: {origin_reference}
Diff:
{schema_diff}
Reference schema:
{reference}""".format(
schema_diff=schema_diff,
reference=str(reference),
origin_schema=_fmt_origin(current.origin),
origin_reference=_fmt_origin(reference.origin),
)
raise ValueError(exception_message)
# add all origins to result AFTER error checking, otherwise the error message would be pretty misleading due to the
# reference containing all origins.
if reference is None:
return None
else:
return reference.with_origin(
reduce(
set.union,
(schema.origin for schema, _null_columns in schemas_to_evaluate),
reference.origin,
)
)
def validate_shared_columns(schemas, ignore_pandas=False):
"""
Validate that columns that are shared amongst schemas are compatible.
Only DataFrame columns are taken into account, other fields (like index data) are ignored. The following data must
be an exact match:
- metadata (as stored in the ``"columns"`` list of the ``b'pandas'`` schema metadata)
- pyarrow type (that means that e.g. ``int8`` and ``int64`` are NOT compatible)
Columns that are only present in a subset of the provided schemas must only be compatible for that subset, i.e.
non-existing columns are ignored. The order of the columns in the provided schemas is irrelevant.
Type normalization should be handled by :meth:`make_meta`.
In the case that all schemas don't contain any pandas metadata, we will check the Arrow
schemas directly for compatibility. Then the metadata information will not be checked
(as it is non-existent).
Parameters
----------
schemas: List[Schema]
Schema information from multiple sources, e.g. multiple tables. List may be empty.
ignore_pandas: bool
Ignore the schema information given by Pandas an always use the Arrow schema.
Raises
------
ValueError
Incompatible columns were found.
"""
seen = {}
has_pandas = _pandas_in_schemas(schemas) and not ignore_pandas
for schema in schemas:
if has_pandas:
metadata = schema.metadata
if metadata is None or b"pandas" not in metadata:
raise ValueError(
"Pandas and non-Pandas schemas are not comparable. "
"Use ignore_pandas=True if you only want to compare "
"on Arrow level."
)
pandas_metadata = load_json(metadata[b"pandas"].decode("utf8"))
columns = []
for cmd in pandas_metadata["columns"]:
name = cmd.get("name")
if name is None:
continue
columns.append(cmd["field_name"])
else:
columns = schema.names
for col in columns:
field_idx = schema.get_field_index(col)
field = schema[field_idx]
obj = (field, col)
if col in seen:
ref = seen[col]
if pa.types.is_null(ref[0].type) or pa.types.is_null(field.type):
continue
if ref != obj:
raise ValueError(
'Found incompatible entries for column "{}"\n{}\n{}'.format(
col, ref, obj
)
)
else:
seen[col] = obj
def _dict_to_binary(dct):
return simplejson.dumps(dct, sort_keys=True).encode("utf8")
[docs]def empty_dataframe_from_schema(schema, columns=None, date_as_object=False):
"""
Create an empty DataFrame from provided schema.
Parameters
----------
schema: Schema
Schema information of the new empty DataFrame.
columns: Union[None, List[str]]
Optional list of columns that should be part of the resulting DataFrame. All columns in that list must also be
part of the provided schema.
Returns
-------
DataFrame
Empty DataFrame with requested columns and types.
"""
df = schema.internal().empty_table().to_pandas(date_as_object=date_as_object)
df.columns = df.columns.map(ensure_string_type)
if columns is not None:
df = df[columns]
return df