Indexes

The mechanism powering xarray label-based lookups is called an “index”. The most flexible way to extend indexing and selection is to create a custom index. For simpler cases, e.g. “aliasing” dimension coordinates, there are more straightforward approaches.

Consider an attrs class describing a 2D structured grid , with integer fields for the size of each dimension and an array field containing some data variable living on grid cells. This is a canonical case for xattree and the dim() and array() decorators.

[1]:
import numpy as np
import pandas as pd
import xarray as xr
from numpy.typing import NDArray
[2]:
from xattree import array, dim, xattree
[3]:
@xattree
class Grid:
    rows: int = dim(default=3)
    cols: int = dim(default=3)
    arr: NDArray[np.float64] = array(default=0.0, dims=("rows", "cols"))
[4]:
grid = Grid()
grid.data
[4]:
<xarray.DataTree 'grid'>
Group: /
    Dimensions:  (rows: 3, cols: 3)
    Coordinates:
      * rows     (rows) int64 24B 0 1 2
      * cols     (cols) int64 24B 0 1 2
    Data variables:
        arr      (rows, cols) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
    Attributes:
        host:      Grid(rows=3, cols=3, arr=<xarray.DataArray 'arr' (rows: 3, col...
        metadata:  {'arr': {}, 'rows': {}, 'cols': {}}

But the field names aren’t quite right as coordinate names — while it’s natural for rows and cols to become dimension coordinates and for dimensions to be named such, we expect i and j as coordinate names.

We might think first to try Dataset.rename():

[5]:
grid.data.dataset.rename({"rows": "i", "cols": "j"})
[5]:
<xarray.Dataset> Size: 120B
Dimensions:  (i: 3, j: 3)
Coordinates:
  * i        (i) int64 24B 0 1 2
  * j        (j) int64 24B 0 1 2
Data variables:
    arr      (i, j) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
Attributes:
    host:      Grid(rows=3, cols=3, arr=<xarray.DataArray 'arr' (rows: 3, col...
    metadata:  {'arr': {}, 'rows': {}, 'cols': {}}

But this renames not only the coordinates but also the dimensions. Ideally, we want dimensions rows/cols, coordinates i/j.

xattree provides a simple way to achieve this: just pass a new name to the coord parameter of the dim() decorator. This will create a new coordinate variable with the given name, but leave the dimension name unchanged. As expected for a dimension coordinate, the new coordinate variable will have a PandasIndex attached.

[6]:
@xattree
class Grid:
    rows: int = dim(default=3, coord="i")
    cols: int = dim(default=3, coord="j")
    arr: NDArray[np.float64] = array(default=0.0, dims=("rows", "cols"))
[7]:
grid = Grid()
grid.data
[7]:
<xarray.DataTree 'grid'>
Group: /
    Dimensions:  (rows: 3, cols: 3)
    Coordinates:
      * i        (rows) int64 24B 0 1 2
      * j        (cols) int64 24B 0 1 2
    Dimensions without coordinates: rows, cols
    Data variables:
        arr      (rows, cols) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
    Attributes:
        host:      Grid(rows=3, cols=3, arr=<xarray.DataArray 'arr' (rows: 3, col...
        metadata:  {'arr': {}, 'rows': {}, 'cols': {}}

Try out the label-based indexing.

[8]:
grid.data.arr.sel(i=0)
[8]:
<xarray.DataArray 'arr' (cols: 3)> Size: 24B
array([0., 0., 0.])
Coordinates:
    i        int64 8B 0
  * j        (cols) int64 24B 0 1 2
Dimensions without coordinates: cols

A more general approach is to create a custom index.

As a first step, we can set coord=False on the dim() call, which will prevent xattree from creating a coordinate variable for a dimension field.

[9]:
@xattree
class Grid:
    rows: int = dim(default=3, coord=False)
    cols: int = dim(default=3, coord=False)
    arr: NDArray[np.float64] = array(default=0.0, dims=("rows", "cols"))

Now the rows and cols fields are not coordinates, but the dataset now has no coordinate variables.

[10]:
grid = Grid()
grid.data
[10]:
<xarray.DataTree 'grid'>
Group: /
    Dimensions:  (rows: 3, cols: 3)
    Dimensions without coordinates: rows, cols
    Data variables:
        arr      (rows, cols) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
    Attributes:
        host:      Grid(rows=3, cols=3, arr=<xarray.DataArray 'arr' (rows: 3, col...
        rows:      3
        cols:      3
        metadata:  {'arr': {}, 'rows': {}, 'cols': {}}

To support i and j labels, we can create a custom index class and register it with xattree.

A custom index is a class subclassing Index. The Index interface is a set of methods that allow you to create, manipulate, and query the index.

For our purposes, an index which simply aliases the rows and cols fields to i and j, respectively, will suffice.

First create an aliasing function, which creates for a dataset dimension with the given name a PandasIndex with a new name.

[11]:
from xarray.core.indexes import Index, PandasIndex
[12]:
def alias(dataset: xr.Dataset, old_name: str, new_name: str) -> PandasIndex:
    """Alias a dimension coordinate to a coordinate with a different name."""
    try:
        size = dataset.sizes[old_name]
    except KeyError:
        size = dataset.attrs[old_name]
    return PandasIndex(pd.RangeIndex(size, name=new_name), dim=old_name)

Now create a meta-index, with which we can combine two “aliased” 1D indexes into a 2D index.

[13]:
from xarray.core.indexing import merge_sel_results
[14]:
class MetaIndex(Index):
    def __init__(self, indices):
        self._indices = indices

    @classmethod
    def from_variables(cls, variables):
        return {k: PandasIndex.from_variables({k: v}) for k, v in variables.items()}

    def create_variables(self, variables=None):
        idx_vars = {}
        for index in self._indices.values():
            idx_vars.update(index.create_variables(variables))
        return idx_vars

    def sel(self, labels):
        results = []
        for k, index in self._indices.items():
            if k in labels:
                results.append(index.sel({k: labels[k]}))
        return merge_sel_results(results)

Define a function to build an instance of the index from a dataset.

[15]:
def meta_index(dataset: xr.Dataset) -> MetaIndex:
    return MetaIndex(
        {
            "i": alias(dataset, "rows", "i"),
            "j": alias(dataset, "cols", "j"),
        }
    )

Finally, register it with the xattree decorator. The index parameter is a callable that takes a dataset and returns an index.

[16]:
@xattree(index=meta_index)
class Grid:
    rows: int = dim(default=3, coord=False)
    cols: int = dim(default=3, coord=False)
    arr: NDArray[np.float64] = array(default=0.0, dims=("rows", "cols"))
[17]:
grid = Grid()
grid.data
[17]:
<xarray.DataTree 'grid'>
Group: /
    Dimensions:  (rows: 3, cols: 3)
    Coordinates:
      * i        (rows) int64 24B 0 1 2
      * j        (cols) int64 24B 0 1 2
    Dimensions without coordinates: rows, cols
    Data variables:
        arr      (rows, cols) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
    Indexes:
      ┌ i        MetaIndex
      └ j
    Attributes:
        host:      Grid(rows=3, cols=3, arr=<xarray.DataArray 'arr' (rows: 3, col...
        rows:      3
        cols:      3
        metadata:  {'arr': {}, 'rows': {}, 'cols': {}}

Check the label-based indexing.

[18]:
grid.data.arr.sel(i=0)
[18]:
<xarray.DataArray 'arr' (cols: 3)> Size: 24B
array([0., 0., 0.])
Coordinates:
    i        int64 8B 0
    j        (cols) int64 24B 0 1 2
Dimensions without coordinates: cols

Make sure the index is working as expected.

[19]:
assert grid.rows == 3
assert grid.cols == 3
assert grid.data.i.shape == (3,)
assert grid.data.j.shape == (3,)
assert "i" in grid.data.coords
assert "j" in grid.data.coords
assert "rows" not in grid.data.coords
assert "cols" not in grid.data.coords
assert grid.data.arr.shape == (3, 3)

Like dimensions and coordinates, an index can be “lifted” to a context above the current class with the xattree decorator’s index_scope parameter.

[20]:
from xattree import ROOT, field
[21]:
@xattree(index=meta_index, index_scope=ROOT)
class Grid:
    rows: int = dim(default=3, coord=False, scope=ROOT)
    cols: int = dim(default=3, coord=False, scope=ROOT)
[22]:
@xattree
class Arrs:
    arr: NDArray[np.float64] = array(default=0.0, dims=("rows", "cols"))
[23]:
@xattree
class Root:
    grid: Grid = field()
    arrs: Arrs = field()
[24]:
grid = Grid()
root = Root(grid=grid)
arrs = Arrs(parent=root)
[25]:
root.data
[25]:
<xarray.DataTree 'root'>
Group: /
│   Dimensions:  (rows: 3, cols: 3)
│   Coordinates:
│     * i        (rows) int64 24B 0 1 2
│     * j        (cols) int64 24B 0 1 2
│   Dimensions without coordinates: rows, cols
│   Indexes:
│     ┌ i        MetaIndex
│     └ j
│   Attributes:
│       host:      Root(grid=Grid(rows=3, cols=3, name='grid', strict=True), arrs...
│       rows:      3
│       cols:      3
│       metadata:  {'grid': {}, 'arrs': {}, 'rows': {}, 'cols': {}}
├── Group: /grid
│       Indexes:
│         ┌ i        MetaIndex
│         └ j
│       Attributes:
│           host:      Grid(rows=3, cols=3, name='grid', strict=True)
│           rows:      3
│           cols:      3
│           metadata:  {'rows': {}, 'cols': {}}
└── Group: /arrs
        Dimensions:  (rows: 3, cols: 3)
        Dimensions without coordinates: rows, cols
        Data variables:
            arr      (rows, cols) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
        Indexes:
          ┌ i        MetaIndex
          └ j
        Attributes:
            host:      Arrs(arr=<xarray.DataArray 'arr' (rows: 3, cols: 3)> Size: 72B...
            metadata:  {'arr': {}}

Indexing still works as expected.

[26]:
assert arrs.arr.sel(i=0).shape == (3,)

Derived dimensions are supported too.

[27]:
def meta_index(dataset: xr.Dataset) -> MetaIndex:
    return MetaIndex(
        {
            "i": alias(dataset, "rows", "i"),
            "j": alias(dataset, "cols", "j"),
            "n": alias(dataset, "nodes", "n"),
        }
    )
[28]:
@xattree(index=meta_index, index_scope=ROOT)
class Grid:
    rows: int = dim(default=3, coord=False, scope=ROOT)
    cols: int = dim(default=3, coord=False, scope=ROOT)
    nodes: int = dim(init=False, coord=False, scope=ROOT)

    def __attrs_post_init__(self):
        self.nodes = self.rows * self.cols
[29]:
@xattree
class Arrs:
    arr_a: NDArray[np.float64] = array(default=0.0, dims=("rows", "cols"))
    arr_b: NDArray[np.float64] = array(default=0.0, dims=("nodes",))
[30]:
@xattree
class Root:
    grid: Grid = field()
    arrs: Arrs = field()
[31]:
grid = Grid()
root = Root(grid=grid)
arrs = Arrs(parent=root)
[32]:
root.data
[32]:
<xarray.DataTree 'root'>
Group: /
│   Dimensions:  (rows: 3, cols: 3, nodes: 9)
│   Coordinates:
│     * i        (rows) int64 24B 0 1 2
│     * j        (cols) int64 24B 0 1 2
│     * n        (nodes) int64 72B 0 1 2 3 4 5 6 7 8
│   Dimensions without coordinates: rows, cols, nodes
│   Indexes:
│     ┌ i        MetaIndex
│     │ j
│     └ n
│   Attributes:
│       host:      Root(grid=Grid(rows=3, cols=3, nodes=9, name='grid', strict=Tr...
│       rows:      3
│       cols:      3
│       nodes:     9
│       metadata:  {'grid': {}, 'arrs': {}, 'rows': {}, 'cols': {}, 'nodes': {}}
├── Group: /grid
│       Indexes:
│         ┌ i        MetaIndex
│         │ j
│         └ n
│       Attributes:
│           host:      Grid(rows=3, cols=3, nodes=9, name='grid', strict=True)
│           rows:      3
│           cols:      3
│           nodes:     9
│           metadata:  {'rows': {}, 'cols': {}, 'nodes': {}}
└── Group: /arrs
        Dimensions:  (rows: 3, cols: 3, nodes: 9)
        Dimensions without coordinates: rows, cols, nodes
        Data variables:
            arr_a    (rows, cols) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
            arr_b    (nodes) float64 72B 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
        Indexes:
          ┌ i        MetaIndex
          │ j
          └ n
        Attributes:
            host:      Arrs(arr_a=<xarray.DataArray 'arr_a' (rows: 3, cols: 3)> Size:...
            metadata:  {'arr_a': {}, 'arr_b': {}}