Diamond Tiling

A scheduling technique for stencil codes.

In [39]:
import islpy as isl
dim_type = isl.dim_type
import islplot.plotter as iplt
import matplotlib.pyplot as plt
In [40]:
def plot_set(s, shape_color="blue", point_color="orange"):
    if shape_color is not None:
        iplt.plot_set_shapes(s, color=shape_color)

    if point_color is not None:
        iplt.plot_set_points(s, color=point_color)
    plt.xlabel("i")
    plt.ylabel("j")
    plt.gca().set_aspect("equal")
    plt.grid()

Loop Domain

In [41]:
d = isl.BasicSet("[nx,nt] -> {[ix, it]: 0<=ix<nx and 0<=it<nt and nx = 43}")
d
Out[41]:
BasicSet("[nx, nt] -> { [ix, it] : nx = 43 and 0 <= ix <= 42 and 0 <= it < nt }")

Index -> Tile Index Map

In [42]:
m = isl.BasicMap("[nx,nt] -> "
                 "{[ix, it] -> [tx, tt, parity]: "
                 "tx - tt = floor((ix - it)/16) and "
                 "tx+(tt+parity) = floor((ix + it)/16) and 0<=parity<2}")
m
Out[42]:
BasicMap("[nx, nt] -> { [ix, it] -> [tx, tt, parity] : -ix + it + 16tx <= 16tt <= 15 - ix + it + 16tx and 0 <= parity <= 1 and -15 + ix + it - 16tx - 16tt <= 16parity <= ix + it - 16tx - 16tt }")
In [43]:
plt.figure(figsize=(12, 8))

for color, parity in [
        ("blue", 0),
        ("orange", 1),
        ]:
    tilerange = isl.BasicMap("[nx,nt] -> {[ix, it] -> [tx, tt,parity]: tx<3 and tt<3 and parity=%d}" % parity)
    plot_dom = (m.intersect_domain(d) & tilerange).domain()
    plot_set(plot_dom.project_out(dim_type.param, 0, 2), shape_color=None, point_color=color)

This is how you can find the values of (tx, tt, parity) for a point in the loop domain:

In [44]:
point = isl.BasicMap("[nx,nt] -> {[ix, it] -> [tx, tt, parity]: ix = 0 and it = 17}")
m & point
Out[44]:
BasicMap("[nx, nt] -> { [ix, it] -> [tx, tt, parity] : ix = 0 and it = 17 and tx = -1 and tt = 1 and parity = 1 }")

We'd expect m to be a function, but not injective (each tile contains multiple points):

In [45]:
m.is_single_valued()
Out[45]:
True
In [46]:
m.is_injective()
Out[46]:
False

Index -> Tile index, intra-tile index map

We'll now work to make the mapping bijective, i.e. establish a point-to-point mapping that we can use to rewrite the loop domain:

In [47]:
m = isl.BasicMap(
    "[nx,nt] -> {[ix, it] -> [tx, tt, tparity, itt, itx]: "
    "16*(tx - tt) + itx - itt = ix - it and "
    "16*(tx + tt + tparity) + itt + itx = ix + it and "
    "0<=tparity<2 and 0 <= itx - itt < 16 and 0 <= itt+itx < 16}")
m
Out[47]:
BasicMap("[nx, nt] -> { [ix, it] -> [tx, tt, tparity, itt, itx] : itt = it - 16tt - 8tparity and itx = ix - 16tx - 8tparity and -ix + it + 16tx <= 16tt <= 15 - ix + it + 16tx and 0 <= tparity <= 1 and -15 + ix + it - 16tx - 16tt <= 16tparity <= ix + it - 16tx - 16tt }")
In [48]:
m.is_bijective()
Out[48]:
True
In [49]:
plt.figure(figsize=(12, 8))

for color, parity in [
        ("blue", 0),
        ("orange", 1),
        ]:
    tilerange = isl.BasicMap("[nx,nt] -> {[ix, it] -> [tx, tt, tparity, itt, itx]: tx<3 and tt<3 and tparity=%d}" % parity)
    plot_dom = (
        m.intersect_domain(d)
        & tilerange).domain()
    plot_set(plot_dom.project_out(dim_type.param, 0, 2), shape_color=None, point_color=color)
In [50]:
point = isl.BasicMap("[nx,nt] -> {[ix, it] -> [tx, tt, tparity, itx, itt]: ix = 0 and it = 16}")
m & point
Out[50]:
BasicMap("[nx, nt] -> { [ix, it] -> [tx, tt, tparity, itt, itx] : ix = 0 and it = 16 and tx = 0 and tt = 1 and tparity = 0 and itt = 0 and itx = 0 }")
In [51]:
m.project_out(dim_type.in_, 0, 1)
Out[51]:
BasicMap("[nx, nt] -> { [it] -> [tx, tt, tparity, itt, itx] : itt = it - 16tt - 8tparity and 0 <= tparity <= 1 and itx >= it - 16tt - 8tparity and -it + 16tt + 8tparity <= itx <= 15 - it + 16tt + 8tparity and itx <= 15 + it - 16tt - 8tparity }")
In [52]:
m
Out[52]:
BasicMap("[nx, nt] -> { [ix, it] -> [tx, tt, tparity, itt, itx] : itt = it - 16tt - 8tparity and itx = ix - 16tx - 8tparity and -ix + it + 16tx <= 16tt <= 15 - ix + it + 16tx and 0 <= tparity <= 1 and -15 + ix + it - 16tx - 16tt <= 16tparity <= ix + it - 16tx - 16tt }")
In [53]:
tile = isl.BasicMap("[nx,nt] -> {[ix, it] -> [tx, tt, tparity, itx, itt]: tt = 1 and tx = 1 and tparity = 0}")
plot_tile = (m & tile).range().project_out(dim_type.set, 0, 3).project_out(dim_type.param, 0, 2)
plot_tile
plot_set(plot_tile, shape_color=None)

What's next from here? Hexagonal tiles are a popular option that improves utilization at the "top" and "bottom" of the diamonds.

In [ ]: