Getting Started

Install paste python package

You can install the package on pypi: https://pypi.org/project/paste-bio/

[1]:
import time

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import torch

from paste3.helper import get_common_genes, match_spots_using_spatial_heuristic
from paste3.paste import center_align, pairwise_align
from paste3.visualization import plot_slice, stack_slices_center, stack_slices_pairwise

Read data and create AnnData object

[2]:
data_dir = "../../../tests/data/input/"


# Assume that the coordinates of slices are named slice_name + "_coor.csv"
def load_slices(data_dir, slice_names):
    slices = []
    for slice_name in slice_names:
        slice_i = sc.read_csv(data_dir + slice_name + ".csv")
        slice_i_coor = np.genfromtxt(data_dir + slice_name + "_coor.csv", delimiter=",")
        slice_i.obsm["spatial"] = slice_i_coor
        # Preprocess slices
        sc.pp.filter_genes(slice_i, min_counts=15)
        sc.pp.filter_cells(slice_i, min_counts=100)
        slices.append(slice_i)
    return slices


slices = load_slices(data_dir, ["slice1", "slice2", "slice3", "slice4"])
slice1, slice2, slice3, slice4 = slices

Each AnnData object consists of a gene expression matrx and spatial coordinate matrix.

[3]:
slice1.X
[3]:
array([[12.,  0.,  6., ...,  0.,  0.,  0.],
       [ 7.,  0.,  1., ...,  1.,  0.,  0.],
       [15.,  1.,  4., ...,  0.,  0.,  1.],
       ...,
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 1.,  0.,  0., ...,  0.,  0.,  0.],
       [ 5.,  0.,  1., ...,  1.,  0.,  0.]], dtype=float32)
[4]:
slice1.obsm["spatial"][0:5, :]
[4]:
array([[13.064,  6.086],
       [12.116,  7.015],
       [13.945,  6.999],
       [12.987,  7.011],
       [15.011,  7.984]])

Note, you can choose to label the spots however you want. In this case, we use the default coordinates.

[5]:
slice1.obs
[5]:
n_counts
13.064x6.086 2181.0
12.116x7.015 2295.0
13.945x6.999 3375.0
12.987x7.011 2935.0
15.011x7.984 2964.0
... ...
21.953x24.847 541.0
20.98x24.963 860.0
20.063x24.964 508.0
19.007x25.045 626.0
21.957x25.871 2515.0

254 rows × 1 columns

[6]:
slice1.var
[6]:
n_counts
GAPDH 2233.0
UBE2G2 78.0
MAPKAPK2 255.0
NDUFA7 96.0
ASNA1 172.0
... ...
DIP2C 31.0
LYPLA2 19.0
RGP1 24.0
BPGM 17.0
HPS6 16.0

7998 rows × 1 columns

We can visualize the spatial coordinates of our slices using plot_slices.

[7]:
slice_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3"]

fig, axs = plt.subplots(2, 2, figsize=(7, 7))
plot_slice(slice1, slice_colors[0], ax=axs[0, 0])
plot_slice(slice2, slice_colors[1], ax=axs[0, 1])
plot_slice(slice3, slice_colors[2], ax=axs[1, 0])
plot_slice(slice4, slice_colors[3], ax=axs[1, 1])
plt.show()
../_images/notebooks_paste_tutorial_13_0.png

We can also plot using Scanpy’s spatial plotting function.

[8]:
sc.pl.spatial(slice1, color="n_counts", spot_size=1)
../_images/notebooks_paste_tutorial_15_0.png

Pairwise Alignment

Run PASTE pairwise_align.

[9]:
start = time.time()

pi12, _ = pairwise_align(slice1, slice2)
pi23, _ = pairwise_align(slice2, slice3)
pi34, _ = pairwise_align(slice3, slice4)

print("Runtime: " + str(time.time() - start))
(INFO) (paste.py) (16-Nov-24 17:57:37) GPU is not available, resorting to torch CPU.
(INFO) (paste.py) (16-Nov-24 17:57:37) GPU is not available, resorting to torch CPU.
(INFO) (paste.py) (16-Nov-24 17:57:38) GPU is not available, resorting to torch CPU.
Runtime: 0.3858058452606201
[10]:
pd.DataFrame(pi12.cpu().numpy())
[10]:
0 1 2 3 4 5 6 7 8 9 ... 240 241 242 243 244 245 246 247 248 249
0 0.003937 0.000000 0.000000 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000e+00 0.000000 0.000000
1 0.000063 0.003874 0.000000 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000e+00 0.000000 0.000000
2 0.000000 0.000000 0.000000 0.0 0.0 0.003937 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000e+00 0.000000 0.000000
3 0.000000 0.000126 0.003811 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000e+00 0.000000 0.000000
4 0.000000 0.000000 0.000000 0.0 0.0 0.000000 0.003874 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000e+00 0.000000 0.000000
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
249 0.000000 0.000000 0.000000 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 3.937008e-03 0.000000 0.000000
250 0.000000 0.000000 0.000000 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 6.299213e-05 0.003874 0.000000
251 0.000000 0.000000 0.000000 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000126 0.003811 0.0 0.0 0.0 3.035766e-17 0.000000 0.000000
252 0.000000 0.000000 0.000000 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.003748 0.000000 0.000189 0.0 0.0 0.0 0.000000e+00 0.000000 0.000000
253 0.000000 0.000000 0.000000 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000e+00 0.000000 0.003937

254 rows × 250 columns

Sequential pairwise slice alignment plots

[11]:
pis = [pi12, pi23, pi34]
slices = [slice1, slice2, slice3, slice4]

new_slices, _, _ = stack_slices_pairwise(slices, pis)

Now that we’ve aligned the spatial coordinates, we can plot them all on the same coordinate system.

[12]:
slice_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3"]

plt.figure(figsize=(7, 7))
for i in range(len(new_slices)):
    plot_slice(new_slices[i], slice_colors[i], s=400)
plt.legend(
    handles=[
        mpatches.Patch(color=slice_colors[0], label="1"),
        mpatches.Patch(color=slice_colors[1], label="2"),
        mpatches.Patch(color=slice_colors[2], label="3"),
        mpatches.Patch(color=slice_colors[3], label="4"),
    ]
)
plt.gca().invert_yaxis()
plt.axis("off")
plt.show()
../_images/notebooks_paste_tutorial_23_0.png

We can also plot pairwise layers together.

[13]:
slice_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3"]

fig, axs = plt.subplots(2, 2, figsize=(7, 7))
plot_slice(new_slices[0], slice_colors[0], ax=axs[0, 0])
plot_slice(new_slices[1], slice_colors[1], ax=axs[0, 0])
plot_slice(new_slices[1], slice_colors[1], ax=axs[0, 1])
plot_slice(new_slices[2], slice_colors[2], ax=axs[0, 1])
plot_slice(new_slices[2], slice_colors[2], ax=axs[1, 0])
plot_slice(new_slices[3], slice_colors[3], ax=axs[1, 0])
fig.delaxes(axs[1, 1])
plt.show()
../_images/notebooks_paste_tutorial_25_0.png

We can also plot the slices in 3-D.

[14]:
import plotly.express as px
import plotly.io as pio

pio.renderers.default = "notebook"

slices_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3"]

# scale the distance between layers
z_scale = 2

values = []
for i, L in enumerate(new_slices):
    for x, y in L.obsm["spatial"]:
        values.append([x, y, i * z_scale, str(i)])
df = pd.DataFrame(values, columns=["x", "y", "z", "slice"])
fig = px.scatter_3d(
    df, x="x", y="y", z="z", color="slice", color_discrete_sequence=slice_colors
)
fig.update_layout(scene_aspectmode="data")
fig.show()

Center Alignment

First, we will read in and preprocess the data (if you ran pairwise_align above, it will be altered).

[15]:
slices = load_slices(data_dir, ["slice1", "slice2", "slice3", "slice4"])
slice1, slice2, slice3, slice4 = slices

Run PASTE center_align.

[16]:
slices = [slice1, slice2, slice3, slice4]
initial_slice = slice1.copy()
lmbda = len(slices) * [1 / len(slices)]

Now, for center alignment, we can provide initial mappings between the center and original slices to PASTE to improve the algorithm. However, note this is optional.

[17]:
slices, _ = get_common_genes(slices)

b = []
for i in range(len(slices)):
    b.append(
        torch.Tensor(
            match_spots_using_spatial_heuristic(slices[0].X, slices[i].X)
        ).double()
    )
[18]:
start = time.time()

center_slice, pis = center_align(
    initial_slice, slices, lmbda, random_seed=5, pi_inits=b
)

print("Runtime: " + str(time.time() - start))
(INFO) (paste.py) (16-Nov-24 17:57:49) GPU is not available, resorting to torch CPU.
(INFO) (paste.py) (16-Nov-24 17:57:49) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:57:52) Iteration: 0
(INFO) (paste.py) (16-Nov-24 17:57:52) Solving Pairwise Slice Alignment Problem.
(INFO) (paste.py) (16-Nov-24 17:57:52) Slice 0
(INFO) (paste.py) (16-Nov-24 17:57:52) Slice 1
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.

(INFO) (paste.py) (16-Nov-24 17:57:52) Slice 2
(INFO) (paste.py) (16-Nov-24 17:57:52) Slice 3
(INFO) (paste.py) (16-Nov-24 17:57:52) center_ot done
(INFO) (paste.py) (16-Nov-24 17:57:52) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:57:55) Objective -13.865903767431423 | Difference: 13.865903767431423
(INFO) (paste.py) (16-Nov-24 17:57:55) Iteration: 1
(INFO) (paste.py) (16-Nov-24 17:57:55) Solving Pairwise Slice Alignment Problem.
(INFO) (paste.py) (16-Nov-24 17:57:55) Slice 0
(INFO) (paste.py) (16-Nov-24 17:57:55) Slice 1
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.

(INFO) (paste.py) (16-Nov-24 17:57:55) Slice 2
(INFO) (paste.py) (16-Nov-24 17:57:55) Slice 3
(INFO) (paste.py) (16-Nov-24 17:57:55) center_ot done
(INFO) (paste.py) (16-Nov-24 17:57:55) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:57:58) Objective 1.3829621916807069 | Difference: 15.24886595911213
(INFO) (paste.py) (16-Nov-24 17:57:58) Iteration: 2
(INFO) (paste.py) (16-Nov-24 17:57:58) Solving Pairwise Slice Alignment Problem.
(INFO) (paste.py) (16-Nov-24 17:57:58) Slice 0
(INFO) (paste.py) (16-Nov-24 17:57:58) Slice 1
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.

(INFO) (paste.py) (16-Nov-24 17:57:58) Slice 2
(INFO) (paste.py) (16-Nov-24 17:57:58) Slice 3
(INFO) (paste.py) (16-Nov-24 17:57:58) center_ot done
(INFO) (paste.py) (16-Nov-24 17:57:58) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:58:01) Objective 1.3880932065404366 | Difference: 0.005131014859729666
(INFO) (paste.py) (16-Nov-24 17:58:01) Iteration: 3
(INFO) (paste.py) (16-Nov-24 17:58:01) Solving Pairwise Slice Alignment Problem.
(INFO) (paste.py) (16-Nov-24 17:58:01) Slice 0
(INFO) (paste.py) (16-Nov-24 17:58:01) Slice 1
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.

(INFO) (paste.py) (16-Nov-24 17:58:01) Slice 2
(INFO) (paste.py) (16-Nov-24 17:58:01) Slice 3
(INFO) (paste.py) (16-Nov-24 17:58:01) center_ot done
(INFO) (paste.py) (16-Nov-24 17:58:01) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:58:04) Objective 1.3915232061917202 | Difference: 0.003429999651283655
(INFO) (paste.py) (16-Nov-24 17:58:04) Iteration: 4
(INFO) (paste.py) (16-Nov-24 17:58:04) Solving Pairwise Slice Alignment Problem.
(INFO) (paste.py) (16-Nov-24 17:58:04) Slice 0
(INFO) (paste.py) (16-Nov-24 17:58:04) Slice 1
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.

(INFO) (paste.py) (16-Nov-24 17:58:04) Slice 2
(INFO) (paste.py) (16-Nov-24 17:58:04) Slice 3
(INFO) (paste.py) (16-Nov-24 17:58:04) center_ot done
(INFO) (paste.py) (16-Nov-24 17:58:04) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:58:07) Objective 1.3930712361281365 | Difference: 0.0015480299364163397
(INFO) (paste.py) (16-Nov-24 17:58:07) Iteration: 5
(INFO) (paste.py) (16-Nov-24 17:58:07) Solving Pairwise Slice Alignment Problem.
(INFO) (paste.py) (16-Nov-24 17:58:07) Slice 0
(INFO) (paste.py) (16-Nov-24 17:58:07) Slice 1
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.

(INFO) (paste.py) (16-Nov-24 17:58:07) Slice 2
(INFO) (paste.py) (16-Nov-24 17:58:07) Slice 3
(INFO) (paste.py) (16-Nov-24 17:58:07) center_ot done
(INFO) (paste.py) (16-Nov-24 17:58:07) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:58:10) Objective 1.395250483498415 | Difference: 0.0021792473702784143
(INFO) (paste.py) (16-Nov-24 17:58:10) Iteration: 6
(INFO) (paste.py) (16-Nov-24 17:58:10) Solving Pairwise Slice Alignment Problem.
(INFO) (paste.py) (16-Nov-24 17:58:10) Slice 0
(INFO) (paste.py) (16-Nov-24 17:58:10) Slice 1
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.

(INFO) (paste.py) (16-Nov-24 17:58:10) Slice 2
(INFO) (paste.py) (16-Nov-24 17:58:10) Slice 3
(INFO) (paste.py) (16-Nov-24 17:58:10) center_ot done
(INFO) (paste.py) (16-Nov-24 17:58:10) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:58:13) Objective 1.3953790526176948 | Difference: 0.00012856911927983106
Runtime: 23.361464977264404
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.

Again, we can run center align without providing intial mappings below.

[19]:
# center_slice, pis = paste.center_align(initial_slice, slices, lmbda, random_seed = 5)

center_slice returns an AnnData object that also includes the low dimensional representation of our inferred center slice.

[20]:
center_slice.uns["paste_W"]
[20]:
array([[2.56406116e-02, 1.17735608e+00, 1.45992053e-01, ...,
        1.67005408e-02, 8.65816920e-03, 1.64249334e-02],
       [1.93240013e-03, 3.38372966e-01, 1.80787149e-01, ...,
        1.61882753e-02, 1.01921081e-03, 4.46527714e-02],
       [4.19665832e-01, 1.88694801e-01, 2.90303240e-02, ...,
        4.75023311e-02, 1.62389887e-02, 1.17923350e-01],
       ...,
       [2.91596165e-02, 8.19205072e-03, 1.07181455e-01, ...,
        4.61294426e-04, 4.20524981e-04, 2.33305738e-01],
       [4.31159630e-02, 3.82986839e-02, 1.78688007e-01, ...,
        9.82425693e-03, 3.23265860e-03, 2.58806995e-01],
       [1.83314337e-01, 1.68948804e-05, 2.22536788e-02, ...,
        1.16837119e-05, 4.58040466e-03, 9.97049918e-02]])
[21]:
center_slice.uns["paste_H"]
[21]:
array([[9.00595555e-01, 1.28095211e-01, 7.99820106e-02, ...,
        5.45672185e-03, 4.42535682e-02, 2.77099853e-02],
       [1.84556880e+00, 1.12907521e-01, 3.19053463e-01, ...,
        1.01912357e-01, 1.03440813e-01, 3.08299135e-02],
       [1.93982105e+00, 1.42403630e-01, 1.30713051e-01, ...,
        7.35120019e-02, 3.50568645e-02, 2.86018512e-02],
       ...,
       [3.05294205e+00, 1.88580791e-01, 3.20976581e-01, ...,
        2.51081658e-02, 5.00736543e-06, 1.10129713e-02],
       [9.20159500e-01, 1.65184151e-01, 1.04867309e-01, ...,
        1.26153742e-02, 1.25054026e-02, 2.44712742e-02],
       [5.73756503e+00, 7.43537812e-04, 3.13147999e-01, ...,
        6.80617759e-06, 3.99149863e-02, 2.55629904e-13]])

Center slice alignment plots

Next, we can use the outputs of center_align to align the slices.

[22]:
center, new_slices, _, _ = stack_slices_center(center_slice, slices, pis)

Now that we’ve aligned the spatial coordinates, we can plot them all on the same coordinate system. Note the center slice is not plotted.

[23]:
center_color = "orange"
slices_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3"]

plt.figure(figsize=(7, 7))
plot_slice(center, center_color, s=400)
for i in range(len(new_slices)):
    plot_slice(new_slices[i], slices_colors[i], s=400)

plt.legend(
    handles=[
        mpatches.Patch(color=slices_colors[0], label="1"),
        mpatches.Patch(color=slices_colors[1], label="2"),
        mpatches.Patch(color=slices_colors[2], label="3"),
        mpatches.Patch(color=slices_colors[3], label="4"),
    ]
)
plt.gca().invert_yaxis()
plt.axis("off")
plt.show()
../_images/notebooks_paste_tutorial_45_0.png

Next, we plot each slice compared to the center.

Note that since we used slice1 as the coordinates for the center slice, they remain the same, and thus we cannot see both in our plots below.

[24]:
center_color = "orange"
slice_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3"]

fig, axs = plt.subplots(2, 2, figsize=(7, 7))
plot_slice(center, center_color, ax=axs[0, 0])
plot_slice(new_slices[0], slice_colors[0], ax=axs[0, 0])

plot_slice(center, center_color, ax=axs[0, 1])
plot_slice(new_slices[1], slice_colors[1], ax=axs[0, 1])

plot_slice(center, center_color, ax=axs[1, 0])
plot_slice(new_slices[2], slice_colors[2], ax=axs[1, 0])

plot_slice(center, center_color, ax=axs[1, 1])
plot_slice(new_slices[3], slice_colors[3], ax=axs[1, 1])
plt.show()
../_images/notebooks_paste_tutorial_48_0.png

Gpu Implementation

POT allows us to write backend agnostic code, allowing us to use Numpy, Pytorch, etc to calculate our computations (https://pythonot.github.io/gen_modules/ot.backend.html).

We have updated our code to include gpu support for Pytorch.

First, you want to make sure you have torch installed. One way to check is by running:

[25]:
import ot

ot.backend.get_backend_list()
[25]:
[<ot.backend.NumpyBackend at 0x7fec2e5a4fb0>,
 <ot.backend.TorchBackend at 0x7fec36e147d0>]

We check to make sure you have access to gpu. PASTE automatically does this check for you, but it is still helpful to know if you want to debug why you can’t seem to access your gpu.

[26]:
import torch

torch.cuda.is_available()
[26]:
False

Running PASTE with gpu

Note: Since the breast dataset is small, cpu may actually be faster than gpu in this particular case. For larger datasets, you will see a greater improvement in gpu vs cpu.

First, we read in our data.

[27]:
data_dir = "../../../tests/data/input/"


# Assume that the coordinates of slices are named slice_name + "_coor.csv"
def load_slices(data_dir, slice_names):
    slices = []
    for slice_name in slice_names:
        slice_i = sc.read_csv(data_dir + slice_name + ".csv")
        slice_i_coor = np.genfromtxt(data_dir + slice_name + "_coor.csv", delimiter=",")
        slice_i.obsm["spatial"] = slice_i_coor
        # Preprocess slices
        sc.pp.filter_genes(slice_i, min_counts=15)
        sc.pp.filter_cells(slice_i, min_counts=100)
        slices.append(slice_i)
    return slices


slices = load_slices(data_dir, ["slice1", "slice2", "slice3", "slice4"])
slice1, slice2, slice3, slice4 = slices

Next, running with gpu is as easy as setting two parameters in our function.

[28]:
start = time.time()

pi12, _ = pairwise_align(slice1, slice2, use_gpu=True)
pi23, _ = pairwise_align(slice2, slice3, use_gpu=True)
pi34, _ = pairwise_align(slice3, slice4, use_gpu=True)

print("Runtime: " + str(time.time() - start))
(INFO) (paste.py) (16-Nov-24 17:58:14) GPU is not available, resorting to torch CPU.
(INFO) (paste.py) (16-Nov-24 17:58:14) GPU is not available, resorting to torch CPU.
(INFO) (paste.py) (16-Nov-24 17:58:14) GPU is not available, resorting to torch CPU.
Runtime: 0.36208081245422363
[29]:
pd.DataFrame(pi12.cpu().numpy())
[29]:
0 1 2 3 4 5 6 7 8 9 ... 240 241 242 243 244 245 246 247 248 249
0 0.003937 0.000000 0.000000 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000e+00 0.000000 0.000000
1 0.000063 0.003874 0.000000 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000e+00 0.000000 0.000000
2 0.000000 0.000000 0.000000 0.0 0.0 0.003937 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000e+00 0.000000 0.000000
3 0.000000 0.000126 0.003811 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000e+00 0.000000 0.000000
4 0.000000 0.000000 0.000000 0.0 0.0 0.000000 0.003874 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000e+00 0.000000 0.000000
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
249 0.000000 0.000000 0.000000 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 3.937008e-03 0.000000 0.000000
250 0.000000 0.000000 0.000000 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 6.299213e-05 0.003874 0.000000
251 0.000000 0.000000 0.000000 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000126 0.003811 0.0 0.0 0.0 3.035766e-17 0.000000 0.000000
252 0.000000 0.000000 0.000000 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.003748 0.000000 0.000189 0.0 0.0 0.0 0.000000e+00 0.000000 0.000000
253 0.000000 0.000000 0.000000 0.0 0.0 0.000000 0.000000 0.0 0.0 0.0 ... 0.0 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000e+00 0.000000 0.003937

254 rows × 250 columns

We do the same with center_align().

Note: This time, we skip providing initial mappings pi_init = b as previously done above.

[30]:
slices = load_slices(data_dir, ["slice1", "slice2", "slice3", "slice4"])
slice1, slice2, slice3, slice4 = slices

slices = [slice1, slice2, slice3, slice4]
initial_slice = slice1.copy()
lmbda = len(slices) * [1 / len(slices)]
[31]:
start = time.time()

center_slice, pis = center_align(
    initial_slice,
    slices,
    lmbda,
    random_seed=5,
    use_gpu=True,
)

print("Runtime: " + str(time.time() - start))
(INFO) (paste.py) (16-Nov-24 17:58:15) GPU is not available, resorting to torch CPU.
(INFO) (paste.py) (16-Nov-24 17:58:15) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:58:20) Iteration: 0
(INFO) (paste.py) (16-Nov-24 17:58:20) Solving Pairwise Slice Alignment Problem.
(INFO) (paste.py) (16-Nov-24 17:58:20) Slice 0
(INFO) (paste.py) (16-Nov-24 17:58:20) Slice 1
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.

(INFO) (paste.py) (16-Nov-24 17:58:20) Slice 2
(INFO) (paste.py) (16-Nov-24 17:58:20) Slice 3
(INFO) (paste.py) (16-Nov-24 17:58:21) center_ot done
(INFO) (paste.py) (16-Nov-24 17:58:21) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:58:23) Objective -27.38540251265489 | Difference: 27.38540251265489
(INFO) (paste.py) (16-Nov-24 17:58:23) Iteration: 1
(INFO) (paste.py) (16-Nov-24 17:58:23) Solving Pairwise Slice Alignment Problem.
(INFO) (paste.py) (16-Nov-24 17:58:23) Slice 0
(INFO) (paste.py) (16-Nov-24 17:58:23) Slice 1
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.

(INFO) (paste.py) (16-Nov-24 17:58:23) Slice 2
(INFO) (paste.py) (16-Nov-24 17:58:24) Slice 3
(INFO) (paste.py) (16-Nov-24 17:58:24) center_ot done
(INFO) (paste.py) (16-Nov-24 17:58:24) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:58:26) Objective 1.381581304723351 | Difference: 28.76698381737824
(INFO) (paste.py) (16-Nov-24 17:58:26) Iteration: 2
(INFO) (paste.py) (16-Nov-24 17:58:26) Solving Pairwise Slice Alignment Problem.
(INFO) (paste.py) (16-Nov-24 17:58:26) Slice 0
(INFO) (paste.py) (16-Nov-24 17:58:26) Slice 1
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.

(INFO) (paste.py) (16-Nov-24 17:58:26) Slice 2
(INFO) (paste.py) (16-Nov-24 17:58:27) Slice 3
(INFO) (paste.py) (16-Nov-24 17:58:27) center_ot done
(INFO) (paste.py) (16-Nov-24 17:58:27) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:58:29) Objective 1.392083257538912 | Difference: 0.010501952815560989
(INFO) (paste.py) (16-Nov-24 17:58:29) Iteration: 3
(INFO) (paste.py) (16-Nov-24 17:58:29) Solving Pairwise Slice Alignment Problem.
(INFO) (paste.py) (16-Nov-24 17:58:29) Slice 0
(INFO) (paste.py) (16-Nov-24 17:58:29) Slice 1
(INFO) (paste.py) (16-Nov-24 17:58:29) Slice 2
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.

(INFO) (paste.py) (16-Nov-24 17:58:29) Slice 3
(INFO) (paste.py) (16-Nov-24 17:58:30) center_ot done
(INFO) (paste.py) (16-Nov-24 17:58:30) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:58:32) Objective 1.3945581180803914 | Difference: 0.0024748605414794955
(INFO) (paste.py) (16-Nov-24 17:58:32) Iteration: 4
(INFO) (paste.py) (16-Nov-24 17:58:32) Solving Pairwise Slice Alignment Problem.
(INFO) (paste.py) (16-Nov-24 17:58:32) Slice 0
(INFO) (paste.py) (16-Nov-24 17:58:32) Slice 1
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.

(INFO) (paste.py) (16-Nov-24 17:58:32) Slice 2
(INFO) (paste.py) (16-Nov-24 17:58:32) Slice 3
(INFO) (paste.py) (16-Nov-24 17:58:32) center_ot done
(INFO) (paste.py) (16-Nov-24 17:58:32) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:58:35) Objective 1.3955875667153284 | Difference: 0.0010294486349369247
(INFO) (paste.py) (16-Nov-24 17:58:35) Iteration: 5
(INFO) (paste.py) (16-Nov-24 17:58:35) Solving Pairwise Slice Alignment Problem.
(INFO) (paste.py) (16-Nov-24 17:58:35) Slice 0
(INFO) (paste.py) (16-Nov-24 17:58:35) Slice 1
(INFO) (paste.py) (16-Nov-24 17:58:35) Slice 2
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.

(INFO) (paste.py) (16-Nov-24 17:58:35) Slice 3
(INFO) (paste.py) (16-Nov-24 17:58:35) center_ot done
(INFO) (paste.py) (16-Nov-24 17:58:35) Solving Center Mapping NMF Problem.
(INFO) (paste.py) (16-Nov-24 17:58:38) Objective 1.3956006249691741 | Difference: 1.305825384578796e-05
Runtime: 22.791017055511475
/opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/sklearn/decomposition/_nmf.py:1759: ConvergenceWarning:

Maximum number of iterations 200 reached. Increase it to improve convergence.