Indexes¶
The mechanism powering xarray label-based lookups is called an “index”. The most flexible way to extend indexing and selection is to create a custom index. For simpler cases, e.g. “aliasing” dimension coordinates, there are more straightforward approaches.
Consider an attrs class describing a 2D structured grid , with integer fields for the size of each dimension and an array field containing some data variable living on grid cells. This is a canonical case for xattree and the dim() and array() decorators.
[1]:
import numpy as np
import pandas as pd
import xarray as xr
from numpy.typing import NDArray
[2]:
from xattree import array, dim, xattree
[3]:
@xattree
class Grid:
rows: int = dim(default=3)
cols: int = dim(default=3)
arr: NDArray[np.float64] = array(default=0.0, dims=("rows", "cols"))
[4]:
grid = Grid()
grid.data
[4]:
<xarray.DataTree 'grid'>
Group: /
Dimensions: (rows: 3, cols: 3)
Coordinates:
* rows (rows) int64 24B 0 1 2
* cols (cols) int64 24B 0 1 2
Data variables:
arr (rows, cols) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
Attributes:
host: Grid(rows=3, cols=3, arr=<xarray.DataArray 'arr' (rows: 3, col...
metadata: {'arr': {}, 'rows': {}, 'cols': {}}But the field names aren’t quite right as coordinate names — while it’s natural for rows and cols to become dimension coordinates and for dimensions to be named such, we expect i and j as coordinate names.
We might think first to try Dataset.rename():
[5]:
grid.data.dataset.rename({"rows": "i", "cols": "j"})
[5]:
<xarray.Dataset> Size: 120B
Dimensions: (i: 3, j: 3)
Coordinates:
* i (i) int64 24B 0 1 2
* j (j) int64 24B 0 1 2
Data variables:
arr (i, j) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
Attributes:
host: Grid(rows=3, cols=3, arr=<xarray.DataArray 'arr' (rows: 3, col...
metadata: {'arr': {}, 'rows': {}, 'cols': {}}But this renames not only the coordinates but also the dimensions. Ideally, we want dimensions rows/cols, coordinates i/j.
xattree provides a simple way to achieve this: just pass a new name to the coord parameter of the dim() decorator. This will create a new coordinate variable with the given name, but leave the dimension name unchanged. As expected for a dimension coordinate, the new coordinate variable will have a PandasIndex attached.
[6]:
@xattree
class Grid:
rows: int = dim(default=3, coord="i")
cols: int = dim(default=3, coord="j")
arr: NDArray[np.float64] = array(default=0.0, dims=("rows", "cols"))
[7]:
grid = Grid()
grid.data
[7]:
<xarray.DataTree 'grid'>
Group: /
Dimensions: (rows: 3, cols: 3)
Coordinates:
* i (rows) int64 24B 0 1 2
* j (cols) int64 24B 0 1 2
Dimensions without coordinates: rows, cols
Data variables:
arr (rows, cols) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
Attributes:
host: Grid(rows=3, cols=3, arr=<xarray.DataArray 'arr' (rows: 3, col...
metadata: {'arr': {}, 'rows': {}, 'cols': {}}Try out the label-based indexing.
[8]:
grid.data.arr.sel(i=0)
[8]:
<xarray.DataArray 'arr' (cols: 3)> Size: 24B
array([0., 0., 0.])
Coordinates:
i int64 8B 0
* j (cols) int64 24B 0 1 2
Dimensions without coordinates: colsA more general approach is to create a custom index.
As a first step, we can set coord=False on the dim() call, which will prevent xattree from creating a coordinate variable for a dimension field.
[9]:
@xattree
class Grid:
rows: int = dim(default=3, coord=False)
cols: int = dim(default=3, coord=False)
arr: NDArray[np.float64] = array(default=0.0, dims=("rows", "cols"))
Now the rows and cols fields are not coordinates, but the dataset now has no coordinate variables.
[10]:
grid = Grid()
grid.data
[10]:
<xarray.DataTree 'grid'>
Group: /
Dimensions: (rows: 3, cols: 3)
Dimensions without coordinates: rows, cols
Data variables:
arr (rows, cols) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
Attributes:
host: Grid(rows=3, cols=3, arr=<xarray.DataArray 'arr' (rows: 3, col...
rows: 3
cols: 3
metadata: {'arr': {}, 'rows': {}, 'cols': {}}To support i and j labels, we can create a custom index class and register it with xattree.
A custom index is a class subclassing Index. The Index interface is a set of methods that allow you to create, manipulate, and query the index.
For our purposes, an index which simply aliases the rows and cols fields to i and j, respectively, will suffice.
First create an aliasing function, which creates for a dataset dimension with the given name a PandasIndex with a new name.
[11]:
from xarray.core.indexes import Index, PandasIndex
[12]:
def alias(dataset: xr.Dataset, old_name: str, new_name: str) -> PandasIndex:
"""Alias a dimension coordinate to a coordinate with a different name."""
try:
size = dataset.sizes[old_name]
except KeyError:
size = dataset.attrs[old_name]
return PandasIndex(pd.RangeIndex(size, name=new_name), dim=old_name)
Now create a meta-index, with which we can combine two “aliased” 1D indexes into a 2D index.
[13]:
from xarray.core.indexing import merge_sel_results
[14]:
class MetaIndex(Index):
def __init__(self, indices):
self._indices = indices
@classmethod
def from_variables(cls, variables):
return {k: PandasIndex.from_variables({k: v}) for k, v in variables.items()}
def create_variables(self, variables=None):
idx_vars = {}
for index in self._indices.values():
idx_vars.update(index.create_variables(variables))
return idx_vars
def sel(self, labels):
results = []
for k, index in self._indices.items():
if k in labels:
results.append(index.sel({k: labels[k]}))
return merge_sel_results(results)
Define a function to build an instance of the index from a dataset.
[15]:
def meta_index(dataset: xr.Dataset) -> MetaIndex:
return MetaIndex(
{
"i": alias(dataset, "rows", "i"),
"j": alias(dataset, "cols", "j"),
}
)
Finally, register it with the xattree decorator. The index parameter is a callable that takes a dataset and returns an index.
[16]:
@xattree(index=meta_index)
class Grid:
rows: int = dim(default=3, coord=False)
cols: int = dim(default=3, coord=False)
arr: NDArray[np.float64] = array(default=0.0, dims=("rows", "cols"))
[17]:
grid = Grid()
grid.data
[17]:
<xarray.DataTree 'grid'>
Group: /
Dimensions: (rows: 3, cols: 3)
Coordinates:
* i (rows) int64 24B 0 1 2
* j (cols) int64 24B 0 1 2
Dimensions without coordinates: rows, cols
Data variables:
arr (rows, cols) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
Indexes:
┌ i MetaIndex
└ j
Attributes:
host: Grid(rows=3, cols=3, arr=<xarray.DataArray 'arr' (rows: 3, col...
rows: 3
cols: 3
metadata: {'arr': {}, 'rows': {}, 'cols': {}}Check the label-based indexing.
[18]:
grid.data.arr.sel(i=0)
[18]:
<xarray.DataArray 'arr' (cols: 3)> Size: 24B
array([0., 0., 0.])
Coordinates:
i int64 8B 0
j (cols) int64 24B 0 1 2
Dimensions without coordinates: colsMake sure the index is working as expected.
[19]:
assert grid.rows == 3
assert grid.cols == 3
assert grid.data.i.shape == (3,)
assert grid.data.j.shape == (3,)
assert "i" in grid.data.coords
assert "j" in grid.data.coords
assert "rows" not in grid.data.coords
assert "cols" not in grid.data.coords
assert grid.data.arr.shape == (3, 3)
Like dimensions and coordinates, an index can be “lifted” to a context above the current class with the xattree decorator’s index_scope parameter.
[20]:
from xattree import ROOT, field
[21]:
@xattree(index=meta_index, index_scope=ROOT)
class Grid:
rows: int = dim(default=3, coord=False, scope=ROOT)
cols: int = dim(default=3, coord=False, scope=ROOT)
[22]:
@xattree
class Arrs:
arr: NDArray[np.float64] = array(default=0.0, dims=("rows", "cols"))
[23]:
@xattree
class Root:
grid: Grid = field()
arrs: Arrs = field()
[24]:
grid = Grid()
root = Root(grid=grid)
arrs = Arrs(parent=root)
[25]:
root.data
[25]:
<xarray.DataTree 'root'>
Group: /
│ Dimensions: (rows: 3, cols: 3)
│ Coordinates:
│ * i (rows) int64 24B 0 1 2
│ * j (cols) int64 24B 0 1 2
│ Dimensions without coordinates: rows, cols
│ Indexes:
│ ┌ i MetaIndex
│ └ j
│ Attributes:
│ host: Root(grid=Grid(rows=3, cols=3, name='grid', strict=True), arrs...
│ rows: 3
│ cols: 3
│ metadata: {'grid': {}, 'arrs': {}, 'rows': {}, 'cols': {}}
├── Group: /grid
│ Indexes:
│ ┌ i MetaIndex
│ └ j
│ Attributes:
│ host: Grid(rows=3, cols=3, name='grid', strict=True)
│ rows: 3
│ cols: 3
│ metadata: {'rows': {}, 'cols': {}}
└── Group: /arrs
Dimensions: (rows: 3, cols: 3)
Dimensions without coordinates: rows, cols
Data variables:
arr (rows, cols) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
Indexes:
┌ i MetaIndex
└ j
Attributes:
host: Arrs(arr=<xarray.DataArray 'arr' (rows: 3, cols: 3)> Size: 72B...
metadata: {'arr': {}}Indexing still works as expected.
[26]:
assert arrs.arr.sel(i=0).shape == (3,)
Derived dimensions are supported too.
[27]:
def meta_index(dataset: xr.Dataset) -> MetaIndex:
return MetaIndex(
{
"i": alias(dataset, "rows", "i"),
"j": alias(dataset, "cols", "j"),
"n": alias(dataset, "nodes", "n"),
}
)
[28]:
@xattree(index=meta_index, index_scope=ROOT)
class Grid:
rows: int = dim(default=3, coord=False, scope=ROOT)
cols: int = dim(default=3, coord=False, scope=ROOT)
nodes: int = dim(init=False, coord=False, scope=ROOT)
def __attrs_post_init__(self):
self.nodes = self.rows * self.cols
[29]:
@xattree
class Arrs:
arr_a: NDArray[np.float64] = array(default=0.0, dims=("rows", "cols"))
arr_b: NDArray[np.float64] = array(default=0.0, dims=("nodes",))
[30]:
@xattree
class Root:
grid: Grid = field()
arrs: Arrs = field()
[31]:
grid = Grid()
root = Root(grid=grid)
arrs = Arrs(parent=root)
[32]:
root.data
[32]:
<xarray.DataTree 'root'>
Group: /
│ Dimensions: (rows: 3, cols: 3, nodes: 9)
│ Coordinates:
│ * i (rows) int64 24B 0 1 2
│ * j (cols) int64 24B 0 1 2
│ * n (nodes) int64 72B 0 1 2 3 4 5 6 7 8
│ Dimensions without coordinates: rows, cols, nodes
│ Indexes:
│ ┌ i MetaIndex
│ │ j
│ └ n
│ Attributes:
│ host: Root(grid=Grid(rows=3, cols=3, nodes=9, name='grid', strict=Tr...
│ rows: 3
│ cols: 3
│ nodes: 9
│ metadata: {'grid': {}, 'arrs': {}, 'rows': {}, 'cols': {}, 'nodes': {}}
├── Group: /grid
│ Indexes:
│ ┌ i MetaIndex
│ │ j
│ └ n
│ Attributes:
│ host: Grid(rows=3, cols=3, nodes=9, name='grid', strict=True)
│ rows: 3
│ cols: 3
│ nodes: 9
│ metadata: {'rows': {}, 'cols': {}, 'nodes': {}}
└── Group: /arrs
Dimensions: (rows: 3, cols: 3, nodes: 9)
Dimensions without coordinates: rows, cols, nodes
Data variables:
arr_a (rows, cols) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
arr_b (nodes) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
Indexes:
┌ i MetaIndex
│ j
└ n
Attributes:
host: Arrs(arr_a=<xarray.DataArray 'arr_a' (rows: 3, cols: 3)> Size:...
metadata: {'arr_a': {}, 'arr_b': {}}