import math
import multiprocessing
import os
import shutil
import tempfile
import threading
import uuid
import warnings
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as _importlib_version
from pathlib import Path
import numpy as np
import packaging.version
import large_image
from large_image.cache_util import LruCacheMetaclass, methodcache
from large_image.constants import NEW_IMAGE_PATH_FLAG, TILE_FORMAT_NUMPY, SourcePriority
from large_image.exceptions import TileSourceError, TileSourceFileNotFoundError
from large_image.tilesource import FileTileSource
from large_image.tilesource.resample import ResampleMethod, downsampleTileHalfRes
from large_image.tilesource.utilities import _imageToNumpy, nearPowerOfTwo
try:
__version__ = _importlib_version(__name__)
except PackageNotFoundError:
# package is not installed
pass
zarr = None
def _lazyImport():
"""
Import the zarr module. This is done when needed rather than in the module
initialization because it is slow.
"""
global zarr
if zarr is None:
try:
import zarr
warnings.filterwarnings('ignore', category=FutureWarning, module='.*zarr.*')
except ImportError:
msg = 'zarr module not found.'
raise TileSourceError(msg)
[docs]
class ZarrFileTileSource(FileTileSource, metaclass=LruCacheMetaclass):
"""
Provides tile access to files that the zarr library can read.
"""
cacheName = 'tilesource'
name = 'zarr'
extensions = {
None: SourcePriority.LOW,
'zarr': SourcePriority.PREFERRED,
'zgroup': SourcePriority.PREFERRED,
'zattrs': SourcePriority.PREFERRED,
'zarray': SourcePriority.PREFERRED,
'db': SourcePriority.MEDIUM,
'zip': SourcePriority.LOWER,
}
mimeTypes = {
None: SourcePriority.FALLBACK,
'application/zip+zarr': SourcePriority.PREFERRED,
'application/vnd+zarr': SourcePriority.PREFERRED,
'application/x-zarr': SourcePriority.PREFERRED,
}
newPriority = SourcePriority.HIGH
_tileSize = 512
_minTileSize = 128
_maxTileSize = 1024
_minAssociatedImageSize = 64
_maxAssociatedImageSize = 8192
def __init__(self, path, **kwargs):
"""
Initialize the tile class. See the base class for other available
parameters.
:param path: a filesystem path for the tile source.
"""
super().__init__(path, **kwargs)
_lazyImport()
if str(path).startswith(NEW_IMAGE_PATH_FLAG):
self._initNew(path, **kwargs)
else:
self._initOpen(**kwargs)
self._tileLock = threading.RLock()
def _initOpen(self, **kwargs):
self._largeImagePath = str(self._getLargeImagePath())
self._zarr = None
self._editable = False
self._frameValues = None
self._frameAxes = None
self._frameUnits = None
if not os.path.isfile(self._largeImagePath) and '//:' not in self._largeImagePath:
raise TileSourceFileNotFoundError(self._largeImagePath) from None
try:
self._zarr = zarr.open(zarr.SQLiteStore(self._largeImagePath), mode='r')
except Exception:
try:
self._zarr = zarr.open(self._largeImagePath, mode='r')
except Exception:
if os.path.basename(self._largeImagePath) in {'.zgroup', '.zattrs', '.zarray'}:
try:
self._zarr = zarr.open(os.path.dirname(self._largeImagePath), mode='r')
except Exception:
pass
if self._zarr is None:
if not os.path.isfile(self._largeImagePath):
raise TileSourceFileNotFoundError(self._largeImagePath) from None
msg = 'File cannot be opened via zarr.'
raise TileSourceError(msg)
try:
self._validateZarr()
except TileSourceError:
raise
except Exception:
msg = 'File cannot be opened -- not an OME NGFF file or understandable zarr file.'
raise TileSourceError(msg)
def _initNew(self, path, **kwargs):
"""
Initialize the tile class for creating a new image.
"""
self._tempdir = Path(tempfile.gettempdir(), path)
self._created = False
if not self._tempdir.exists():
self._created = True
self._zarr_store = zarr.DirectoryStore(str(self._tempdir))
self._zarr = zarr.open(self._zarr_store, mode='a')
self._largeImagePath = None
self._dims = {}
self.sizeX = self.sizeY = self.levels = 0
self.tileWidth = self.tileHeight = self._tileSize
self._cacheValue = str(uuid.uuid4())
self._output = None
self._editable = True
self._bandRanges = None
self._threadLock = threading.RLock()
self._processLock = multiprocessing.Lock()
self._framecount = 0
self._mm_x = 0
self._mm_y = 0
self._channelNames = []
self._channelColors = []
self._imageDescription = None
self._levels = []
self._associatedImages = {}
self._frameValues = None
self._frameAxes = None
self._frameUnits = None
def __del__(self):
if not hasattr(self, '_derivedSource'):
try:
self._zarr.close()
except Exception:
pass
try:
if self._created:
shutil.rmtree(self._tempdir)
except Exception:
pass
def _checkEditable(self):
"""
Raise an exception if this is not an editable image.
"""
if not self._editable:
msg = 'Not an editable image'
raise TileSourceError(msg)
def _getGeneralAxes(self, arr):
"""
Examine a zarr array an guess what the axes are. We assume the two
maximal dimensions are y, x. Then, if there is a dimension that is 3
or 4 in length, it is channels. If there is more than one other that
is not 1, we don't know how it is sorted, so we will fail.
:param arr: a zarr array.
:return: a dictionary of axes with the axis as the key and the index
within the array axes as the value.
"""
shape = arr.shape
maxIndex = shape.index(max(shape))
secondMaxIndex = shape.index(max(x for idx, x in enumerate(shape) if idx != maxIndex))
axes = {
'x': max(maxIndex, secondMaxIndex),
'y': min(maxIndex, secondMaxIndex),
}
for idx, val in enumerate(shape):
if idx not in axes.values() and val == 4 and 'c' not in axes:
axes['c'] = idx
if idx not in axes.values() and val == 3:
axes['c'] = idx
for idx, val in enumerate(shape):
if idx not in axes.values() and val > 1:
if 'f' in axes:
msg = 'Too many large axes'
raise TileSourceError(msg)
axes['f'] = idx
return axes
def _scanZarrArray(self, group, arr, results):
"""
Scan a zarr array and determine if is the maximal dimension array we
can read. If so, update a results dictionary with the information. If
it is the shape as a previous maximum, append it as a multi-series set.
If not, and it is small enough, add it to a list of possible associated
images.
:param group: a zarr group; the parent of the array.
:param arr: a zarr array.
:param results: a dictionary to store the results. 'best' contains a
tuple that is used to find the maximum size array, preferring ome
arrays, then total pixels, then channels. 'is_ome' is a boolean.
'series' is a list of the found groups and arrays that match the
best criteria. 'axes', 'axes_values', 'axes_units', and 'channels'
are from the best array. 'associated' is a list of all groups and
arrays that might be associated images. These have to be culled
for the actual groups used in the series.
"""
attrs = group.attrs.asdict() if group is not None else {}
min_version = packaging.version.Version('0.4')
is_ome = (
isinstance(attrs.get('multiscales', None), list) and
'omero' in attrs and
isinstance(attrs['omero'], dict) and
all(isinstance(m, dict) for m in attrs['multiscales']) and
all(packaging.version.Version(m['version']) >= min_version
for m in attrs['multiscales'] if 'version' in m))
channels = None
axes_values = None
axes_units = None
if is_ome:
axes = {axis['name']: idx for idx, axis in enumerate(
attrs['multiscales'][0]['axes'])}
axes_values = {
axis['name']: axis.get('values')
for axis in attrs['multiscales'][0]['axes']
}
axes_units = {
axis['name']: axis.get('unit')
for axis in attrs['multiscales'][0]['axes']
}
if isinstance(attrs['omero'].get('channels'), list):
channels = [channel['label'] for channel in attrs['omero']['channels']]
if all(channel.startswith('Channel ') for channel in channels):
channels = None
else:
try:
axes = self._getGeneralAxes(arr)
except TileSourceError:
return
if 'x' not in axes or 'y' not in axes:
return
check = (is_ome, math.prod(arr.shape), channels is not None,
tuple(axes.keys()), tuple(channels) if channels else ())
if results['best'] is None or check > results['best']:
results['best'] = check
results['series'] = [(group, arr)]
results['is_ome'] = is_ome
results['axes'] = axes
results['axes_values'] = axes_values
results['axes_units'] = axes_units
results['channels'] = channels
elif check == results['best']:
results['series'].append((group, arr))
if not any(group is g for g, _ in results['associated']):
axes = {k: v for k, v in axes.items() if arr.shape[axes[k]] > 1}
if (len(axes) <= 3 and
self._minAssociatedImageSize <= arr.shape[axes['x']] <=
self._maxAssociatedImageSize and
self._minAssociatedImageSize <= arr.shape[axes['y']] <=
self._maxAssociatedImageSize and
(len(axes) == 2 or ('c' in axes and arr.shape[axes['c']] in {1, 3, 4}))):
results['associated'].append((group, arr))
def _scanZarrGroup(self, group, results=None):
"""
Scan a zarr group for usable arrays.
:param group: a zarr group
:param results: a results dicitionary, updated.
:returns: the results dictionary.
"""
if results is None:
results = {'best': None, 'series': [], 'associated': []}
if isinstance(group, zarr.core.Array):
self._scanZarrArray(None, group, results)
return results
for val in group.values():
if isinstance(val, zarr.core.Array):
self._scanZarrArray(group, val, results)
elif isinstance(val, zarr.hierarchy.Group):
results = self._scanZarrGroup(val, results)
return results
def _zarrFindLevels(self):
"""
Find usable multi-level images. This checks that arrays are nearly a
power of two. This updates self._levels and self._populatedLevels.
self._levels is an array the same length as the number of series, each
entry of which is an array of the number of conceptual tile levels
where each entry of that is either the zarr array that can be used to
get pixels or None if it is not populated.
"""
levels = [[None] * self.levels for _ in self._series]
baseGroup, baseArray = self._series[0]
for idx, (_, arr) in enumerate(self._series):
levels[idx][0] = arr
arrs = [[arr for _, arr in s.arrays()] if s is not None else [a] for s, a in self._series]
for idx, arr in enumerate(arrs[0]):
if any(idx >= len(sarrs) for sarrs in arrs[1:]):
break
if any(arr.shape != sarrs[idx].shape for sarrs in arrs[1:]):
continue
if (nearPowerOfTwo(self.sizeX, arr.shape[self._axes['x']]) and
nearPowerOfTwo(self.sizeY, arr.shape[self._axes['y']])):
level = int(round(math.log(self.sizeX / arr.shape[self._axes['x']]) / math.log(2)))
if level < self.levels and levels[0][level] is None:
for sidx in range(len(self._series)):
levels[sidx][level] = arrs[sidx][idx]
self._levels = levels
self._populatedLevels = len([l for l in self._levels[0] if l is not None])
# TODO: check for inefficient file and raise warning
def _getScale(self):
"""
Get the scale values from the ome metadata and populate the class
values for _mm_x and _mm_y.
"""
unit = {'micrometer': 1e-3, 'millimeter': 1, 'meter': 1e3}
self._mm_x = self._mm_y = None
baseGroup, baseArray = self._series[0]
try:
ms = baseGroup.attrs.asdict()['multiscales'][0]
self._mm_x = ms['datasets'][0]['coordinateTransformations'][0][
'scale'][self._axes['x']] * unit[ms['axes'][self._axes['x']]['unit']]
self._mm_y = ms['datasets'][0]['coordinateTransformations'][0][
'scale'][self._axes['y']] * unit[ms['axes'][self._axes['y']]['unit']]
except Exception:
pass
def _readFrameValues(self, found, baseArray):
"""
Populate frame values and frame units from image metadata.
"""
axes_values = found.get('axes_values')
axes_units = found.get('axes_units')
if isinstance(axes_values, dict):
self._frameAxes = [
a for a, i in sorted(self._axes.items(), key=lambda x: x[1])
if axes_values.get(a) is not None
]
self._frameUnits = {k: axes_units.get(k) for k in self.frameAxes if k in axes_units}
frame_values_shape = [baseArray.shape[self._axes[a]] for a in self.frameAxes]
frame_values_shape.append(len(frame_values_shape))
frame_values = np.empty(frame_values_shape, dtype=object)
all_frame_specs = self.getMetadata().get('frames')
for axis, values in axes_values.items():
if axis in self.frameAxes:
slicing = [slice(None) for i in range(len(frame_values_shape))]
axis_index = self.frameAxes.index(axis)
slicing[-1] = axis_index
if len(values) == frame_values_shape[axis_index]:
# uniform values have same length as axis
for i, value in enumerate(values):
slicing[axis_index] = i
frame_values[tuple(slicing)] = value
elif len(values) == self._framecount:
# non-uniform values have a value for every frame
for i, value in enumerate(values):
if all_frame_specs and len(all_frame_specs) > i:
for k, j in all_frame_specs[i].items():
if 'Index' in k:
name = k.replace('Index', '').lower()
if name:
slicing[self._frameAxes.index(name)] = j
frame_values[tuple(slicing)] = value
self._frameValues = frame_values
def _validateZarr(self):
"""
Validate that we can read tiles from the zarr parent group in
self._zarr. Set up the appropriate class variables.
"""
if self._editable and hasattr(self, '_axes'):
self._writeInternalMetadata()
found = self._scanZarrGroup(self._zarr)
if found['best'] is None:
msg = 'No data array that can be used.'
raise TileSourceError(msg)
self._series = found['series']
baseGroup, baseArray = self._series[0]
self._is_ome = found['is_ome']
self._axes = {
k.lower(): v
for k, v in found['axes'].items()
if baseArray.shape[v] > 1 or k in found.get('axes_values', {})
}
if len(self._series) > 1 and 'xy' in self._axes:
msg = 'Conflicting xy axis data.'
raise TileSourceError(msg)
self._channels = found['channels']
self._associatedImages = {
g.name.replace('/', ''): (g, a)
for g, a in found['associated'] if not any(g is gb for gb, _ in self._series)
}
self.sizeX = baseArray.shape[self._axes['x']]
self.sizeY = baseArray.shape[self._axes['y']]
self.tileWidth = (
baseArray.chunks[self._axes['x']]
if self._minTileSize <= baseArray.chunks[self._axes['x']] <= self._maxTileSize else
self._tileSize)
self.tileHeight = (
baseArray.chunks[self._axes['y']]
if self._minTileSize <= baseArray.chunks[self._axes['y']] <= self._maxTileSize else
self._tileSize)
# If we wanted to require equal tile width and height:
# self.tileWidth = self.tileHeight = self._tileSize
# if (baseArray.chunks[self._axes['x']] == baseArray.chunks[self._axes['y']] and
# self._minTileSize <= baseArray.chunks[self._axes['x']] <= self._maxTileSize):
# self.tileWidth = self.tileHeight = baseArray.chunks[self._axes['x']]
self.levels = int(max(1, math.ceil(math.log(max(
self.sizeX / self.tileWidth, self.sizeY / self.tileHeight)) / math.log(2)) + 1))
self._dtype = np.dtype(baseArray.dtype)
self._bandCount = 1
if ('c' in self._axes and 's' not in self._axes and not self._channels and
baseArray.shape[self._axes.get('c')] in {1, 3, 4}):
self._bandCount = baseArray.shape[self._axes['c']]
self._axes['s'] = self._axes.pop('c')
elif 's' in self._axes:
self._bandCount = baseArray.shape[self._axes['s']]
self._zarrFindLevels()
self._getScale()
stride = 1
self._strides = {}
self._axisCounts = {}
for _, k in sorted(
(-self._axes.get(k, 'tzc'.index(k) if k in 'tzc' else -1), k)
for k in self._axes if k not in 'xys'
):
self._strides[k] = stride
self._axisCounts[k] = baseArray.shape[self._axes[k]]
stride *= baseArray.shape[self._axes[k]]
if len(self._series) > 1:
self._strides['xy'] = stride
self._axisCounts['xy'] = len(self._series)
stride *= len(self._series)
self._framecount = stride
self._readFrameValues(found, baseArray)
def _nonemptyLevelsList(self, frame=0):
"""
Return a list of one value per level where the value is None if the
level does not exist in the file and any other value if it does.
:param frame: the frame number.
:returns: a list of levels length.
"""
return [True if l is not None else None for l in self._levels[0][::-1]]
[docs]
def getNativeMagnification(self):
"""
Get the magnification at a particular level.
:return: magnification, width of a pixel in mm, height of a pixel in mm.
"""
mm_x = self._mm_x
mm_y = self._mm_y
# Estimate the magnification; we don't have a direct value
mag = 0.01 / mm_x if mm_x else None
return {
'magnification': getattr(self, '_magnification', mag),
'mm_x': mm_x,
'mm_y': mm_y,
}
[docs]
def getState(self):
# Use the _cacheValue to avoid caching the source and tiles if we are
# creating something new.
if not hasattr(self, '_cacheValue'):
return super().getState()
return super().getState() + ',%s' % (self._cacheValue, )
[docs]
def getAssociatedImagesList(self):
"""
Get a list of all associated images.
:return: the list of image keys.
"""
return list(self._associatedImages.keys())
def _getAssociatedImage(self, imageKey):
"""
Get an associated image in PIL format.
:param imageKey: the key of the associated image.
:return: the image in PIL format or None.
"""
if imageKey not in self._associatedImages:
return
group, arr = self._associatedImages[imageKey]
axes = self._getGeneralAxes(arr)
trans = [idx for idx in range(len(arr.shape))
if idx not in axes.values()] + [axes['y'], axes['x']]
if 'c' in axes or 's' in axes:
trans.append(axes.get('c', axes.get('s')))
with self._tileLock:
img = np.transpose(arr, trans).squeeze()
if len(img.shape) == 2:
img.expand_dims(axis=2)
return large_image.tilesource.base._imageToPIL(img)
[docs]
@methodcache()
def getTile(self, x, y, z, pilImageAllowed=False, numpyAllowed=False, **kwargs):
if self._levels is None:
self._validateZarr()
frame = self._getFrame(**kwargs)
self._xyzInRange(x, y, z, frame, self._framecount)
x0, y0, x1, y1, step = self._xyzToCorners(x, y, z)
sidx = 0 if len(self._series) <= 1 else frame // self._strides['xy']
targlevel = self.levels - 1 - z
while targlevel and self._levels[sidx][targlevel] is None:
targlevel -= 1
arr = self._levels[sidx][targlevel]
scale = int(2 ** targlevel)
x0 //= scale
y0 //= scale
x1 //= scale
y1 //= scale
step //= scale
if step > 2 ** self._maxSkippedLevels:
tile, _format = self._getTileFromEmptyLevel(x, y, z, **kwargs)
tile = large_image.tilesource.base._imageToNumpy(tile)[0]
else:
idx = [slice(None) for _ in arr.shape]
idx[self._axes['x']] = slice(x0, x1, step)
idx[self._axes['y']] = slice(y0, y1, step)
for key in self._axes:
if key in self._strides:
pos = (frame // self._strides[key]) % self._axisCounts[key]
idx[self._axes[key]] = slice(pos, pos + 1)
trans = [idx for idx in range(len(arr.shape))
if idx not in {self._axes['x'], self._axes['y'],
self._axes.get('s', self._axes['x'])}]
squeezeCount = len(trans)
trans += [self._axes['y'], self._axes['x']]
if 's' in self._axes:
trans.append(self._axes['s'])
with self._tileLock:
tile = arr[tuple(idx)]
tile = np.transpose(tile, trans)
for _ in range(squeezeCount):
tile = tile.squeeze(0)
if len(tile.shape) == 2:
tile = np.expand_dims(tile, axis=2)
return self._outputTile(tile, TILE_FORMAT_NUMPY, x, y, z,
pilImageAllowed, numpyAllowed, **kwargs)
def _validateNewTile(self, tile, mask, placement, axes):
if not isinstance(tile, np.ndarray) or axes is None:
axes = 'yxs'
tile, mode = _imageToNumpy(tile)
elif not isinstance(axes, str) and not isinstance(axes, list):
err = 'Invalid type for axes. Must be str or list[str].'
raise ValueError(err)
axes = [x.lower() for x in axes]
if axes[-1] != 's':
axes.append('s')
if mask is not None and len(axes) - 1 == len(mask.shape):
mask = mask[:, :, np.newaxis]
if 'x' not in axes or 'y' not in axes:
err = 'Invalid value for axes. Must contain "y" and "x".'
raise ValueError(err)
for k in placement:
if k not in axes:
axes[0:0] = [k]
while len(tile.shape) < len(axes):
tile = np.expand_dims(tile, axis=0)
while mask is not None and len(mask.shape) < len(axes):
mask = np.expand_dims(mask, axis=0)
return tile, mask, placement, axes
[docs]
def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs):
"""
Add a numpy or image tile to the image, expanding the image as needed
to accommodate it. Note that x and y can be negative. If so, the
output image (and internal memory access of the image) will act as if
the 0, 0 point is the most negative position. Cropping is applied
after this offset.
:param tile: a numpy array, PIL Image, or a binary string
with an image. The numpy array can have 2 or 3 dimensions.
:param x: location in destination for upper-left corner.
:param y: location in destination for upper-left corner.
:param mask: a 2-d numpy array (or 3-d if the last dimension is 1).
If specified, areas where the mask is false will not be altered.
:param axes: a string or list of strings specifying the names of axes
in the same order as the tile dimensions.
:param kwargs: start locations for any additional axes. Note that
``level`` is a reserved word and not permitted for an axis name.
"""
self._checkEditable()
store_path = str(kwargs.pop('level', 0))
placement = {
'x': x,
'y': y,
**{k: v for k, v in kwargs.items() if not k.endswith('_value')},
}
frame_values = {
k.replace('_value', ''): v for k, v in kwargs.items()
if k not in placement
}
tile, mask, placement, axes = self._validateNewTile(tile, mask, placement, axes)
with self._threadLock and self._processLock:
self._axes = {k: i for i, k in enumerate(axes)}
new_dims = {
a: max(
self._dims.get(store_path, {}).get(a, 0),
placement.get(a, 0) + tile.shape[i],
)
for a, i in self._axes.items()
}
self._dims[store_path] = new_dims
placement_slices = tuple([
slice(placement.get(a, 0), placement.get(a, 0) + tile.shape[i], 1)
for i, a in enumerate(axes)
])
if len(frame_values.keys()) > 0:
# update self.frameValues
self.frameAxes = [
a for a in axes
if a in frame_values or
(self.frameAxes is not None and a in self.frameAxes)
]
frames_shape = [new_dims[a] for a in self.frameAxes]
frames_shape.append(len(frames_shape))
if self.frameValues is None:
self.frameValues = np.empty(frames_shape, dtype=object)
elif self.frameValues.shape != frames_shape:
self.frameValues = np.pad(
self.frameValues,
[(0, s - self.frameValues.shape[i]) for i, s in enumerate(frames_shape)],
)
current_frame_slice = tuple(placement.get(a) for a in self.frameAxes)
for i, k in enumerate(self.frameAxes):
self.frameValues[(*current_frame_slice, i)] = frame_values.get(k)
current_arrays = dict(self._zarr.arrays())
if store_path == '0':
# if writing to base data, invalidate generated levels
for path in current_arrays:
if path != store_path:
self._zarr_store.rmdir(path)
chunking = None
if store_path not in current_arrays:
arr = np.empty(tuple(new_dims.values()), dtype=tile.dtype)
chunking = tuple([
self._tileSize if a in ['x', 'y'] else
new_dims.get('s') if a == 's' else 1
for a in axes
])
else:
arr = current_arrays[store_path]
new_shape = tuple(max(v, arr.shape[i]) for i, v in enumerate(new_dims.values()))
if new_shape != arr.shape:
arr.resize(*new_shape)
if arr.chunks[-1] != new_dims.get('s'):
# rechunk if length of samples axis changes
chunking = tuple([
self._tileSize if a in ['x', 'y'] else
new_dims.get('s') if a == 's' else 1
for a in axes
])
if mask is not None:
arr[placement_slices] = np.where(mask, tile, arr[placement_slices])
else:
arr[placement_slices] = tile
if chunking:
zarr.array(
arr,
chunks=chunking,
overwrite=True,
store=self._zarr_store,
path=store_path,
)
# If base data changed, update large_image attributes
if store_path == '0':
self._dtype = np.dtype(tile.dtype)
self._bandCount = new_dims.get(axes[-1]) # last axis is assumed to be bands
self.sizeX = new_dims.get('x')
self.sizeY = new_dims.get('y')
self._framecount = math.prod([
length
for axis, length in new_dims.items()
if axis in axes[:-3]
])
self._cacheValue = str(uuid.uuid4())
self._levels = None
self.levels = int(max(1, math.ceil(math.log(max(
self.sizeX / self.tileWidth, self.sizeY / self.tileHeight)) / math.log(2)) + 1))
[docs]
def addAssociatedImage(self, image, imageKey=None):
"""
Add an associated image to this source.
:param image: a numpy array, PIL Image, or a binary string
with an image. The numpy array can have 2 or 3 dimensions.
"""
data, _ = _imageToNumpy(image)
with self._threadLock and self._processLock:
if imageKey is None:
# Each associated image should be in its own group
num_existing = len(self.getAssociatedImagesList())
imageKey = f'image_{num_existing + 1}'
group = self._zarr.require_group(imageKey)
arr = zarr.array(
data,
store=self._zarr_store,
path=f'{imageKey}/image',
)
self._associatedImages[imageKey] = (group, arr)
def _getAxisInternalMetadata(self, axis_name):
"""
Get the metadata structure describing an axis in the image.
This will be written to the image metadata.
:param axis_name: a string corresponding to an axis in the image
:returns: a dictionary describing the target axis
"""
axis_metadata = {'name': axis_name}
if axis_name in ['x', 'y']:
axis_metadata['type'] = 'space'
axis_metadata['unit'] = 'millimeter'
elif axis_name in ['s', 'c']:
axis_metadata['type'] = 'channel'
if self.frameAxes is not None:
axis_index = self.frameAxes.index(axis_name) if axis_name in self.frameAxes else None
if axis_index is not None and self.frameValues is not None:
all_frame_values = self.frameValues[..., axis_index]
split = np.split(
all_frame_values,
all_frame_values.shape[axis_index],
axis=axis_index,
)
uniform = all(len(np.unique(
# convert all values to strings for mixed type comparison with np.unique
(a[np.not_equal(a, None)]).astype(str),
)) == 1 for a in split)
if uniform:
values = [a.flat[0] for a in split]
else:
values = all_frame_values.flatten().tolist()
axis_metadata['values'] = values
unit = self.frameUnits.get(axis_name) if self.frameUnits is not None else None
if unit is not None:
axis_metadata['unit'] = unit
return axis_metadata
def _writeInternalMetadata(self):
"""
Write metadata to Zarr attributes.
Metadata structure adheres to OME schema from https://ngff.openmicroscopy.org/latest/
"""
self._checkEditable()
with self._threadLock and self._processLock:
name = str(self._tempdir.name).split('/')[-1]
arrays = dict(self._zarr.arrays())
channel_axis = self._axes.get('c', self._axes.get('s'))
datasets = []
axes = []
channels = []
rdefs = {'model': 'color' if len(self._channelColors) else 'greyscale'}
sorted_axes = [a[0] for a in sorted(self._axes.items(), key=lambda item: item[1])]
for arr_name in arrays:
level = int(arr_name)
scale = [1.0 for a in sorted_axes]
scale[self._axes.get('x')] = (self._mm_x or 0) * (2 ** level)
scale[self._axes.get('y')] = (self._mm_y or 0) * (2 ** level)
dataset_metadata = {
'path': arr_name,
'coordinateTransformations': [{
'type': 'scale',
'scale': scale,
}],
}
datasets.append(dataset_metadata)
for a in sorted_axes:
if a == 't':
rdefs['defaultT'] = 0
elif a == 'z':
rdefs['defaultZ'] = 0
axes.append(self._getAxisInternalMetadata(a))
if channel_axis is not None and len(arrays) > 0:
base_array = list(arrays.values())[0]
base_shape = base_array.shape
for c in range(base_shape[channel_axis]):
channel_metadata = {
'active': True,
'coefficient': 1,
'color': 'FFFFFF',
'family': 'linear',
'inverted': False,
'label': f'Band {c + 1}',
}
# slicing = tuple(
# slice(None)
# if k != ('c' if 'c' in self._axes else 's')
# else c
# for k, v in self._axes.items()
# )
# channel_data = base_array[slicing]
# channel_min = np.min(channel_data)
# channel_max = np.max(channel_data)
# channel_metadata['window'] = {
# 'end': channel_max,
# 'max': channel_max,
# 'min': channel_min,
# 'start': channel_min,
# }
if len(self._channelNames) > c:
channel_metadata['label'] = self._channelNames[c]
if len(self._channelColors) > c:
channel_metadata['color'] = self._channelColors[c]
channels.append(channel_metadata)
# Guidelines from https://ngff.openmicroscopy.org/latest/
self._zarr.attrs.update({
'multiscales': [{
'version': '0.5',
'name': name,
'axes': axes,
'datasets': datasets,
'metadata': {
'description': self._imageDescription or '',
'kwargs': {
'multichannel': (channel_axis is not None),
},
},
}],
'omero': {
'id': 1,
'version': '0.5',
'name': name,
'channels': channels,
'rdefs': rdefs,
},
'bioformats2raw.layout': 3,
})
@property
def crop(self):
"""
Crop only applies to the output file, not the internal data access.
It consists of x, y, w, h in pixels.
"""
return getattr(self, '_crop', None)
@crop.setter
def crop(self, value):
self._checkEditable()
if value is None:
self._crop = None
return
x, y, w, h = value
x = int(x)
y = int(y)
w = int(w)
h = int(h)
if x < 0 or y < 0 or w <= 0 or h <= 0:
msg = 'Crop must have non-negative x, y and positive w, h'
raise TileSourceError(msg)
self._crop = (x, y, w, h)
@property
def mm_x(self):
return self._mm_x
@mm_x.setter
def mm_x(self, value):
self._checkEditable()
value = float(value) if value is not None else None
if value is not None and value <= 0:
msg = 'mm_x must be positive or None'
raise TileSourceError(msg)
self._mm_x = value
@property
def mm_y(self):
return self._mm_y
@mm_y.setter
def mm_y(self, value):
self._checkEditable()
value = float(value) if value is not None else None
if value is not None and value <= 0:
msg = 'mm_y must be positive or None'
raise TileSourceError(msg)
self._mm_y = value
@property
def imageDescription(self):
return self._imageDescription
@imageDescription.setter
def imageDescription(self, description):
self._checkEditable()
self._imageDescription = description
@property
def channelNames(self):
return self._channelNames
@channelNames.setter
def channelNames(self, names):
self._checkEditable()
self._channelNames = names
@property
def channelColors(self):
return self._channelColors
@channelColors.setter
def channelColors(self, colors):
self._checkEditable()
self._channelColors = colors
@property
def frameAxes(self):
return self._frameAxes
@frameAxes.setter
def frameAxes(self, axes):
self._checkEditable()
self._frameAxes = axes
@property
def frameUnits(self):
return self._frameUnits
@frameUnits.setter
def frameUnits(self, units):
self._checkEditable()
if self.frameAxes is None:
err = 'frameAxes must be set first with a list of frame axis names.'
raise ValueError(err)
if not isinstance(units, dict) or not all(
k in self.frameAxes for k in units.keys()
):
err = 'frameUnits must be a dictionary with keys that exist in frameAxes.'
raise ValueError(err)
self._frameUnits = units
@property
def frameValues(self):
return self._frameValues
@frameValues.setter
def frameValues(self, a):
self._checkEditable()
if self.frameAxes is None:
err = 'frameAxes must be set first with a list of frame axis names.'
raise ValueError(err)
if len(a.shape) != len(self.frameAxes) + 1:
err = f'frameValues must have {len(self.frameAxes) + 1} dimensions.'
raise ValueError(err)
self._frameValues = a
def _generateDownsampledLevels(self, resample_method):
self._checkEditable()
current_arrays = dict(self._zarr.arrays())
if '0' not in current_arrays:
msg = 'No root data found, cannot generate lower resolution levels.'
raise TileSourceError(msg)
if 'x' not in self._axes or 'y' not in self._axes:
msg = 'Data must have an X axis and Y axis to generate lower resolution levels.'
raise TileSourceError(msg)
metadata = self.getMetadata()
if (
resample_method.value < ResampleMethod.PIL_MAX_ENUM.value and
resample_method != ResampleMethod.PIL_NEAREST
):
tile_overlap = dict(x=4, y=4, edges=True)
else:
tile_overlap = dict(x=0, y=0)
tile_size = dict(
width=4096 + tile_overlap['x'],
height=4096 + tile_overlap['y'],
)
sorted_axes = [a[0] for a in sorted(self._axes.items(), key=lambda item: item[1])]
for level in range(1, self.levels):
scale_factor = 2 ** level
iterator_output = dict(
maxWidth=self.sizeX // scale_factor,
maxHeight=self.sizeY // scale_factor,
)
for frame in metadata.get('frames', [{'Index': 0}]):
frame_position = {
k.replace('Index', '').lower(): v
for k, v in frame.items()
if k.replace('Index', '').lower() in self._axes
}
for tile in self.tileIterator(
tile_size=tile_size,
tile_overlap=tile_overlap,
frame=frame['Index'],
output=iterator_output,
resample=False, # TODO: incorporate resampling in core
):
new_tile = downsampleTileHalfRes(tile['tile'], resample_method)
overlap = {k: int(v / 2) for k, v in tile['tile_overlap'].items()}
new_tile = new_tile[
slice(overlap['top'], new_tile.shape[0] - overlap['bottom']),
slice(overlap['left'], new_tile.shape[1] - overlap['right']),
]
x = int(tile['x'] / 2 + overlap['left'])
y = int(tile['y'] / 2 + overlap['top'])
self.addTile(
new_tile,
x=x,
y=y,
**frame_position,
axes=sorted_axes,
level=level,
)
self._validateZarr() # refresh self._levels before continuing
[docs]
def write(
self,
path,
lossy=True,
alpha=True,
overwriteAllowed=True,
resample=None,
**converterParams,
):
"""
Output the current image to a file.
:param path: output path.
:param lossy: if false, emit a lossless file.
:param alpha: True if an alpha channel is allowed.
:param overwriteAllowed: if False, raise an exception if the output
path exists.
:param resample: one of the ``ResampleMethod`` enum values. Defaults
to ``NP_NEAREST`` for lossless and non-uint8 data and to
``PIL_LANCZOS`` for lossy uint8 data.
:param converterParams: options to pass to the large_image_converter if
the output is not a zarr variant.
"""
if os.path.exists(path):
if overwriteAllowed:
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)
else:
raise TileSourceError('Output path exists (%s).' % str(path))
suffix = Path(path).suffix
source = self
if self.crop:
left, top, width, height = self.crop
source = new()
source._zarr.attrs.update(self._zarr.attrs)
for frame in self.getMetadata().get('frames', [{'Index': 0}]):
frame_position = {
k.replace('Index', '').lower(): v
for k, v in frame.items()
if k.replace('Index', '').lower() in self._axes
}
for tile in self.tileIterator(
frame=frame['Index'],
region=dict(top=top, left=left, width=width, height=height),
resample=False,
):
source.addTile(
tile['tile'],
x=tile['x'] - left,
y=tile['y'] - top,
axes=list(self._axes.keys()),
**frame_position,
)
source._validateZarr()
if suffix in ['.zarr', '.db', '.sqlite', '.zip']:
if resample is None:
resample = (
ResampleMethod.PIL_LANCZOS
if lossy and source.dtype == np.uint8
else ResampleMethod.NP_NEAREST
)
source._generateDownsampledLevels(resample)
source._writeInternalMetadata() # rewrite with new level datasets
if suffix == '.zarr':
shutil.copytree(str(source._tempdir), path)
elif suffix in ['.db', '.sqlite']:
sqlite_store = zarr.SQLiteStore(path)
zarr.copy_store(source._zarr_store, sqlite_store, if_exists='replace')
sqlite_store.close()
elif suffix == '.zip':
zip_store = zarr.ZipStore(path)
zarr.copy_store(source._zarr_store, zip_store, if_exists='replace')
zip_store.close()
else:
from large_image_converter import convert
attrs_path = source._tempdir / '.zattrs'
params = {}
if lossy and self.dtype == np.uint8:
params['compression'] = 'jpeg'
params.update(converterParams)
convert(str(attrs_path), path, overwrite=overwriteAllowed, **params)
[docs]
def open(*args, **kwargs):
"""
Create an instance of the module class.
"""
return ZarrFileTileSource(*args, **kwargs)
[docs]
def canRead(*args, **kwargs):
"""
Check if an input can be read by the module class.
"""
return ZarrFileTileSource.canRead(*args, **kwargs)
[docs]
def new(*args, **kwargs):
"""
Create a new image, collecting the results from patches of numpy arrays or
smaller images.
"""
return ZarrFileTileSource(NEW_IMAGE_PATH_FLAG + str(uuid.uuid4()), *args, **kwargs)