from functools import reduce
from itertools import product
from operator import mul
import numpy as np
from .core import Array
from .utils import meta_from_array
from ..base import tokenize
from ..core import flatten
from ..highlevelgraph import HighLevelGraph
from ..utils import M
def reshape_rechunk(inshape, outshape, inchunks):
assert all(isinstance(c, tuple) for c in inchunks)
ii = len(inshape) - 1
oi = len(outshape) - 1
result_inchunks = [None for i in range(len(inshape))]
result_outchunks = [None for i in range(len(outshape))]
while ii >= 0 or oi >= 0:
if inshape[ii] == outshape[oi]:
result_inchunks[ii] = inchunks[ii]
result_outchunks[oi] = inchunks[ii]
ii -= 1
oi -= 1
continue
din = inshape[ii]
dout = outshape[oi]
if din == 1:
result_inchunks[ii] = (1,)
ii -= 1
elif dout == 1:
result_outchunks[oi] = (1,)
oi -= 1
elif din < dout: # (4, 4, 4) -> (64,)
ileft = ii - 1
while (
ileft >= 0 and reduce(mul, inshape[ileft : ii + 1]) < dout
): # 4 < 64, 4*4 < 64, 4*4*4 == 64
ileft -= 1
if reduce(mul, inshape[ileft : ii + 1]) != dout:
raise ValueError("Shapes not compatible")
# Special case to avoid intermediate rechunking:
# When all the lower axis are completely chunked (chunksize=1) then
# we're simply moving around blocks.
if all(len(inchunks[i]) == inshape[i] for i in range(ii)):
for i in range(ii + 1):
result_inchunks[i] = inchunks[i]
result_outchunks[oi] = inchunks[i] * np.prod(
list(map(len, inchunks[:i]))
)
else:
for i in range(ileft + 1, ii + 1): # need single-shape dimensions
result_inchunks[i] = (inshape[i],) # chunks[i] = (4,)
chunk_reduction = reduce(mul, map(len, inchunks[ileft + 1 : ii + 1]))
result_inchunks[ileft] = expand_tuple(inchunks[ileft], chunk_reduction)
prod = reduce(mul, inshape[ileft + 1 : ii + 1]) # 16
result_outchunks[oi] = tuple(
prod * c for c in result_inchunks[ileft]
) # (1, 1, 1, 1) .* 16
oi -= 1
ii = ileft - 1
elif din > dout: # (64,) -> (4, 4, 4)
oleft = oi - 1
while oleft >= 0 and reduce(mul, outshape[oleft : oi + 1]) < din:
oleft -= 1
if reduce(mul, outshape[oleft : oi + 1]) != din:
raise ValueError("Shapes not compatible")
# TODO: don't coalesce shapes unnecessarily
cs = reduce(mul, outshape[oleft + 1 : oi + 1])
result_inchunks[ii] = contract_tuple(inchunks[ii], cs) # (16, 16, 16, 16)
for i in range(oleft + 1, oi + 1):
result_outchunks[i] = (outshape[i],)
result_outchunks[oleft] = tuple(c // cs for c in result_inchunks[ii])
oi = oleft - 1
ii -= 1
return tuple(result_inchunks), tuple(result_outchunks)
def expand_tuple(chunks, factor):
"""
>>> expand_tuple((2, 4), 2)
(1, 1, 2, 2)
>>> expand_tuple((2, 4), 3)
(1, 1, 1, 1, 2)
>>> expand_tuple((3, 4), 2)
(1, 2, 2, 2)
>>> expand_tuple((7, 4), 3)
(2, 2, 3, 1, 1, 2)
"""
if factor == 1:
return chunks
out = []
for c in chunks:
x = c
part = max(x / factor, 1)
while x >= 2 * part:
out.append(int(part))
x -= int(part)
if x:
out.append(x)
assert sum(chunks) == sum(out)
return tuple(out)
def contract_tuple(chunks, factor):
"""Return simple chunks tuple such that factor divides all elements
Examples
--------
>>> contract_tuple((2, 2, 8, 4), 4)
(4, 8, 4)
"""
assert sum(chunks) % factor == 0
out = []
residual = 0
for chunk in chunks:
chunk += residual
div = chunk // factor
residual = chunk % factor
good = factor * div
if good:
out.append(good)
return tuple(out)
[docs]def reshape(x, shape, merge_chunks=True):
"""Reshape array to new shape
Parameters
----------
shape : int or tuple of ints
The new shape should be compatible with the original shape. If
an integer, then the result will be a 1-D array of that length.
One shape dimension can be -1. In this case, the value is
inferred from the length of the array and remaining dimensions.
merge_chunks : bool, default True
Whether to merge chunks using the logic in :meth:`dask.array.rechunk`
when communication is necessary given the input array chunking and
the output shape. With ``merge_chunks==False``, the input array will
be rechunked to a chunksize of 1, which can create very many tasks.
Notes
-----
This is a parallelized version of the ``np.reshape`` function with the
following limitations:
1. It assumes that the array is stored in `row-major order`_
2. It only allows for reshapings that collapse or merge dimensions like
``(1, 2, 3, 4) -> (1, 6, 4)`` or ``(64,) -> (4, 4, 4)``
.. _`row-major order`: https://en.wikipedia.org/wiki/Row-_and_column-major_order
When communication is necessary this algorithm depends on the logic within
rechunk. It endeavors to keep chunk sizes roughly the same when possible.
See :ref:`array-chunks.reshaping` for a discussion the tradeoffs of
``merge_chunks``.
See Also
--------
dask.array.rechunk
numpy.reshape
"""
# Sanitize inputs, look for -1 in shape
from .slicing import sanitize_index
shape = tuple(map(sanitize_index, shape))
known_sizes = [s for s in shape if s != -1]
if len(known_sizes) < len(shape):
if len(shape) - len(known_sizes) > 1:
raise ValueError("can only specify one unknown dimension")
# Fastpath for x.reshape(-1) on 1D arrays, allows unknown shape in x
# for this case only.
if len(shape) == 1 and x.ndim == 1:
return x
missing_size = sanitize_index(x.size / reduce(mul, known_sizes, 1))
shape = tuple(missing_size if s == -1 else s for s in shape)
if np.isnan(sum(x.shape)):
raise ValueError(
"Array chunk size or shape is unknown. shape: %s\n\n"
"Possible solution with x.compute_chunk_sizes()" % x.shape
)
if reduce(mul, shape, 1) != x.size:
raise ValueError("total size of new array must be unchanged")
if x.shape == shape:
return x
meta = meta_from_array(x, len(shape))
name = "reshape-" + tokenize(x, shape)
if x.npartitions == 1:
key = next(flatten(x.__dask_keys__()))
dsk = {(name,) + (0,) * len(shape): (M.reshape, key, shape)}
chunks = tuple((d,) for d in shape)
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x])
return Array(graph, name, chunks, meta=meta)
# Logic or how to rechunk
din = len(x.shape)
dout = len(shape)
if not merge_chunks and din > dout:
x = x.rechunk({i: 1 for i in range(din - dout)})
inchunks, outchunks = reshape_rechunk(x.shape, shape, x.chunks)
x2 = x.rechunk(inchunks)
# Construct graph
in_keys = list(product([x2.name], *[range(len(c)) for c in inchunks]))
out_keys = list(product([name], *[range(len(c)) for c in outchunks]))
shapes = list(product(*outchunks))
dsk = {a: (M.reshape, b, shape) for a, b, shape in zip(out_keys, in_keys, shapes)}
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x2])
return Array(graph, name, outchunks, meta=meta)