{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# This cell is to allow automatic notebook generation for docs\n", "# You may want to comment this out if you have paste3 installed\n", "import sys\n", "from pathlib import Path\n", "\n", "sys.path.insert(0, str(Path.cwd().parent.parent.parent / \"src\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Using the PASTE algorithm\n", "\n", "This noteook highlights the creation of slices (Anndata objects), usage of the `pairwise_align` and `center_align` functions of `paste3`, along with stacking and plotting functionalities.\n", "\n", "**This notebook primarily highlights how you would use the `paste3` package in `PASTE` (i.e. full alignment) mode, when the slices overlap over the full 2D assayed region, with a similar field of view and similar number and proportion of cell types.**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "import matplotlib.patches as mpatches\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import scanpy as sc\n", "import torch\n", "\n", "from paste3.helper import get_common_genes, match_spots_using_spatial_heuristic\n", "from paste3.paste import center_align, pairwise_align\n", "from paste3.visualization import plot_slice, stack_slices_center, stack_slices_pairwise" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Read data and create AnnData objects" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_dir = \"../../../tests/data/\"\n", "\n", "\n", "# Assume that the coordinates of slices are named slice_name + \"_coor.csv\"\n", "def load_slices(data_dir, slice_names):\n", " slices = []\n", " for slice_name in slice_names:\n", " slice_i = sc.read_csv(data_dir + slice_name + \".csv\")\n", " slice_i_coor = np.genfromtxt(data_dir + slice_name + \"_coor.csv\", delimiter=\",\")\n", " slice_i.obsm[\"spatial\"] = slice_i_coor\n", " # Preprocess slices\n", " sc.pp.filter_genes(slice_i, min_counts=15)\n", " sc.pp.filter_cells(slice_i, min_counts=100)\n", " slices.append(slice_i)\n", " return slices\n", "\n", "\n", "slices = load_slices(data_dir, [\"slice1\", \"slice2\", \"slice3\", \"slice4\"])\n", "slice1, slice2, slice3, slice4 = slices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each AnnData object consists of a gene expression matrx and spatial coordinate matrix." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "slice1.X" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "slice1.obsm[\"spatial\"][0:5, :]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note, you can choose to label the spots however you want. In this case, we use the default coordinates." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "slice1.obs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "slice1.var" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can visualize the spatial coordinates of our slices using `plot_slices`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "slice_colors = [\"#e41a1c\", \"#377eb8\", \"#4daf4a\", \"#984ea3\"]\n", "\n", "fig, axs = plt.subplots(2, 2, figsize=(7, 7))\n", "plot_slice(slice1, slice_colors[0], ax=axs[0, 0])\n", "plot_slice(slice2, slice_colors[1], ax=axs[0, 1])\n", "plot_slice(slice3, slice_colors[2], ax=axs[1, 0])\n", "plot_slice(slice4, slice_colors[3], ax=axs[1, 1])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also plot using Scanpy's spatial plotting function." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sc.pl.spatial(slice1, color=\"n_counts\", spot_size=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pairwise Alignment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run PASTE `pairwise_align`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "start = time.time()\n", "\n", "pi12, _ = pairwise_align(slice1, slice2)\n", "pi23, _ = pairwise_align(slice2, slice3)\n", "pi34, _ = pairwise_align(slice3, slice4)\n", "\n", "print(\"Runtime: \" + str(time.time() - start))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pd.DataFrame(pi12.cpu().numpy())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sequential pairwise slice alignment plots" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pis = [pi12, pi23, pi34]\n", "slices = [slice1, slice2, slice3, slice4]\n", "\n", "new_slices, _, _ = stack_slices_pairwise(slices, pis)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we've aligned the spatial coordinates, we can plot them all on the same coordinate system." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "slice_colors = [\"#e41a1c\", \"#377eb8\", \"#4daf4a\", \"#984ea3\"]\n", "\n", "plt.figure(figsize=(7, 7))\n", "for i in range(len(new_slices)):\n", " plot_slice(new_slices[i], slice_colors[i], s=400)\n", "plt.legend(\n", " handles=[\n", " mpatches.Patch(color=slice_colors[0], label=\"1\"),\n", " mpatches.Patch(color=slice_colors[1], label=\"2\"),\n", " mpatches.Patch(color=slice_colors[2], label=\"3\"),\n", " mpatches.Patch(color=slice_colors[3], label=\"4\"),\n", " ]\n", ")\n", "plt.gca().invert_yaxis()\n", "plt.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also plot pairwise layers together." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "slice_colors = [\"#e41a1c\", \"#377eb8\", \"#4daf4a\", \"#984ea3\"]\n", "\n", "fig, axs = plt.subplots(2, 2, figsize=(7, 7))\n", "plot_slice(new_slices[0], slice_colors[0], ax=axs[0, 0])\n", "plot_slice(new_slices[1], slice_colors[1], ax=axs[0, 0])\n", "plot_slice(new_slices[1], slice_colors[1], ax=axs[0, 1])\n", "plot_slice(new_slices[2], slice_colors[2], ax=axs[0, 1])\n", "plot_slice(new_slices[2], slice_colors[2], ax=axs[1, 0])\n", "plot_slice(new_slices[3], slice_colors[3], ax=axs[1, 0])\n", "fig.delaxes(axs[1, 1])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also plot the slices in 3-D." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import plotly.express as px\n", "import plotly.io as pio\n", "\n", "pio.renderers.default = \"notebook\"\n", "\n", "slices_colors = [\"#e41a1c\", \"#377eb8\", \"#4daf4a\", \"#984ea3\"]\n", "\n", "# scale the distance between layers\n", "z_scale = 2\n", "\n", "values = []\n", "for i, L in enumerate(new_slices):\n", " for x, y in L.obsm[\"spatial\"]:\n", " values.append([x, y, i * z_scale, str(i)])\n", "df = pd.DataFrame(values, columns=[\"x\", \"y\", \"z\", \"slice\"])\n", "fig = px.scatter_3d(\n", " df, x=\"x\", y=\"y\", z=\"z\", color=\"slice\", color_discrete_sequence=slice_colors\n", ")\n", "fig.update_layout(scene_aspectmode=\"data\")\n", "fig.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Center Alignment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we will read in and preprocess the data (if you ran `pairwise_align` above, it will be altered)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "slices = load_slices(data_dir, [\"slice1\", \"slice2\", \"slice3\", \"slice4\"])\n", "slice1, slice2, slice3, slice4 = slices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run PASTE `center_align`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "slices = [slice1, slice2, slice3, slice4]\n", "initial_slice = slice1.copy()\n", "lmbda = len(slices) * [1 / len(slices)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, for center alignment, we can provide initial mappings between the center and original slices to PASTE to improve the algorithm. However, note this is optional." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "slices, _ = get_common_genes(slices)\n", "\n", "b = []\n", "for i in range(len(slices)):\n", " b.append(\n", " torch.Tensor(\n", " match_spots_using_spatial_heuristic(slices[0].X, slices[i].X)\n", " ).double()\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "start = time.time()\n", "\n", "center_slice, pis = center_align(\n", " initial_slice, slices, lmbda, random_seed=5, pi_inits=b\n", ")\n", "\n", "print(\"Runtime: \" + str(time.time() - start))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again, we can run center align without providing intial mappings below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# center_slice, pis = paste.center_align(initial_slice, slices, lmbda, random_seed = 5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`center_slice` returns an AnnData object that also includes the low dimensional representation of our inferred center slice." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "center_slice.uns[\"paste_W\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "center_slice.uns[\"paste_H\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Center slice alignment plots" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we can use the outputs of `center_align` to align the slices." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "center, new_slices, _, _ = stack_slices_center(center_slice, slices, pis)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we've aligned the spatial coordinates, we can plot them all on the same coordinate system. Note the center slice is not plotted." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "center_color = \"orange\"\n", "slices_colors = [\"#e41a1c\", \"#377eb8\", \"#4daf4a\", \"#984ea3\"]\n", "\n", "plt.figure(figsize=(7, 7))\n", "plot_slice(center, center_color, s=400)\n", "for i in range(len(new_slices)):\n", " plot_slice(new_slices[i], slices_colors[i], s=400)\n", "\n", "plt.legend(\n", " handles=[\n", " mpatches.Patch(color=slices_colors[0], label=\"1\"),\n", " mpatches.Patch(color=slices_colors[1], label=\"2\"),\n", " mpatches.Patch(color=slices_colors[2], label=\"3\"),\n", " mpatches.Patch(color=slices_colors[3], label=\"4\"),\n", " ]\n", ")\n", "plt.gca().invert_yaxis()\n", "plt.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we plot each slice compared to the center." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that since we used slice1 as the coordinates for the center slice, they remain the same, and thus we cannot see both in our plots below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "center_color = \"orange\"\n", "slice_colors = [\"#e41a1c\", \"#377eb8\", \"#4daf4a\", \"#984ea3\"]\n", "\n", "fig, axs = plt.subplots(2, 2, figsize=(7, 7))\n", "plot_slice(center, center_color, ax=axs[0, 0])\n", "plot_slice(new_slices[0], slice_colors[0], ax=axs[0, 0])\n", "\n", "plot_slice(center, center_color, ax=axs[0, 1])\n", "plot_slice(new_slices[1], slice_colors[1], ax=axs[0, 1])\n", "\n", "plot_slice(center, center_color, ax=axs[1, 0])\n", "plot_slice(new_slices[2], slice_colors[2], ax=axs[1, 0])\n", "\n", "plot_slice(center, center_color, ax=axs[1, 1])\n", "plot_slice(new_slices[3], slice_colors[3], ax=axs[1, 1])\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "307.2px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 4 }