import numpy as np
from sklearn.neighbors import NearestNeighbors
from datetime import datetime
from scipy.stats import norm as normal
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial import distance
from scipy.sparse import csr_matrix, csgraph, find
import math
import pandas as pd
import numpy as np
from numpy import ndarray
from scipy.sparse import issparse, spmatrix
import hnswlib
import time
import matplotlib
import igraph as ig
import matplotlib.pyplot as plt
from datetime import datetime
import matplotlib.cm as cm
import pygam as pg
from sklearn.preprocessing import normalize
from typing import Optional, Union
#from utils_via import *
from pyVIA.utils_via import *
import random
from scipy.spatial.distance import pdist, squareform
from sklearn.preprocessing import StandardScaler
def _pl_velocity_embedding(via_object, X_emb, smooth_transition, b, use_sequentially_augmented=False):
'''
:param X_emb:
:param smooth_transition:
:return: V_emb
'''
# T transition matrix at single cell level
n_obs = X_emb.shape[0]
V_emb = np.zeros(X_emb.shape)
print('inside _plotting')
if via_object.single_cell_transition_matrix is None:
print('inside _plotting compute sc_transitionmatrix')
via_object.single_cell_transition_matrix = via_object.sc_transition_matrix(smooth_transition, b,
use_sequentially_augmented=use_sequentially_augmented)
T = via_object.single_cell_transition_matrix
else:
print('get _plotting compute sc_transitionmatrix')
T = via_object.single_cell_transition_matrix
# the change in embedding distance when moving from cell i to its neighbors is given by dx
for i in range(n_obs):
indices = T[i].indices
dX = X_emb[indices] - X_emb[i, None] # shape (n_neighbors, 2)
dX /= l2_norm(dX)[:, None]
# dX /= np.sqrt(dX.multiply(dX).sum(axis=1).A1)[:, None]
dX[np.isnan(dX)] = 0 # zero diff in a steady-state
# neighbor edge weights are used to weight the overall dX or velocity from cell i.
probs = T[i].data
# if probs.size ==0: print('velocity embedding probs=0 length', probs, i, self.true_label[i])
V_emb[i] = probs.dot(dX) - probs.mean() * dX.sum(0)
V_emb /= 3 * quiver_autoscale(X_emb, V_emb)
return V_emb
def geodesic_distance(data: ndarray, knn: int = 10, root: int = 0, mst_mode: bool = False,
cluster_labels: ndarray = None):
n_samples = data.shape[0]
# make knn graph on low dimensional data "data"
knn_struct = construct_knn_utils(data, knn=knn)
neighbors, distances = knn_struct.knn_query(data, k=knn)
msk = np.full_like(distances, True, dtype=np.bool_)
# https://igraph.org/python/versions/0.10.1/tutorials/shortest_paths/shortest_paths.html
# Remove self-loops
msk &= (neighbors != np.arange(neighbors.shape[0])[:, np.newaxis])
rows = np.array([np.repeat(i, len(x)) for i, x in enumerate(neighbors)])[msk]
cols = neighbors[msk]
weights = distances[msk] # we keep the distances as the weights here will actually be edge distances
result = csr_matrix((weights, (rows, cols)), shape=(len(neighbors), len(neighbors)), dtype=np.float32)
if mst_mode:
print(f'MST geodesic mode')
from scipy.sparse.csgraph import minimum_spanning_tree
MST_ = minimum_spanning_tree(result)
result = result + MST_
result.eliminate_zeros()
sources, targets = result.nonzero()
edgelist = list(zip(sources.tolist(), targets.tolist()))
G = ig.Graph(edgelist, edge_attrs={'weight': result.data.tolist()})
if cluster_labels is not None:
graph = ig.VertexClustering(G, membership=cluster_labels).cluster_graph(combine_edges='sum')
graph = recompute_weights(graph, Counter(cluster_labels)) # returns csr matrix
weights = graph.data / (np.std(graph.data))
edges = list(zip(*graph.nonzero()))
G = ig.Graph(edges, edge_attrs={'weight': weights})
root = cluster_labels[root]
# get shortest distance from root to each point
geo_distance_list = []
print(f'start computing shortest paths')
for i in range(G.vcount()):
if cluster_labels is None:
if i % 1000 == 0: print(f'{datetime.now()}\t{i} out of {n_samples} complete')
shortest_path = G.get_shortest_paths(root, to=i, weights=G.es["weight"], output="epath")
if len(shortest_path[0]) > 0:
# Add up the weights across all edges on the shortest path
distance = 0
for e in shortest_path[0]:
distance += G.es[e]["weight"]
geo_distance_list.append(distance)
# print("Shortest weighted distance is: ", distance)
else:
geo_distance_list.append(0)
return geo_distance_list
def corr_geodesic_distance_lowdim(embedding, knn=10, time_labels: list = [], root: int = 0,
saveto='/home/shobi/Trajectory/Datasets/geodesic_distance.csv',
mst_mode: bool = False, cluster_labels: ndarray = None):
geodesic_dist = geodesic_distance(embedding, knn=knn, root=root, mst_mode=mst_mode, cluster_labels=cluster_labels)
df_ = pd.DataFrame()
df_['true_time'] = time_labels
if cluster_labels is not None:
df_['cluster_labels'] = cluster_labels
df_ = df_.sort_values(['cluster_labels'], ascending=True).groupby('cluster_labels').mean()
print('df_groupby', df_.head())
df_['geo'] = geodesic_dist
df_['geo'] = df_['geo'].fillna(0)
correlation = df_['geo'].corr(df_['true_time'])
print(f'{datetime.now()}\tcorrelation geo 2d and true time, {correlation}')
df_.to_csv(saveto)
return correlation
[docs]def make_edgebundle_milestone(embedding: ndarray = None, sc_graph=None, via_object=None, sc_pt: list = None,
initial_bandwidth=0.03, decay=0.7, n_milestones: int = None, milestone_labels: list = [],
sc_labels_numeric: list = None, weighted: bool = True, global_visual_pruning: float = 0.5,
terminal_cluster_list: list = [], single_cell_lineage_prob: ndarray = None,
random_state: int = 0):
'''
Perform Edgebundling of edges in a milestone level to return a hammer bundle of milestone-level edges. This is more granular than the original parc-clusters but less granular than single-cell level and hence also less computationally expensive
requires some type of embedding (n_samples x 2) to be available
:param embedding: optional (not required if via_object is provided) embedding single cell. also looks nice when done on via_mds as more streamlined continuous diffused graph structure. Umap is a but "clustery"
:param graph: optional (not required if via_object is provided) igraph single cell graph level
:param via_object: via_object (best way to run this function by simply providing via_object)
:param sc_graph: igraph graph set as the via attribute self.ig_full_graph (affinity graph)
:param initial_bandwidth: increasing bw increases merging of minor edges
:param decay: increasing decay increases merging of minor edges #https://datashader.org/user_guide/Networks.html
:param milestone_labels: default list=[]. Usually autocomputed. but can provide as single-cell level labels (clusters, groups, which function as milestone groupings of the single cells)
:param sc_labels_numeric: default is None which automatically chooses via_object's pseudotime or time_series_labels (when available). otherwise set to a list of numerical values representing some sequential/chronological information
:param terminal_cluster_list: default list [] and automatically uses all terminal clusters. otherwise set to any of the terminal cluster numbers within a list
:return: dictionary containing keys: hb_dict['hammerbundle'] = hb hammerbundle class with hb.x and hb.y containing the coords
hb_dict['milestone_embedding'] dataframe with 'x' and 'y' columns for each milestone and hb_dict['edges'] dataframe with columns ['source','target'] milestone for each each and ['cluster_pop'], hb_dict['sc_milestone_labels'] is a list of milestone label for each single cell
'''
if embedding is None:
if via_object is not None: embedding = via_object.embedding
if sc_graph is None:
if via_object is not None: sc_graph = via_object.ig_full_graph
if embedding is None:
if via_object is None:
print(f'{datetime.now()}\tERROR: Please provide via_object')
return
else:
print(
f'{datetime.now()}\tWARNING: VIA will now autocompute an embedding. It would be better to precompute an embedding using embedding = via_umap() or via_mds() and setting this as the embedding attribute via_object = embedding.')
embedding = via_mds(via_object=via_object, random_seed=random_state)
n_samples = embedding.shape[0]
if n_milestones is None:
n_milestones = min(via_object.nsamples, min(250, int(0.1 * via_object.nsamples)))
print(f'{datetime.now()}\t n_milestones is {n_milestones}')
# milestone_indices = random.sample(range(n_samples), n_milestones) # this is sampling without replacement
if len(milestone_labels) == 0:
print(f'{datetime.now()}\tStart finding milestones')
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=n_milestones, random_state=random_state, n_init=10).fit(embedding)
milestone_labels = kmeans.labels_.flatten().tolist()
#df_ = pd.DataFrame()
#df_['kmeans'] = milestone_labels
#df_.to_csv('/home/user/Trajectory/Datasets/Zebrafish_Lange2023/kmeans_milestones'+str(n_milestones)+'.csv')
print(f'{datetime.now()}\tEnd milestones with {n_milestones}')
# plt.scatter(embedding[:, 0], embedding[:, 1], c=milestone_labels, cmap='tab20', s=1, alpha=0.3)
# plt.show()
if sc_labels_numeric is None:
if via_object is not None:
sc_labels_numeric = via_object.time_series_labels
else:
print(
f'{datetime.now()}\tWill use via-pseudotime for edges, otherwise consider providing a list of numeric labels (single cell level) or via_object')
if sc_pt is None:
sc_pt = via_object.single_cell_pt_markov
'''
numeric_val_of_milestone = []
if len(sc_labels_numeric)>0:
for cluster_i in set(milestone_labels):
loc_cluster_i = np.where(np.asarray(milestone_labels)==cluster_i)[0]
majority_ = func_mode(list(np.asarray(sc_labels_numeric)[loc_cluster_i]))
numeric_val_of_milestone.append(majority_)
'''
vertex_milestone_graph = ig.VertexClustering(sc_graph, membership=milestone_labels).cluster_graph(
combine_edges='sum')
print(f'{datetime.now()}\tRecompute weights')
vertex_milestone_graph = recompute_weights(vertex_milestone_graph, Counter(milestone_labels))
print(f'{datetime.now()}\tpruning milestone graph based on recomputed weights')
# was at 0.1 global_pruning for 2000+ milestones
edgeweights_pruned_milestoneclustergraph, edges_pruned_milestoneclustergraph, comp_labels = pruning_clustergraph(
vertex_milestone_graph,
global_pruning_std=global_visual_pruning,
preserve_disconnected=True,
preserve_disconnected_after_pruning=False, do_max_outgoing=False)
print(f'{datetime.now()}\tregenerate igraph on pruned edges')
vertex_milestone_graph = ig.Graph(edges_pruned_milestoneclustergraph,
edge_attrs={'weight': edgeweights_pruned_milestoneclustergraph}).simplify(
combine_edges='sum')
vertex_milestone_csrgraph = get_sparse_from_igraph(vertex_milestone_graph, weight_attr='weight')
weights_for_layout = np.asarray(vertex_milestone_csrgraph.data)
# clip weights to prevent distorted visual scale
weights_for_layout = np.clip(weights_for_layout, np.percentile(weights_for_layout, 20),
np.percentile(weights_for_layout,
80)) # want to clip the weights used to get the layout
# print('weights for layout', (weights_for_layout))
# print('weights for layout std', np.std(weights_for_layout))
weights_for_layout = weights_for_layout / np.std(weights_for_layout)
# print('weights for layout post-std', weights_for_layout)
# print(f'{datetime.now()}\tregenerate igraph after clipping')
vertex_milestone_graph = ig.Graph(list(zip(*vertex_milestone_csrgraph.nonzero())),
edge_attrs={'weight': list(weights_for_layout)})
# layout = vertex_milestone_graph.layout_fruchterman_reingold()
# embedding = np.asarray(layout.coords)
# print(f'{datetime.now()}\tmake node dataframe')
data_node = [node for node in range(embedding.shape[0])]
nodes = pd.DataFrame(data_node, columns=['id'])
nodes.set_index('id', inplace=True)
nodes['x'] = embedding[:, 0]
nodes['y'] = embedding[:, 1]
nodes['pt'] = sc_pt
if via_object is not None:
terminal_cluster_list = via_object.terminal_clusters
single_cell_lineage_prob = via_object.single_cell_bp_rownormed # _rownormed#_rownormed does not make a huge difference whether or not rownorming is applied. (default not rownormed)
if (len(terminal_cluster_list) > 0) and (single_cell_lineage_prob is not None):
for i, c_i in enumerate(terminal_cluster_list):
nodes['sc_lineage_probability_' + str(c_i)] = single_cell_lineage_prob[:, i]
if sc_labels_numeric is not None:
print(
f'{datetime.now()}\tSetting numeric label as time_series_labels or other sequential metadata for coloring edges')
nodes['numeric label'] = sc_labels_numeric
else:
print(f'{datetime.now()}\tSetting numeric label as single cell pseudotime for coloring edges')
nodes['numeric label'] = sc_pt
nodes['kmeans'] = milestone_labels
group_pop = []
for i in sorted(set(milestone_labels)):
group_pop.append(milestone_labels.count(i))
nodes_mean = nodes.groupby('kmeans').mean()
nodes_mean['cluster population'] = group_pop
edges = pd.DataFrame([e.tuple for e in vertex_milestone_graph.es], columns=['source', 'target'])
edges['weight0'] = vertex_milestone_graph.es['weight']
edges = edges[edges['source'] != edges['target']]
# seems to work better when allowing the bundling to occur on unweighted representation and later using length of segments to color code significance
if weighted == True:
edges['weight'] = edges[
'weight0'] # 1 # [1/i for i in edges['weight0']]np.where((edges['source_cluster'] != edges['target_cluster']) , 1,0.1)#[1/i for i in edges['weight0']]#
else:
edges['weight'] = 1
print(f'{datetime.now()}\tMaking smooth edges')
hb = hammer_bundle(nodes_mean, edges, weight='weight', initial_bandwidth=initial_bandwidth,
decay=decay) # default bw=0.05, dec=0.7
# hb.x and hb.y contain all the x and y coords of the points that make up the edge lines.
# each new line segment is separated by a nan value
# https://datashader.org/_modules/datashader/bundling.html#hammer_bundle
# nodes_mean contains the averaged 'x' and 'y' milestone locations based on the embedding
hb_dict = {}
hb_dict['hammerbundle'] = hb
hb_dict['milestone_embedding'] = nodes_mean
hb_dict['edges'] = edges[['source', 'target']]
hb_dict['sc_milestone_labels'] = milestone_labels
return hb_dict
[docs]def plot_gene_trend_heatmaps(via_object, df_gene_exp: pd.DataFrame, marker_lineages: list = [], fontsize: int = 8,
cmap: str = 'viridis', normalize: bool = True, ytick_labelrotation: int = 0,
fig_width: int = 7):
'''
Plot the gene trends on heatmap: a heatmap is generated for each lineage (identified by terminal cluster number). Default selects all lineages
:param via_object:
:param df_gene_exp: pandas DataFrame single-cell level expression [cells x genes]
:param marker_lineages: list default = None and plots all detected all lineages. Optionally provide a list of integers corresponding to the cluster number of terminal cell fates
:param fontsize: int default = 8
:param cmap: str default = 'viridis'
:param normalize: bool = True
:param ytick_labelrotation: int default = 0
:return: fig and list of axes
'''
import seaborn as sns
if len(marker_lineages) == 0: marker_lineages = via_object.terminal_clusters
dict_trends = get_gene_trend(via_object=via_object, marker_lineages=marker_lineages, df_gene_exp=df_gene_exp)
branches = list(dict_trends.keys())
print('branches', branches)
genes = dict_trends[branches[0]]['trends'].index
height = len(genes) * len(branches)
# Standardize the matrix (standardization along each gene. Since SS function scales the columns, we first transpose the df)
# Set up plot
fig = plt.figure(figsize=[fig_width, height])
ax_list = []
for i, branch in enumerate(branches):
ax = fig.add_subplot(len(branches), 1, i + 1)
df_trends = dict_trends[branch]['trends']
# normalize each genes (feature)
if normalize == True:
df_trends = pd.DataFrame(
StandardScaler().fit_transform(df_trends.T).T,
index=df_trends.index,
columns=df_trends.columns)
ax.set_title('Lineage: ' + str(branch) + '-' + str(dict_trends[branch]['name']), fontsize=int(fontsize * 1.3))
# sns.set(size=fontsize) # set fontsize 2
b = sns.heatmap(df_trends, yticklabels=True, xticklabels=False, cmap=cmap)
b.tick_params(labelsize=fontsize, labelrotation=ytick_labelrotation)
b.figure.axes[-1].tick_params(labelsize=fontsize)
ax_list.append(ax)
b.set_xlabel("pseudotime", fontsize=int(fontsize * 1.3))
return fig, ax_list
[docs]def plot_scatter(embedding: ndarray, labels: list, cmap='rainbow', s=5, alpha=0.3, edgecolors='None', title: str = '',
text_labels: bool = True, color_dict=None, via_object=None, sc_index_terminal_states: list = None,
true_labels: list = [], show_legend: bool = True, hide_axes_ticks:bool=True, color_labels_reverse:bool = False):
'''
General scatter plotting tool for numeric and categorical labels on the single-cell level
:param embedding: ndarray n_samples x 2
:param labels: list single cell labels list of number or strings
:param cmap: str default = 'rainbow'
:param s: int size of scatter dot
:param alpha: float with 0 transparent to 1 opaque default =0.3
:param edgecolors:
:param title: str
:param text_labels: bool default =True
:param via_object:
:param sc_index_terminal_states: list of integers corresponding to one cell in each of the terminal states
:param color_dict: {'true_label_group_1': #COLOR,'true_label_group_2': #COLOR2,....} where the dictionary keys correspond to the provided labels
:param true_labels: list of single cell labels used to annotate the terminal states
:return: matplotlib pyplot fig, ax
'''
fig, ax = plt.subplots()
if (isinstance(labels[0], str)) == True:
categorical = True
else:
categorical = False
ax.set_facecolor('white')
if color_dict is not None:
#:param color_dict: {'true_label_group_1': #COLOR,'true_label_group_2': #COLOR2,....} where the dictionary keys correspond to the provided labels
for key in color_dict:
loc_key = np.where(np.asarray(labels) == key)[0]
ax.scatter(embedding[loc_key, 0], embedding[loc_key, 1], color=color_dict[key], label=key, s=s,
alpha=alpha, edgecolors=edgecolors)
x_mean = embedding[loc_key, 0].mean()
y_mean = embedding[loc_key, 1].mean()
if text_labels == True: ax.text(x_mean, y_mean, key, style='italic', fontsize=10, color="black")
elif categorical == True:
color_dict = {}
set_labels = list(set(labels))
set_labels.sort(reverse=color_labels_reverse)#True) #used to be True until Feb 2024
for index, value in enumerate(set_labels):
color_dict[value] = index
palette = cm.get_cmap(cmap, len(color_dict.keys()))
cmap_ = palette(range(len(color_dict.keys())))
for key in color_dict:
loc_key = np.where(np.asarray(labels) == key)[0]
ax.scatter(embedding[loc_key, 0], embedding[loc_key, 1], color=cmap_[color_dict[key]], label=key, s=s,
alpha=alpha, edgecolors=edgecolors)
x_mean = embedding[loc_key, 0].mean()
y_mean = embedding[loc_key, 1].mean()
if text_labels == True: ax.text(x_mean, y_mean, key, style='italic', fontsize=10, color="black")
else:
im = ax.scatter(embedding[:, 0], embedding[:, 1], c=labels,
cmap=cmap, s=s, alpha=alpha)
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im, cax=cax, orientation='vertical',
label='pseudotime')
if via_object is not None:
tsi_list = []
for tsi in via_object.terminal_clusters:
loc_i = np.where(np.asarray(via_object.labels) == tsi)[0]
val_pt = [via_object.single_cell_pt_markov[i] for i in loc_i]
th_pt = np.percentile(val_pt, 50) # 50
loc_i = [loc_i[i] for i in range(len(val_pt)) if val_pt[i] >= th_pt]
temp = np.mean(via_object.data[loc_i], axis=0)
labelsq, distances = via_object.knn_struct.knn_query(temp, k=1)
tsi_list.append(labelsq[0][0])
ax.set_title(
'root:' + str(via_object.root_user[0]) + 'knn' + str(via_object.knn) + 'Ncomp' + str(via_object.ncomp))
for i in tsi_list:
# print(i, ' has traj and cell type', self.df_annot.loc[i, ['Main_trajectory', 'Main_cell_type']])
ax.text(embedding[i, 0], embedding[i, 1], str(true_labels[i]) + '_Cell' + str(i))
ax.scatter(embedding[i, 0], embedding[i, 1], c='black', s=10)
if (via_object is None) & (sc_index_terminal_states is not None):
for i in sc_index_terminal_states:
ax.text(embedding[i, 0], embedding[i, 1], str(true_labels[i]) + '_Cell' + str(i))
ax.scatter(embedding[i, 0], embedding[i, 1], c='black', s=10)
if len(title) == 0:
ax.set_title(label='scatter plot', color='blue')
else:
ax.set_title(label=title, color='blue')
ax.grid(False)
# Hide axes ticks
if hide_axes_ticks:
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')
# Hide grid lines
ax.grid(False)
fig.patch.set_visible(True)
if show_legend:
ax.legend(fontsize=12, frameon=False)
legend = ax.legend(bbox_to_anchor=(0, 0, 1.2, 1), loc='lower right', borderaxespad=0)
for i, handle in enumerate(legend.legendHandles):
# handle.set_edgecolor("#6c2167") # set_edgecolors
# handle.set_facecolor(colors[i])
# handle.set_hatch(hatches[i])
handle.set_alpha(1)
handle.set_sizes([40])
return fig, ax
def _make_knn_embeddedspace(embedding):
# knn struct built in the embedded space to be used for drawing the lineage trajectories onto the 2D plot
knn = hnswlib.Index(space='l2', dim=embedding.shape[1])
knn.init_index(max_elements=embedding.shape[0], ef_construction=200, M=16)
knn.add_items(embedding)
knn.set_ef(50)
return knn
[docs]def via_forcelayout(X_pca, viagraph_full: csr_matrix = None, k: int = 10,
n_milestones=2000, time_series_labels: list = [],
knn_seq: int = 5, saveto='', random_seed: int = 0) -> ndarray:
'''
Compute force directed layout. #TODO not complete
:param X_pca:
:param viagraph_full: optional. if calling before via, then None. if calling after or from within via, then we can use the via-graph to reinforce the layout
:param k:
:param random_seed:
:param t_diffusion:
:param n_milestones:
:param time_series_labels:
:param knn_seq:
:return: ndarray
'''
# use the csr_full_graph from via and subsample it.
# but this results in a very fragmented graph because the subsampling index is too small a fraction of the total number of possible edges.
# only works if you take a high enough percentage of the original samples
print(f"{datetime.now()}\tCommencing Force Layout")
np.random.seed(random_seed)
milestone_indices = random.sample(range(X_pca.shape[0]), n_milestones) # this is sampling without replacement
if viagraph_full is not None:
milestone_knn = viagraph_full[milestone_indices] #
milestone_knn = milestone_knn[:, milestone_indices]
milestone_knn = normalize(milestone_knn, axis=1)
knn_struct = construct_knn_utils(X_pca[milestone_indices, :], knn=k)
# we need to add the new knn (milestone_knn_new) built on the subsampled indices to ensure connectivity. o/w graph is fragmented if only relying on the subsampled graph
if time_series_labels is None: time_series_labels = []
if len(time_series_labels) >= 1: time_series_labels = np.array(time_series_labels)[milestone_indices].tolist()
milestone_knn_new = affinity_milestone_knn(data=X_pca[milestone_indices, :], knn_struct=knn_struct, k=k,
time_series_labels=time_series_labels, knn_seq=knn_seq)
print('milestone knn new', milestone_knn_new.shape, milestone_knn_new.data[0:10])
if viagraph_full is None:
milestone_knn = milestone_knn_new
else:
milestone_knn = milestone_knn + milestone_knn_new
print('final reinforced milestone knn', milestone_knn.shape, 'number of nonzero edges', len(milestone_knn.data))
print('force layout')
g_layout = ig.Graph(list(zip(*milestone_knn.nonzero()))) # , edge_attrs={'weight': weights_for_layout})
layout = g_layout.layout_fruchterman_reingold()
force_layout = np.asarray(layout.coords)
# compute knn used to estimate the embedding values of the full sample set based on embedding values computed just for a milestone subset of the full sample
neighbor_array, distance_array = knn_struct.knn_query(X_pca, k=k)
print('shape of ', X_pca.shape, neighbor_array.shape)
row_mean = np.mean(distance_array, axis=1)
row_var = np.var(distance_array, axis=1)
row_znormed_dist_array = -(distance_array - row_mean[:, np.newaxis]) / row_var[:, np.newaxis]
# when k is very small, then you can get very large affinities due to var being ~0
row_znormed_dist_array = np.nan_to_num(row_znormed_dist_array, copy=True, nan=1, posinf=1, neginf=1)
row_znormed_dist_array[row_znormed_dist_array > 10] = 0
affinity_array = np.exp(row_znormed_dist_array)
affinity_array = normalize(affinity_array, norm='l1', axis=1) # row stoch
row_list = []
n_neighbors = neighbor_array.shape[1]
n_cells = neighbor_array.shape[0]
print('ncells and neighs', n_cells, n_neighbors)
row_list.extend(list(np.transpose(np.ones((n_neighbors, n_cells)) * range(0, n_cells)).flatten()))
col_list = neighbor_array.flatten().tolist()
list_affinity = affinity_array.flatten().tolist()
csr_knn = csr_matrix((list_affinity, (row_list, col_list)),
shape=(n_cells, len(milestone_indices))) # n_samples*n_milestones
milestone_force = csr_matrix(force_layout) ##TODO remove this we are just testing force layout
full_force = csr_knn * milestone_force # is a matrix
full_force = np.asarray(full_force.todense())
plt.scatter(full_force[:, 0].tolist(), full_force[:, 1].tolist(), s=1, alpha=0.3, c='red')
plt.title('full mds')
plt.show()
full_force = np.reshape(full_force, (n_cells, 2))
if len(saveto) > 0:
U_df = pd.DataFrame(full_force)
U_df.to_csv(saveto)
return full_force
[docs]def via_mds(via_object=None, X_pca: ndarray = None, viagraph_full: csr_matrix = None, k: int = 15,
random_seed: int = 0, diffusion_op: int = 1, n_milestones=2000, time_series_labels: list = [],
knn_seq: int = 5, k_project_milestones: int = 3, t_difference: int = 2, saveto='',
embedding_type: str = 'mds', double_diffusion: bool = False) -> ndarray:
'''
Fast computation of a 2D embedding
FOR EXAMPLE:
via_object.embedding = via.via_mds(via_object = v0)
plot_scatter(embedding = via_object.embedding, labels = via_object.true_labels)
:param via_object:
:param X_pca: dimension reduced (only if via_object is not passed)
:param viagraph_full: optional. if calling before or without via, then None and a milestone graph will be computed. if calling after or from within via, then we can use the via-graph to reinforce the layout of the milestone graph
:param k: number of knn for the via_mds reinforcement graph on milestones. default =15. integers 5-20 are reasonable
:param random_seed: randomseed integer
:param t_diffusion: default integer value = 1 with higher values generate more smoothing
:param n_milestones: number of milestones used to generate the initial embedding
:param time_series_labels: numerical values in list form representing some sequentual information
:param knn_seq: if time-series data is available, this will augment the knn with sequential neighbors (2-10 are reasonable values) default =5
:param embedding_type: default = 'mds' or set to 'umap'
:param double_diffusion: default is False. To achieve sharper strokes/lineages, set to True
:param k_project_milestones: number of milestones in the milestone-knngraph used to compute the single-cell projection
:param n_iterations: number of iterations to run
:param neighbors_distances: array of distances of each neighbor for each cell (n_cells x knn) used when called from within via.run() for autocompute via-mds
:return: numpy array of size n_samples x 2
'''
# use the csr_full_graph from via and subsample it.
# but this results in a very fragmented graph because the subsampling index is too small a fraction of the total number of possible edges.
# only works if you take a high enough percentage of the original samples
# however, omitting the integration of csr_full_graph also compromises the ability of the embedding to better reflect the underlying trajectory in terms of global structure
print(f"{datetime.now()}\tCommencing Via-MDS")
if via_object is not None:
if X_pca is None: X_pca = via_object.data
if viagraph_full is None: viagraph_full = via_object.csr_full_graph
n_samples = X_pca.shape[0]
if n_milestones is None:
n_milestones = min(n_samples, max(2000, int(0.01 * n_samples)))
if n_milestones > n_samples:
n_milestones = min(n_samples, max(2000, int(0.01 * n_samples)))
print(f"{datetime.now()}\tResetting n_milestones to {n_milestones} as n_samples > original n_milestones")
'''
if n_milestones < n_samples:
if via_object is not None: milestone_indices = density_sampling(neighbors_distances= via_object.full_neighbor_array, desired_samples = n_milestones)
else: milestone_indices = density_sampling(neighbors_distances= neighbors_distances, desired_samples = n_milestones)
print(f'number of milestone indices from density sampling {milestone_indices.shape}')
print('exp=True, dens sampling')
'''
np.random.seed(random_seed)
milestone_indices = random.sample(range(X_pca.shape[0]), n_milestones) # this is sampling without replacement
if viagraph_full is not None:
milestone_knn = viagraph_full[milestone_indices] #
milestone_knn = milestone_knn[:, milestone_indices]
milestone_knn = normalize(milestone_knn,
axis=1) # using these effectively emphasises the edges that are pass an even more stringent requirement on Nearest neighbors (since they are selected from the full set of cells, rather than a subset of milestones)
X_pca[milestone_indices, :]
knn_struct = construct_knn_utils(X_pca[milestone_indices, :], knn=k)
# we need to add the new knn (milestone_knn_new) built on the subsampled indices to ensure connectivity. o/w graph is fragmented if only relying on the subsampled graph
if time_series_labels is None: time_series_labels = []
if len(time_series_labels) >= 1: time_series_labels = np.array(time_series_labels)[milestone_indices].tolist()
milestone_knn_new = affinity_milestone_knn(data=X_pca[milestone_indices, :], knn_struct=knn_struct, k=k,
time_series_labels=time_series_labels, knn_seq=knn_seq,
t_difference=t_difference)
if viagraph_full is None:
milestone_knn = milestone_knn_new
else:
milestone_knn = milestone_knn + milestone_knn_new
# build a knn to project the input n_samples based on milestone knn
neighbor_array, distance_array = knn_struct.knn_query(X_pca, k=k_project_milestones) # [n_samples x n_milestones]
row_mean = np.mean(distance_array, axis=1)
row_var = np.var(distance_array, axis=1)
row_znormed_dist_array = -(distance_array - row_mean[:, np.newaxis]) / row_var[:, np.newaxis]
# when k is very small, then you can get very large affinities due to var being ~0
row_znormed_dist_array = np.nan_to_num(row_znormed_dist_array, copy=True, nan=1, posinf=1, neginf=1)
row_znormed_dist_array[row_znormed_dist_array > 10] = 0
affinity_array = np.exp(row_znormed_dist_array)
affinity_array = normalize(affinity_array, norm='l1', axis=1) # row stoch
row_list = []
n_neighbors = neighbor_array.shape[1]
n_cells = neighbor_array.shape[0]
row_list.extend(list(np.transpose(np.ones((n_neighbors, n_cells)) * range(0, n_cells)).flatten()))
col_list = neighbor_array.flatten().tolist()
list_affinity = affinity_array.flatten().tolist()
csr_knn = csr_matrix((list_affinity, (row_list, col_list)),
shape=(n_cells, len(milestone_indices))) # n_samples*n_milestones
print(f"{datetime.now()}\tStart computing with diffusion power:{diffusion_op}")
# r2w_input = pd.read_csv( '/home/shobi/Trajectory/Datasets/EB_Phate/RW2/pc20_knn100kseq50krev50RW2_sparse_matrix029_P1_Q10.csv')
# r2w_input = r2w_input.drop(['Unnamed: 0'], axis=1).values
# input = r2w_input[:, 0:30]
# input = input[milestone_indices, :]
# print('USING RW2 COMPS')
if embedding_type == 'mds':
milestone_mds = sgd_mds(via_graph=milestone_knn, X_pca=X_pca[milestone_indices, :], diff_op=diffusion_op,
ndims=2,
random_seed=random_seed, double_diffusion=double_diffusion) # returns an ndarray
elif embedding_type == 'umap':
milestone_mds = via_umap(X_input=X_pca[milestone_indices, :], graph=milestone_knn)
print(f"{datetime.now()}\tEnd computing mds with diffusion power:{diffusion_op}")
# TESTING
# plt.scatter(milestone_mds[:, 0], milestone_mds[:, 1], s=1)
# plt.title('sampled')
# plt.show()
milestone_mds = csr_matrix(milestone_mds)
full_mds = csr_knn * milestone_mds # is a matrix
full_mds = np.asarray(full_mds.todense())
# TESTING
# plt.scatter(full_mds[:, 0].tolist(), full_mds[:, 1].tolist(), s=1, alpha=0.3, c='green')
# plt.title('full')
# plt.show()
full_mds = np.reshape(full_mds, (n_cells, 2))
if len(saveto) > 0:
U_df = pd.DataFrame(full_mds)
U_df.to_csv(saveto)
return full_mds
[docs]def via_atlas_emb(via_object=None, X_input: ndarray = None, graph: csr_matrix = None, n_components: int = 2,
alpha: float = 1.0, negative_sample_rate: int = 5,
gamma: float = 1.0, spread: float = 1.0, min_dist: float = 0.1, init_pos: Union[str, ndarray] = 'via',
random_state: int = 0,
n_epochs: int = 100, distance_metric: str = 'euclidean', layout: Optional[list] = None,
cluster_membership: Optional[list] = None, parallel: bool = False, saveto='',
n_jobs: int = 2) -> ndarray:
'''
Run dimensionality reduction using the VIA modified HNSW graph using via cluster graph initialization when Via_object is provided
:param via_object: if via_object is provided then X_input and graph are ignored
:param X_input: ndarray nsamples x features (PCs)
:param graph: csr_matrix of knngraph. This usually is via's pruned, sequentially augmented sc-knn graph accessed as an attribute of via via_object.csr_full_graph
:param n_components:
:param alpha:
:param negative_sample_rate:
:param gamma: Weight to apply to negative samples.
:param spread: The effective scale of embedded points. In combination with min_dist this determines how clustered/clumped the embedded points are.
:param min_dist: The effective minimum distance between embedded points. Smaller values will result in a more clustered/clumped embedding where nearby points on the manifold are drawn closer together, while larger values will result on a more even dispersal of points
:param init_pos: either a string (default) 'via' (uses via graph to initialize), or 'spectral'. Or a n_cellx2 dimensional ndarray with initial coordinates
:param random_state:
:param n_epochs: The number of training epochs to be used in optimizing the low dimensional embedding. Larger values result in more accurate embeddings. If 0 is specified a value will be selected based on the size of the input dataset (200 for large datasets, 500 for small).
:param distance_metric:
:param layout: ndarray . custom initial layout. (n_cells x2). also requires cluster_membership labels
:param cluster_membership: via_object.labels (cluster level labels of length n_samples corresponding to the layout)
:return: ndarray of shape (nsamples,n_components)
'''
if via_object is None:
if (X_input is None) or (graph is None):
print(f"{datetime.now()}\tERROR: please provide both X_input and graph")
if via_object is not None:
if X_input is None: X_input = via_object.data
print('X-input', X_input.shape)
graph = via_object.csr_full_graph
if cluster_membership is None: cluster_membership = via_object.labels
# X_input = via0.data
n_cells = X_input.shape[0]
print('len membership and n_cells', len(cluster_membership), n_cells)
print(f'n cell {n_cells}')
# graph = graph+graph.T
# graph = via0.csr_full_graph
print(f"{datetime.now()}\tComputing embedding on sc-Viagraph")
from umap.umap_ import find_ab_params, simplicial_set_embedding
# graph is a csr matrix
# weight all edges as 1 in order to prevent umap from pruning weaker edges away
layout_array = np.zeros(shape=(n_cells, 2))
if (init_pos == 'via') and (via_object is None):
# list of lists [[x,y], [x1,y1], []]
if (layout is None) or (cluster_membership is None):
print('please provide via object or values for arguments: layout and cluster_membership')
else:
for i in range(n_cells):
layout_array[i, 0] = layout[cluster_membership[i]][0]
layout_array[i, 1] = layout[cluster_membership[i]][1]
init_pos = layout_array
print(f'{datetime.now()}\tusing via cluster graph to initialize embedding')
elif (init_pos == 'via') and (via_object is not None):
layout = via_object.graph_node_pos
cluster_membership = via_object.labels
for i in range(n_cells):
layout_array[i, 0] = layout[cluster_membership[i]][0]
layout_array[i, 1] = layout[cluster_membership[i]][1]
init_pos = layout_array
print(f'{datetime.now()}\tusing via cluster graph to initialize embedding')
elif init_pos == 'spatial':
init_pos = layout
a, b = find_ab_params(spread, min_dist)
# print('a,b, spread, dist', a, b, spread, min_dist)
t0 = time.time()
# m = graph.data.max()
graph.data = np.clip(graph.data, np.percentile(graph.data, 1), np.percentile(graph.data, 99))
# graph.data = 1 + graph.data/m
# graph.data.fill(1)
# print('average graph.data', round(np.mean(graph.data),4), round(np.max(graph.data),2))
# graph.data = graph.data + np.mean(graph.data)
# transpose =graph.transpose()
# prod_matrix = graph.multiply(transpose)
# graph = graph + transpose - prod_matrix
if parallel:
import numba
print('before setting numba threads')
print(f'there are {numba.get_num_threads()} threads')
numba.set_num_threads(n_jobs)
print(f'there are now {numba.get_num_threads()} threads')
random_state = np.random.RandomState(random_state)
if parallel:
print('using parallel, the random_state will not be used.')
do_randomize_init = True
if do_randomize_init:
init_pos = init_pos + random_state.normal(
scale=0.001, size=init_pos.shape
).astype(np.float32)
X_emb, aux_data = simplicial_set_embedding(data=X_input, graph=graph, n_components=n_components,
initial_alpha=alpha,
a=a, b=b, n_epochs=n_epochs, metric_kwds={}, gamma=gamma,
metric=distance_metric,
negative_sample_rate=negative_sample_rate, init=init_pos,
random_state=random_state,
verbose=1, output_dens=False, densmap_kwds={}, densmap=False,
parallel=parallel)
if len(saveto) > 0:
U_df = pd.DataFrame(X_emb)
U_df.to_csv(saveto)
return X_emb
def run_umap_hnsw(via_object=None, X_input: ndarray = None, graph: csr_matrix = None, n_components: int = 2,
alpha: float = 1.0, negative_sample_rate: int = 5,
gamma: float = 1.0, spread: float = 1.0, min_dist: float = 0.1,
init_pos: Union[str, ndarray] = 'spectral', random_state: int = 0,
n_epochs: int = 0, distance_metric: str = 'euclidean', layout: Optional[list] = None,
cluster_membership: list = [], saveto='') -> ndarray:
print(f"{datetime.now()}\tWarning: in future call via_umap() to run this function")
return via_umap(via_object=via_object, X_input=X_input, graph=graph, n_components=n_components, alpha=alpha,
negative_sample_rate=negative_sample_rate,
gamma=gamma, spread=spread, min_dist=min_dist, init_pos=init_pos, random_state=random_state,
n_epochs=n_epochs, distance_metric=distance_metric, layout=layout,
cluster_membership=cluster_membership, saveto=saveto)
[docs]def plot_population_composition(via_object, time_labels: list = None, celltype_list: list = None, cmap: str = 'rainbow',
legend: bool = True,
alpha: float = 0.5, linewidth: float = 0.2, n_intervals: int = 20, xlabel: str = 'time',
ylabel: str = '', title: str = 'Cell populations', color_dict: dict = None,
fraction: bool = True):
'''
:param via_object: optional. this is required unless both time_labels and cell_labels are provided as arguments to the function
:param time_labels: list length n_cells of pseudotime or known stage numeric labels
:param cell_labels: list of cell type or cluster length n_cells
:return: ax
'''
if time_labels is None:
pt = via_object.single_cell_pt_markov
maxpt = max(pt)
pt = [i / maxpt for i in pt]
else:
pt = time_labels
maxpt = max(pt)
minpt = min(pt)
if celltype_list is None: celltype_list = via_object.true_label
df_full = pd.DataFrame()
df_full['pt'] = [i for i in pt]
df_full['celltype'] = celltype_list
print(f'head df full {df_full.head()}')
n_intervals = n_intervals
interval_step = (max(pt) - min(pt)) / n_intervals
interval_i = 0
from collections import Counter
set_celltype_sorted = list(sorted(list(set(celltype_list))))
df_population = pd.DataFrame(0, index=[minpt + (i) * interval_step for i in range(n_intervals)],
columns=set_celltype_sorted)
index_i = 0
while interval_i <= max(pt) + 0.01:
df_temp = df_full[((df_full['pt'] < interval_i + interval_step) & (df_full['pt'] >= minpt + interval_i))]
dict_temp = Counter(df_temp['celltype'])
if fraction:
n_samples_temp = df_temp.shape[0]
for key in dict_temp:
dict_temp[key] = dict_temp[key] / n_samples_temp
print('dict temp', dict_temp)
dict_temp = dict(sorted(dict_temp.items()))
interval_i += interval_step
for key_pop_i in dict_temp:
df_population.loc[minpt + (index_i + 1) * interval_step, key_pop_i] = dict_temp[key_pop_i]
index_i += 1
title = title + 'n_intervals' + str(n_intervals)
if color_dict is not None:
ax = df_population.plot.area(grid=False, legend=legend, color=color_dict, alpha=alpha, linewidth=linewidth,
xlabel=xlabel, ylabel=ylabel, title=title)
else:
ax = df_population.plot.area(grid=False, legend=legend, colormap=cmap, alpha=alpha, linewidth=linewidth,
xlabel=xlabel, ylabel=ylabel, title=title)
return ax
[docs]def plot_differentiation_flow(via_object, idx: list = None, dpi=150, marker_lineages=[], label_node: list = [],
do_log_flow: bool = True, fontsize: int = 8, alpha_factor: float = 0.9,
majority_cluster_population_dict: dict = None, cmap_sankey='rainbow',
title_str: str = 'Differentiation Flow', root_cluster_list: list = None):
'''
#SANKEY PLOTS
G is the igraph knn (low K) used for shortest path in high dim space. no idx needed as it's made on full sample
knn_hnsw is the knn made in the embedded space used for query to find the nearest point in the downsampled embedding
that corresponds to the single cells in the full graph
:param via_object:
:param embedding: n_samples x 2. embedding is 2D representation of the full dataset.
:param idx: if one uses a downsampled embedding of the original data, then idx is the selected indices of the downsampled samples used in the visualization
:param cmap_name:
:param dpi:
:param do_log_flow bool True (default) take the natural log (1+edge flow value)
:param label_node list of labels for each cell (could be cell type, stage level) length is n_cells
:param scatter_size: if None, then auto determined based on n_cells
:param marker_lineages: Default is to use all lineage pathways. other provide a list of lineage number (terminal cluster number).
:param alpha_factor: float transparency
:param root_cluster_list: list of roots by cluster number e.g. [5] means a good root is cluster number 5
:return: fig, axs
'''
import math
if len(marker_lineages) == 0:
marker_lineages = via_object.terminal_clusters
if root_cluster_list is None:
root_cluster_list = via_object.root
else:
marker_lineages = [i for i in marker_lineages if i in via_object.labels] # via_object.terminal_clusters]
print(f'{datetime.now()}\tMarker_lineages: {marker_lineages}')
'''
if embedding is None:
if via_object.embedding is None:
print('ERROR: please provide a single cell embedding or run re-via with do_compute_embedding==True using either embedding_type = via-umap OR via-mds')
return
else:
print(f'automatically setting embedding to via_object.embedding')
embedding = via_object.embedding
'''
# make the sankey node labels either using via_obect.true_label or the labels provided by the user
print(f'{datetime.now()}\tStart dictionary modes')
df_mode = pd.DataFrame()
df_mode['cluster'] = via_object.labels
# df_mode['celltype'] = pre_labels_celltype_df['fine'].tolist()#v0.true_label
if len(label_node) > 0:
df_mode['celltype'] = label_node # v0.true_label
else:
df_mode['celltype'] = via_object.true_label
majority_cluster_population_dict = df_mode.groupby(['cluster'])['celltype'].agg(
lambda x: pd.Series.mode(x)[0]) # agg(pd.Series.mode would give all modes) #series
majority_cluster_population_dict = majority_cluster_population_dict.to_dict()
print(f'{datetime.now()}\tEnd dictionary modes')
if idx is None: idx = np.arange(0, via_object.nsamples)
# G = via_object.full_graph_shortpath
n_original_comp, n_original_comp_labels = connected_components(via_object.csr_full_graph, directed=False)
# G = via_object.full_graph_paths(via_object.data, n_original_comp)
# knn_hnsw = _make_knn_embeddedspace(embedding)
y_root = []
x_root = []
root1_list = []
p1_sc_bp = np.nan_to_num(via_object.single_cell_bp[idx, :], nan=0.0, posinf=0.0, neginf=0.0)
# row normalize
row_sums = p1_sc_bp.sum(axis=1)
p1_sc_bp = p1_sc_bp / row_sums[:,
np.newaxis] # make rowsums a column vector where i'th entry is sum of i'th row in p1-sc-bp
print(f'{datetime.now()}\tCheck sc pb {p1_sc_bp[0, :].sum()} ')
p1_labels = np.asarray(via_object.labels)[idx]
p1_cc = via_object.connected_comp_labels
p1_sc_pt_markov = list(np.asarray(via_object.single_cell_pt_markov)[idx])
X_data = via_object.data
X_ds = X_data[idx, :]
p_ds = hnswlib.Index(space='l2', dim=X_ds.shape[1])
p_ds.init_index(max_elements=X_ds.shape[0], ef_construction=200, M=16)
p_ds.add_items(X_ds)
p_ds.set_ef(50)
num_cluster = len(set(via_object.labels))
G_orange = ig.Graph(n=num_cluster, edges=via_object.edgelist_maxout,
edge_attrs={'weight': via_object.edgeweights_maxout})
for ii, r_i in enumerate(root_cluster_list):
sankey_edges = []
'''
loc_i = np.where(p1_labels == via_object.root[ii])[0]
x = [embedding[xi, 0] for xi in loc_i]
y = [embedding[yi, 1] for yi in loc_i]
labels_root, distances_root = knn_hnsw.knn_query(np.array([np.mean(x), np.mean(y)]), k=1)
x_root.append(embedding[labels_root, 0][0])
y_root.append(embedding[labels_root, 1][0])
labelsroot1, distances1 = via_object.knn_struct.knn_query(X_ds[labels_root[0][0], :], k=1)
root1_list.append(labelsroot1[0][0])
print('f getting majority comp')
'''
'''
#VERY SLOW maybe try with dataframe mode df.groupby(['team'])['points'].agg(pd.Series.mode)
for labels_i in via_object.labels:
loc_labels = np.where(np.asarray(via_object.labels) == labels_i)[0]
majority_composition = func_mode(list(np.asarray(via_object.true_label)[loc_labels]))
majority_cluster_population_dict[labels_i] = majority_composition
print('f End getting majority comp')
'''
for fst_i in marker_lineages:
path_orange = G_orange.get_shortest_paths(root_cluster_list[ii], to=fst_i)[0]
'''
if fst_i in [1,22,71,89,136,10,83,115]: #CNS and periderm for Zebrahub Lange we want the root to be the early CNS
#path_orange = G_orange.get_shortest_paths(52, to=fst_i)[0]
#path_orange = G_orange.get_shortest_paths(via_object.root[ii], to=fst_i)[0]
path_orange = G_orange.get_shortest_paths(3, to=fst_i)[0] #for Zebralange use this CNS root
elif fst_i in [10,69,90,95]:
path_orange = G_orange.get_shortest_paths(75, to=fst_i)[0] # for Zebralange use this endoderm (pharynx, liver, intestine) root
else:
path_orange = G_orange.get_shortest_paths(via_object.root[ii], to=fst_i)[0]
'''
# path_orange = G_orange.get_shortest_paths(3, to=fst_i)[0]
# if the roots is in the same component as the terminal cluster, then print the path to output
if len(path_orange) > 0:
print(
f'{datetime.now()}\tCluster path on clustergraph starting from Root Cluster {root_cluster_list[ii]} to Terminal Cluster {fst_i}: {path_orange}')
do_sankey = True
'''
cluster_population_dict = {}
for group_i in set(via_object.labels):
loc_i = np.where(via_object.labels == group_i)[0]
cluster_population_dict[group_i] = len(loc_i)
'''
if do_sankey:
import holoviews as hv
hv.extension('bokeh')
from bokeh.plotting import show
from holoviews import opts, dim
print(f"{datetime.now()}\tHoloviews for TC {fst_i}")
cluster_adjacency = via_object.cluster_adjacency
# row normalize
row_sums = cluster_adjacency.sum(axis=1)
cluster_adjacency_rownormed = cluster_adjacency / row_sums[:, np.newaxis]
for n_i in range(len(path_orange) - 1):
source = path_orange[n_i]
dest = path_orange[n_i + 1]
if n_i < len(path_orange) - 2:
if do_log_flow:
val_edge = round(math.log1p(cluster_adjacency_rownormed[source, dest]),
2) # * cluster_population_dict[source] # natural logarithm (base e) of 1 + x
else:
val_edge = round(cluster_adjacency_rownormed[source, dest], 2)
# print("clipping val edge")
# if val_edge > 0.5: val_edge = 0.5
else:
if dest in via_object.terminal_clusters:
ts_array_original = np.asarray(via_object.terminal_clusters)
loc_ts_current = np.where(ts_array_original == dest)[0][0]
print(f'dest {dest}, is at loc {loc_ts_current} on the bp_array')
if do_log_flow:
val_edge = round(math.log1p(via_object.cluster_bp[source, loc_ts_current]),
2) # * cluster_population_dict[source]
else:
val_edge = round(via_object.cluster_bp[source, loc_ts_current], 2)
# print("clipping val edge")
# if val_edge > 0.5: val_edge = 0.5
else:
if do_log_flow:
val_edge = round(math.log1p(cluster_adjacency_rownormed[source, dest]),
2) # * cluster_population_dict[source] # natural logarithm (base e) of 1 + x
else:
val_edge = round(cluster_adjacency_rownormed[source, dest], 2)
# print("clipping val edge")
# val_edge = 0.5
# sankey_edges.append((majority_cluster_population_dict[source]+'_C'+str(source), majority_cluster_population_dict[dest]+'_C'+str(dest), val_edge))#, majority_cluster_population_dict[source],majority_cluster_population_dict[dest]))
sankey_edges.append((source, dest,
val_edge)) # ,majority_cluster_population_dict[source]+'_C'+str(source),'magenta' ))
# print(f'pre-final sankey set of edges and vals {len(sankey_edges)}, {sankey_edges}')
source_dest = list(set(sankey_edges))
# print(f'final sankey set of edges and vals {len(source_dest)}, {source_dest}')
source_dest_df = pd.DataFrame(source_dest, columns=['Source', 'Dest', 'Count']) # ,'Label','Color'])
nodes_in_source_dest = list(set(set(source_dest_df.Source) | set(source_dest_df.Dest)))
nodes_in_source_dest.sort()
convert_old_to_new = {}
convert_new_to_old = {}
majority_newcluster_population_dict = {}
for ei, ii in enumerate(nodes_in_source_dest):
convert_old_to_new[ii] = ei
convert_new_to_old[ei] = ii
majority_newcluster_population_dict[ei] = majority_cluster_population_dict[ii]
source_dest_new = []
for tuple_ in source_dest:
source_dest_new.append((convert_old_to_new[tuple_[0]], convert_old_to_new[tuple_[1]], tuple_[2]))
# print('new source dest after reindexing', source_dest_new)
# nodes = [majority_cluster_population_dict[i] for i in range(len(majority_cluster_population_dict))]
# nodes = [majority_cluster_population_dict[i] for i in nodes_in_source_dest]
nodes = [majority_newcluster_population_dict[key] + '_C' + str(convert_new_to_old[key]) for key in
majority_newcluster_population_dict]
# nodes = ['C' + str(convert_new_to_old[key]) for key in majority_newcluster_population_dict]
# print('nodes', len(nodes), nodes,)
nodes = hv.Dataset(enumerate(nodes), 'index', 'label')
from holoviews.plotting.util import process_cmap
print(f'{datetime.now()}\tStart sankey')
cmap_list = process_cmap("glasbey_hv")
p2 = hv.Sankey((source_dest_new, nodes), ['Source', "Dest"])
p2_2 = hv.Sankey((source_dest_new, nodes), ['Source', "Dest"])
print(f'{datetime.now()}\tmake sankey color dict')
# Make color map
# Extract Unique values dictionary values
# Using set comprehension + values() + sorted()
set_majority_truth = list(set(list(majority_newcluster_population_dict.values())))
set_majority_truth.sort(reverse=True)
color_dict = {}
for index, value in enumerate(set_majority_truth):
# assign each celltype a number
color_dict[value] = index
palette = cm.get_cmap(cmap_sankey, len(color_dict.keys()))
cmap_ = palette(range(len(color_dict.keys())))
cmap_colors_dict_sankey = {}
for key in majority_newcluster_population_dict:
cmap_colors_dict_sankey[int(key)] = matplotlib.colors.rgb2hex(
cmap_[color_dict[majority_newcluster_population_dict[key]]])
print(f'{datetime.now()}\tset options and render')
p2.opts(
opts.Sankey(show_values=False, edge_cmap=cmap_colors_dict_sankey, edge_color=dim('Source').str(),
node_color=dim('Source').str(),
edge_line_width=2, width=1800, height=1200, cmap=cmap_colors_dict_sankey, node_padding=15,
fontsize={'labels': 1}, title=title_str))
show(hv.render(p2))
p2_2.opts(
opts.Sankey(show_values=False, edge_cmap=cmap_colors_dict_sankey, edge_color=dim('Source').str(),
node_color=dim('Source').str(),
edge_line_width=2, width=1800, height=1200, node_padding=20, cmap=cmap_colors_dict_sankey,
title=title_str))
# show(hv.render(p2_2))
p2_2.opts(
opts.Sankey(labels='label', edge_cmap=cmap_colors_dict_sankey, edge_color=dim('Source').str(),
node_color=dim('Source').str(),
edge_line_width=2, width=1800, height=1200, node_padding=20, cmap=cmap_colors_dict_sankey,
title=title_str))
show(hv.render(p2_2))
'''
p0 = hv.Sankey(source_dest_df)
show(hv.render(p0))
p = hv.Sankey(source_dest_df, kdims=["Source", "Dest"], vdims=["Count"])
p.opts(
opts.Sankey(edge_color=dim('Source').str(), node_color=dim('Source').str(),
edge_line_width=2,
edge_cmap='tab20', node_cmap='tab20', width=1800, height=1800, title='test title', node_padding=3))
show(hv.render(p))
'''
# https://stackoverflow.com/questions/57085026/how-do-i-colour-the-individual-categories-in-a-holoviews-sankey-diagram
# https://stackoverflow.com/questions/76505156/draw-sankey-diagram-with-holoviews-and-bokeh
# https://holoviews.org/reference/elements/bokeh/Sankey.html
# https://malouche.github.io/notebooks/Sankey_graphs.html
# https://github.com/holoviz/holoviews/issues/3501
return
[docs]def plot_sc_lineage_probability(via_object, embedding: ndarray = None, idx: list = None, cmap_name='plasma', dpi=150,
scatter_size=None, marker_lineages=[], fontsize: int = 8, alpha_factor: float = 0.9,
majority_cluster_population_dict: dict = None, cmap_sankey='rainbow',
do_sankey: bool = False):
'''
G is the igraph knn (low K) used for shortest path in high dim space. no idx needed as it's made on full sample
knn_hnsw is the knn made in the embedded space used for query to find the nearest point in the downsampled embedding
that corresponds to the single cells in the full graph
:param via_object:
:param embedding: n_samples x 2. embedding is either the full or downsampled 2D representation of the full dataset.
:param idx: if one uses a downsampled embedding of the original data, then idx is the selected indices of the downsampled samples used in the visualization
:param cmap_name:
:param dpi:
:param scatter_size: if None, then auto determined based on n_cells
:param marker_lineages: Default is to use all lineage pathways. other provide a list of lineage number (terminal cluster number).
:param alpha_factor: float transparency
:return: fig, axs
'''
if len(marker_lineages) == 0:
marker_lineages = via_object.terminal_clusters
else:
marker_lineages = [i for i in marker_lineages if i in via_object.terminal_clusters]
print(f'{datetime.now()}\tMarker_lineages: {marker_lineages}')
if embedding is None:
if via_object.embedding is None:
print(
f'{datetime.now()}\tERROR: please provide a single cell embedding or run re-via with do_compute_embedding==True using either embedding_type = via-umap OR via-mds')
return
else:
print(f'{datetime.now()}\tAutomatically setting embedding to via_object.embedding')
embedding = via_object.embedding
if idx is None: idx = np.arange(0, via_object.nsamples)
# G = via_object.full_graph_shortpath
n_original_comp, n_original_comp_labels = connected_components(via_object.csr_full_graph, directed=False)
G = via_object.full_graph_paths(via_object.data, n_original_comp)
knn_hnsw = _make_knn_embeddedspace(embedding)
y_root = []
x_root = []
root1_list = []
p1_sc_bp = np.nan_to_num(via_object.single_cell_bp[idx, :], nan=0.0, posinf=0.0, neginf=0.0)
# row normalize
row_sums = p1_sc_bp.sum(axis=1)
p1_sc_bp = p1_sc_bp / row_sums[:,
np.newaxis] # make rowsums a column vector where i'th entry is sum of i'th row in p1-sc-bp
print(f'{datetime.now()}\tCheck sc pb {p1_sc_bp[0, :].sum()} ')
p1_labels = np.asarray(via_object.labels)[idx]
p1_cc = via_object.connected_comp_labels
p1_sc_pt_markov = list(np.asarray(via_object.single_cell_pt_markov)[idx])
X_data = via_object.data
X_ds = X_data[idx, :]
p_ds = hnswlib.Index(space='l2', dim=X_ds.shape[1])
p_ds.init_index(max_elements=X_ds.shape[0], ef_construction=200, M=16)
p_ds.add_items(X_ds)
p_ds.set_ef(50)
num_cluster = len(set(via_object.labels))
G_orange = ig.Graph(n=num_cluster, edges=via_object.edgelist_maxout,
edge_attrs={'weight': via_object.edgeweights_maxout})
for ii, r_i in enumerate(via_object.root):
loc_i = np.where(p1_labels == via_object.root[ii])[0]
x = [embedding[xi, 0] for xi in loc_i]
y = [embedding[yi, 1] for yi in loc_i]
labels_root, distances_root = knn_hnsw.knn_query(np.array([np.mean(x), np.mean(y)]), k=1)
x_root.append(embedding[labels_root, 0][0])
y_root.append(embedding[labels_root, 1][0])
labelsroot1, distances1 = via_object.knn_struct.knn_query(X_ds[labels_root[0][0], :], k=1)
root1_list.append(labelsroot1[0][0])
sankey_edges = []
print('f getting majority comp')
'''
#VERY SLOW maybe try with dataframe mode df.groupby(['team'])['points'].agg(pd.Series.mode)
for labels_i in via_object.labels:
loc_labels = np.where(np.asarray(via_object.labels) == labels_i)[0]
majority_composition = func_mode(list(np.asarray(via_object.true_label)[loc_labels]))
majority_cluster_population_dict[labels_i] = majority_composition
print('f End getting majority comp')
'''
for fst_i in via_object.terminal_clusters:
path_orange = G_orange.get_shortest_paths(via_object.root[ii], to=fst_i)[0]
if len(path_orange) > 0:
print(
f'{datetime.now()}\tCluster path on clustergraph starting from Root Cluster {via_object.root[ii]} to Terminal Cluster {fst_i}: {path_orange}')
'''
cluster_population_dict = {}
for group_i in set(via_object.labels):
loc_i = np.where(via_object.labels == group_i)[0]
cluster_population_dict[group_i] = len(loc_i)
'''
# single-cell branch probability evolution probability
n_terminal_clusters = len(marker_lineages)
fig_ncols = min(3, n_terminal_clusters)
fig_nrows, mod = divmod(n_terminal_clusters, fig_ncols)
if mod == 0:
if fig_nrows == 0:
fig_nrows += 1
else:
fig_nrows = fig_nrows
if mod != 0: fig_nrows += 1
fig, axs = plt.subplots(fig_nrows, fig_ncols, dpi=dpi)
ts_array_original = np.asarray(via_object.terminal_clusters)
ti = 0 # counter for terminal cluster
for r in range(fig_nrows):
for c in range(fig_ncols):
if ti < n_terminal_clusters:
ts_current = marker_lineages[ti]
loc_ts_current = np.where(ts_array_original == ts_current)[0][0]
loc_labels = np.where(np.asarray(via_object.labels) == ts_current)[0]
majority_composition = func_mode(list(np.asarray(via_object.true_label)[loc_labels]))
if fig_nrows == 1:
if fig_ncols == 1:
plot_sc_pb(axs, fig, embedding, p1_sc_bp[:, loc_ts_current],
ti=str(ts_current) + '-' + str(majority_composition), cmap_name=cmap_name,
scatter_size=scatter_size, fontsize=fontsize)
else:
plot_sc_pb(axs[c], fig, embedding, p1_sc_bp[:, loc_ts_current],
ti=str(ts_current) + '-' + str(majority_composition), cmap_name=cmap_name,
scatter_size=scatter_size, fontsize=fontsize, alpha_factor=alpha_factor)
else:
plot_sc_pb(axs[r, c], fig, embedding, p1_sc_bp[:, loc_ts_current],
ti=str(ts_current) + '-' + str(majority_composition), cmap_name=cmap_name,
scatter_size=scatter_size, fontsize=fontsize, alpha_factor=alpha_factor)
loc_i = np.where(p1_labels == ts_current)[0]
val_pt = [p1_sc_pt_markov[i] for i in loc_i]
th_pt = np.percentile(val_pt, 50) # 50
loc_i = [loc_i[i] for i in range(len(val_pt)) if val_pt[i] >= th_pt]
x = [embedding[xi, 0] for xi in
loc_i] # location of sc nearest to average location of terminal clus in the EMBEDDED space
y = [embedding[yi, 1] for yi in loc_i]
labels, distances = knn_hnsw.knn_query(np.array([np.mean(x), np.mean(y)]),
k=1) # knn_hnsw is knn of embedded space
x_sc = embedding[labels[0], 0] # terminal sc location in the embedded space
y_sc = embedding[labels[0], 1]
labelsq1, distances1 = via_object.knn_struct.knn_query(X_ds[labels[0][0], :],
k=1) # find the nearest neighbor in the PCA-space full graph
path = G.get_shortest_paths(root1_list[p1_cc[loc_ts_current]], to=labelsq1[0][0]) # weights='weight')
# G is the knn of all sc points
path_idx = [] # find the single-cell which is nearest to the average-location of a terminal cluster
# get the nearest-neighbor in this downsampled PCA-space graph. These will make the new path-way points
path = path[0]
# clusters of path
cluster_path = []
for cell_ in path:
cluster_path.append(via_object.labels[cell_])
revised_cluster_path = []
revised_sc_path = []
for enum_i, clus in enumerate(cluster_path):
num_instances_clus = cluster_path.count(clus)
if (clus == cluster_path[0]) | (clus == cluster_path[-1]):
revised_cluster_path.append(clus)
revised_sc_path.append(path[enum_i])
else:
if num_instances_clus > 1: # typically intermediate stages spend a few transitions at the sc level within a cluster
if clus not in revised_cluster_path: revised_cluster_path.append(clus) # cluster
revised_sc_path.append(path[enum_i]) # index of single cell
print(
f"{datetime.now()}\tRevised Cluster level path on sc-knnGraph from Root Cluster {via_object.root[p1_cc[ti - 1]]} to Terminal Cluster {ts_current} along path: {revised_cluster_path}")
ti += 1
fig.patch.set_visible(False)
if fig_nrows == 1:
if fig_ncols == 1:
axs.axis('off')
axs.grid(False)
else:
axs[c].axis('off')
axs[c].grid(False)
else:
axs[r, c].axis('off')
axs[r, c].grid(False)
return fig, axs
[docs]def plot_viagraph(via_object, type_data='gene', df_genes=None, gene_list:list = [],arrow_head:float=0.1,
edgeweight_scale:float=1.5, cmap=None, label_text:bool=True, size_factor_node: float = 1, tune_edges:bool = False,initial_bandwidth=0.05, decay=0.9, edgebundle_pruning=0.5):
'''
cluster level expression of gene/feature intensity
:param via_object:
:param type_data:
:param gene_exp: pd.Dataframe size n_cells x genes. Otherwise defaults to plotting pseudotime
:param gene_list: list of gene names corresponding to the column name
:param arrow_head:
:param edgeweight_scale:
:param cmap:
:param label_text: bool to add numeric values of the gene exp level
:param size_factor_node size of graph nodes
:param tune_edges: bool (false). if you want to change the number of edges visualized, then set this to True and modify the tuning parameters (initial_bandwidth, decay, edgebundle_pruning)
:param initial_bandwidth: (float = 0.05) increasing bw increases merging of minor edges. Only used when tune_edges = True
:param decay: (decay = 0.9) increasing decay increases merging of minor edges . Only used when tune_edges = True
:param edgebundle_pruning (float = 0.5). takes on values between 0-1. smaller value means more pruning away edges that can be visualised. Only used when tune_edges = True
:return: fig, axs
'''
'''
#draws the clustergraph for cluster level gene or pseudotime values
# type_pt can be 'pt' pseudotime or 'gene' for gene expression
# ax1 is the pseudotime graph
'''
n_genes = len(gene_list)
pt = via_object.markov_hitting_times
if n_genes == 0:
gene_list=['pseudotime']
df_genes = pd.DataFrame()
df_genes['pseudotime'] = via_object.single_cell_pt_markov
n_genes = 1
if tune_edges:
hammer_bundle, layout = make_edgebundle_viagraph(via_object = via_object, layout=via_object.layout, decay=decay,initial_bandwidth=initial_bandwidth, edgebundle_pruning=edgebundle_pruning) #hold the layout fixed. only change the edges
else:
hammer_bundle = via_object.hammerbundle_cluster
layout = via_object.layout#graph_node_pos
fig, axs = plt.subplots(1, n_genes)
if cmap is None: cmap = 'coolwarm' if type_data == 'gene' else 'viridis_r'
node_pos = layout.coords# via_object.graph_node_pos
node_pos = np.asarray(node_pos)
df_genes['cluster'] = via_object.labels
df_genes = df_genes.groupby('cluster', as_index=False).mean()
n_groups = len(set(via_object.labels)) # node_pos.shape[0]
group_pop = np.zeros([n_groups, 1])
via_object.cluster_population_dict = {}
for group_i in set(via_object.labels):
loc_i = np.where(via_object.labels == group_i)[0]
group_pop[group_i] = len(loc_i) # np.sum(loc_i) / 1000 + 1
via_object.cluster_population_dict[group_i] = len(loc_i)
for i in range(n_genes):
if n_genes ==1:
ax_i = axs
else: ax_i = axs[i]
gene_i = gene_list[i]
c_edge, l_width = [], []
for ei, pti in enumerate(pt):
if ei in via_object.terminal_clusters:
c_edge.append('red')
l_width.append(1.5)
else:
c_edge.append('gray')
l_width.append(0.0)
ax_i = plot_viagraph_(ax_i, hammer_bundle=hammer_bundle, layout=layout,
CSM=via_object.CSM,
velocity_weight=via_object.velo_weight, pt=pt, headwidth_bundle=arrow_head,
alpha_bundle=0.4, linewidth_bundle=edgeweight_scale)
group_pop_scale = .5 * group_pop * 1000 / max(group_pop)
pos = ax_i.scatter(node_pos[:, 0], node_pos[:, 1], s=group_pop_scale * size_factor_node,
c=df_genes[gene_i].values, cmap=cmap,
edgecolors=c_edge, alpha=1, zorder=3, linewidth=l_width)
if label_text == True:
for ii in range(node_pos.shape[0]):
ax_i.text(node_pos[ii, 0] + 0.1, node_pos[ii, 1] + 0.1,
'C' + str(ii) + ' ' + str(round(df_genes[gene_i].values[ii], 1)),
color='black', zorder=4, fontsize=6)
divider = make_axes_locatable(ax_i)
cax = divider.append_axes('right', size='10%', pad=0.05)
cbar = fig.colorbar(pos, cax=cax, orientation='vertical')
cbar.ax.tick_params(labelsize=8)
ax_i.set_title(gene_i)
ax_i.grid(False)
ax_i.set_xticks([])
ax_i.set_yticks([])
ax_i.axis('off')
fig.patch.set_visible(False)
return fig, axs
[docs]def plot_atlas_view(hammerbundle_dict=None, via_object=None, alpha_bundle_factor=1, linewidth_bundle=2,
facecolor: str = 'white', cmap: str = 'plasma', extra_title_text='', alpha_milestones: float = 0.3,
headwidth_bundle: float = 0.1, headwidth_alpha: float = 0.8, arrow_frequency: float = 0.05,
show_arrow: bool = True, sc_labels_sequential: list = None, sc_labels_expression: list = None,
initial_bandwidth=0.03, decay=0.7, n_milestones: int = None, scale_scatter_size_pop: bool = False,
show_milestones: bool = True, sc_labels: list = None, text_labels: bool = False,
lineage_pathway: list = [], dpi: int = 300, fontsize_title: int = 6, fontsize_labels: int = 6,
global_visual_pruning=0.5, use_sc_labels_sequential_for_direction: bool = False, sc_scatter_size=3,
sc_scatter_alpha: float = 0.4, add_sc_embedding: bool = True, size_milestones: int = 5,
colorbar_legend='pseudotime'):
'''
Edges can be colored by time-series numeric labels, pseudotime, lineage pathway probabilities, or gene expression. If not specificed then time-series is chosen if available, otherwise falls back to pseudotime. to use gene expression the sc_labels_expression is provided as a list.
To specify other numeric sequential data provide a list of sc_labels_sequential = [] n_samples in length. via_object.embedding must be an ndarray of shape (nsamples,2)
:param hammer_bundle_dict: dictionary with keys: hammerbundle object with coordinates of all the edges to draw. If hammer_bundle and layout are None, then this will be computed internally
:param via_object: type via object, if hammerbundle_dict is None, then you must provide a via_object. Ensure that via_object has embedding attribute
:param layout: coords of cluster nodes and optionally also contains the numeric value associated with each cluster (such as time-stamp) layout[['x','y','numeric label']] sc/cluster/milestone level
:param CSM: cosine similarity matrix. cosine similarity between the RNA velocity between neighbors and the change in gene expression between these neighbors. Only used when available
:param velocity_weight: percentage weightage given to the RNA velocity based transition matrix
:param pt: cluster-level pseudotime
:param alpha_bundle: alpha when drawing lines
:param linewidth_bundle: linewidth of bundled lines
:param edge_color:
:param alpha_milestones: float 0.3 alpha of milestones
:param size_milestones: scatter size of the milestones (use sc_size_scatter to control single cell scatter when using in conjunction with lineage probs/ sc embeddings)
:param arrow_frequency: min dist between arrows (bundled edges otherwise have overcrowding of arrows)
:param show_direction: True will draw arrows along the lines to indicate direction
:param milestone_edges: pandas DataFrame milestoone_edges[['source','target']]
:param milestone_numeric_values: the milestone average of numeric values such as time (days, hours), location (position), or other numeric value used for coloring edges in a sequential manner
if this is None then the edges are colored by length to distinguish short and long range edges
:param arrow_frequency: 0.05. higher means fewer arrows
:param n_milestones: int None. if no hammerbundle_dict is provided, but via_object is provided, then the user can specify level of granularity by setting the n_milestones. otherwise it will be automatically selected
:param scale_scatter_size_pop: bool default False
:param sc_labels_expression: list single cell numeric values used for coloring edges and nodes of corresponding milestones mean expression levels (len n_single_cell samples)
edges can be colored by time-series numeric (gene expression)/string (cell type) labels, pseudotime, or gene expression. If not specificed then time-series is chosen if available, otherwise falls back to pseudotime. to use gene expression the sc_labels_expression is provided as a list
:param sc_labels_sequential: list single cell numeric sequential values used for directionality inference as replacement for pseudotime or via_object.time_series_labels (len n_samples single cell)
:param sc_labels: list None list of single-cell level labels (categorial or discrete set of numerical values) to label the nodes
:param text_labels: bool False if you want to label the nodes based on sc_labels (or true_label if via_object is provided)
:param lineage_pathway: list of terminal states to plot lineage pathways
:param use_sc_labels_sequential_for_direction: use the sequential data (timeseries labels or other provided by user) to direct the arrows
:param lineage_alpha_threshold number representing the percentile (0-100) of lineage likelikhood in a particular lineage pathway, below which edges will be drawn with lower alpha transparency factor
:param sc_scatter_alpha: transparency of the background singlecell scatter when plotting lineages
:param add_sc_embedding: add background of single cell scatter plot for Atlas
:param scatter_size_sc_embedding
:param colorbar_legend str title of colorbar
:return: fig, axis with bundled edges plotted
'''
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
import matplotlib.colors as colors
new_cmap = colors.LinearSegmentedColormap.from_list(
'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
cmap(np.linspace(minval, maxval, n)))
return new_cmap
sc_scatter_alpha = 1 - sc_scatter_alpha
cmap_name = cmap
headwidth_alpha_og = headwidth_alpha
linewidth_bundle_og = linewidth_bundle
alpha_bundle_factor_og = alpha_bundle_factor
if hammerbundle_dict is None:
if via_object is None:
print('if hammerbundle_dict is not provided, then you must provide via_object')
else:
hammerbundle_dict = via_object.hammerbundle_milestone_dict
if hammerbundle_dict is None:
if n_milestones is None: n_milestones = min(via_object.nsamples, 150)
if sc_labels_sequential is None:
if via_object.time_series_labels is not None:
sc_labels_sequential = via_object.time_series_labels
else:
sc_labels_sequential = via_object.single_cell_pt_markov
print(f'{datetime.now()}\tComputing Edges')
hammerbundle_dict = make_edgebundle_milestone(via_object=via_object,
embedding=via_object.embedding,
sc_graph=via_object.ig_full_graph,
n_milestones=n_milestones,
sc_labels_numeric=sc_labels_sequential,
initial_bandwidth=initial_bandwidth, decay=decay,
weighted=True,
global_visual_pruning=global_visual_pruning)
via_object.hammerbundle_dict = hammerbundle_dict
hammer_bundle = hammerbundle_dict['hammerbundle']
layout = hammerbundle_dict['milestone_embedding'][['x', 'y']].values
milestone_edges = hammerbundle_dict['edges']
if sc_labels_expression is None:
milestone_numeric_values = hammerbundle_dict['milestone_embedding']['numeric label']
else:
if (isinstance(sc_labels_expression[0], str)) == True:
color_dict = {}
set_labels = list(set(sc_labels_expression))
set_labels.sort(reverse=True)
for index, value in enumerate(set_labels):
color_dict[value] = index
milestone_numeric_values = [color_dict[i] for i in sc_labels_expression]
sc_labels_expression = milestone_numeric_values
else:
milestone_numeric_values = sc_labels_expression
milestone_pt = hammerbundle_dict['milestone_embedding']['pt']
if use_sc_labels_sequential_for_direction: milestone_pt = hammerbundle_dict['milestone_embedding'][
'numeric label']
if sc_labels_expression is not None: # if both sclabelexpression and sequential are provided, then sc_labels_expression takes precedence
df = pd.DataFrame()
df['sc_milestone_labels'] = hammerbundle_dict['sc_milestone_labels']
df['sc_expression'] = sc_labels_expression
df = df.groupby('sc_milestone_labels').mean()
milestone_numeric_values = df[
'sc_expression'].values # used to color edges. direction is based on milestone_pt
else:
hammer_bundle = hammerbundle_dict['hammerbundle']
layout = hammerbundle_dict['milestone_embedding'][['x', 'y']].values
milestone_edges = hammerbundle_dict['edges']
milestone_numeric_values = hammerbundle_dict['milestone_embedding']['numeric label']
if sc_labels_expression is not None: # if both sclabelexpression and sequential are provided, then sc_labels_expression takes precedence
df = pd.DataFrame()
df['sc_milestone_labels'] = hammerbundle_dict['sc_milestone_labels']
df['sc_expression'] = sc_labels_expression
df = df.groupby('sc_milestone_labels').mean()
milestone_numeric_values = df['sc_expression'].values # used to color edges
milestone_pt = hammerbundle_dict['milestone_embedding']['pt']
if use_sc_labels_sequential_for_direction: milestone_pt = hammerbundle_dict['milestone_embedding'][
'numeric label']
if len(lineage_pathway) == 0:
# fig, ax = plt.subplots(facecolor=facecolor)
fig_nrows, fig_ncols = 1, 1
else:
lineage_pathway_temp = [i for i in lineage_pathway if
i in via_object.terminal_clusters] # checking the clusters are actually in terminal_clusters
lineage_pathway = lineage_pathway_temp
n_terminal_clusters = len(lineage_pathway)
fig_ncols = min(3, n_terminal_clusters)
fig_nrows, mod = divmod(n_terminal_clusters, fig_ncols)
if mod == 0:
if fig_nrows == 0:
fig_nrows += 1
else:
fig_nrows = fig_nrows
if mod != 0: fig_nrows += 1
fig, ax = plt.subplots(fig_nrows, fig_ncols, dpi=dpi, facecolor=facecolor)
counter_ = 0
n_real_subplots = max(len(lineage_pathway), 1)
majority_composition = ''
for r in range(fig_nrows):
for c in range(fig_ncols):
if (counter_ < n_real_subplots):
if len(lineage_pathway) > 0:
milestone_numeric_values = hammerbundle_dict['milestone_embedding'][
'sc_lineage_probability_' + str(lineage_pathway[counter_])]
p1_sc_bp = np.nan_to_num(via_object.single_cell_bp, nan=0.0, posinf=0.0,
neginf=0.0) # single cell lineage probabilities sc pb
# row normalize
row_sums = p1_sc_bp.sum(axis=1)
p1_sc_bp = p1_sc_bp / row_sums[:,
np.newaxis] # make rowsums a column vector where i'th entry is sum of i'th row in p1-sc-bp
ts_cluster_number = lineage_pathway[counter_]
ts_array_original = np.asarray(via_object.terminal_clusters)
loc_ts_current = np.where(ts_array_original == ts_cluster_number)[0][0]
print(
f'location of {lineage_pathway[counter_]} is at {np.where(ts_array_original == ts_cluster_number)[0]} and {loc_ts_current}')
p1_sc_bp = p1_sc_bp[:, loc_ts_current]
# print(f'{datetime.now()}\tCheck sc pb {p1_sc_bp[0, :].sum()} ')
if via_object is not None:
ts_current = lineage_pathway[counter_]
loc_labels = np.where(np.asarray(via_object.labels) == ts_current)[0]
majority_composition = func_mode(list(np.asarray(via_object.true_label)[loc_labels]))
x_ = [l[0] for l in layout]
y_ = [l[1] for l in layout]
# min_x, max_x = min(x_), max(x_)
# min_y, max_y = min(y_), max(y_)
delta_x = max(x_) - min(x_)
delta_y = max(y_) - min(y_)
layout = np.asarray(layout)
# get each segment. these are separated by nans.
hbnp = hammer_bundle.to_numpy()
splits = (np.isnan(hbnp[:, 0])).nonzero()[0] # location of each nan values
edgelist_segments = []
start = 0
segments = []
arrow_coords = []
seg_len = [] # length of a segment
for stop in splits:
seg = hbnp[start:stop, :]
segments.append(seg)
seg_len.append(seg.shape[0])
start = stop
min_seg_length = min(seg_len)
max_seg_length = max(seg_len)
seg_len = np.asarray(seg_len)
seg_len = np.clip(seg_len, a_min=np.percentile(seg_len, 10),
a_max=np.percentile(seg_len, 90))
# mean_seg_length = sum(seg_len)/len(seg_len)
step = 1 # every step'th segment is plotted
cmap = matplotlib.cm.get_cmap(cmap)
if milestone_numeric_values is not None:
max_numerical_value = max(milestone_numeric_values)
min_numerical_value = min(milestone_numeric_values)
##inserting edits here
from matplotlib.patches import Rectangle
sc_embedding = via_object.embedding
max_r = np.max(via_object.embedding[:, 0]) + 1
max_l = np.min(via_object.embedding[:, 0]) - 1
max_up = np.max(via_object.embedding[:, 1]) + 1
max_dw = np.min(via_object.embedding[:, 1]) - 1
if add_sc_embedding:
if len(lineage_pathway) == 0:
print('inside add sc embedding second if')
if sc_labels_expression is not None:
gene_expression = False
if gene_expression:
val_alph = [i if i > 0.3 else 0 for i in sc_labels_expression]
max_alph = max(val_alph)
val_alph = [i / max_alph for i in val_alph]
ax.scatter(via_object.embedding[:, 0], via_object.embedding[:, 1], alpha=val_alph,
c=sc_labels_expression, s=sc_scatter_size, cmap=cmap_name,
zorder=2) # alpha=1 change back to
# ax.scatter(via_object.embedding[:, 0], via_object.embedding[:, 1], alpha=0.1, c='lightgray', s=5)
# new_cmap= truncate_colormap(cmap, 0.25, 1.0) #use this for gene expression plotting in zebrahub non-neuro ecto
else:
ax.scatter(via_object.embedding[:, 0], via_object.embedding[:, 1], alpha=1,
c=sc_labels_expression, s=sc_scatter_size, cmap=cmap,
zorder=1) # alpha=1 change back to
ax.add_patch(Rectangle((max_l, max_dw), max_r - max_l, max_up - max_dw, facecolor=facecolor, #'white'
alpha=sc_scatter_alpha))
if len(lineage_pathway) > 0:
if fig_nrows == 1:
if fig_ncols == 1:
plot_sc_pb(ax, fig, embedding=via_object.embedding, prob=p1_sc_bp,
ti=str(ts_current) + '-' + str(majority_composition), cmap_name=cmap_name,
scatter_size=sc_scatter_size, fontsize=4, alpha_factor=1, show_legend=False)
ax.add_patch(
Rectangle((max_l, max_dw), max_r - max_l, max_up - max_dw, facecolor=facecolor,#"white",
alpha=sc_scatter_alpha))
# ax.scatter(via_object.embedding[:, 0], via_object.embedding[:, 1], alpha=0.05, c='white', s=5)
else:
plot_sc_pb(ax[c], fig, embedding=sc_embedding, prob=p1_sc_bp,
ti=str(ts_current) + '-' + str(majority_composition), cmap_name=cmap_name,
scatter_size=sc_scatter_size, fontsize=4, alpha_factor=1, show_legend=False)
ax[c].add_patch(
Rectangle((max_l, max_dw), max_r - max_l, max_up - max_dw, facecolor=facecolor,#"white",
alpha=sc_scatter_alpha))
# ax[c].scatter(via_object.embedding[:, 0], via_object.embedding[:, 1], alpha=0.1, c='white', s=4)
else:
plot_sc_pb(ax[r, c], fig, embedding=sc_embedding, prob=p1_sc_bp,
ti=str(ts_current) + '-' + str(majority_composition), cmap_name=cmap_name,
scatter_size=sc_scatter_size, fontsize=4, alpha_factor=1, show_legend=False)
ax[r, c].add_patch(Rectangle((max_l, max_dw), max_r - max_l, max_up - max_dw, facecolor=facecolor,#"white",
alpha=sc_scatter_alpha)) # 0.7
# ax[r, c].scatter(layout[:, 0], layout[:, 1], s=40, c='white', cmap=cmap_name, alpha=0.5, edgecolors='none') # vmax=1)
# ax[r, c].scatter(via_object.embedding[:, 0], via_object.embedding[:, 1], alpha=0.1, c='white', s=4,edgecolors='none')
# end white edits
seg_count = 0
for seg in segments[::step]:
do_arrow = True
# seg_weight = max(0.3, math.log(1+seg[-1,2])) seg[-1,2] column index 2 has the weight information
seg_weight = seg[-1, 2] * seg_len[seg_count] / (
max_seg_length - min_seg_length) ##seg.shape[0] / (max_seg_length - min_seg_length)#seg.shape[0]
# cant' quite decide yet if sigmoid is desirable
# seg_weight=sigmoid_scalar(seg.shape[0] / (max_seg_length - min_seg_length), scale=5, shift=mean_seg_length / (max_seg_length - min_seg_length))
alpha_bundle = max(seg_weight * alpha_bundle_factor, 0.1) # max(0.1, math.log(1 + seg[-1, 2]))
if alpha_bundle > 1: alpha_bundle = 1
source_milestone = milestone_edges['source'].values[seg_count]
target_milestone = milestone_edges['target'].values[seg_count]
direction = milestone_pt[target_milestone] - milestone_pt[source_milestone]
if direction < 0:
direction = -1
else:
direction = 1
source_milestone_numerical_value = milestone_numeric_values[source_milestone]
target_milestone_numerical_value = milestone_numeric_values[target_milestone]
# print('source milestone', source_milestone_numerical_value)
# print('target milestone', target_milestone_numerical_value)
min_source_target_numerical_value = min(source_milestone_numerical_value,
target_milestone_numerical_value) # ORIGINALLY USING MIN()
# min_source_target_numerical_value =(source_milestone_numerical_value+ target_milestone_numerical_value)/2
max_source_target_numerical_value = max(source_milestone_numerical_value,
target_milestone_numerical_value)
# consider using the max value for lineage pathways to better highlight the high probabilties near the cell fate
if len(lineage_pathway) > 0: # print('change remove this back to >0 in plotting_via.py and gray segment zorder =2')
# rgba = cmap((min_source_target_numerical_value - min_numerical_value) / (max_numerical_value - min_numerical_value))
if min_source_target_numerical_value <= 0.3 * np.max(
milestone_numeric_values): # 0.1:#np.percentile(milestone_numeric_values,lineage_alpha_threshold):
alpha_bundle = 0.01 # 0.1#0.01
headwidth_alpha = 0.01 # 0.2
linewidth_bundle = 0.1 * linewidth_bundle_og
elif ((min_source_target_numerical_value > 0.3 * np.max(milestone_numeric_values)) & (
min_source_target_numerical_value < 0.7 * np.max(
milestone_numeric_values))): # 0.1:#np.percentile(milestone_numeric_values,lineage_alpha_threshold):
alpha_bundle = 0.05 # 0.2#max(min_source_target_numerical_value/np.max(milestone_numeric_values) *alpha_bundle,0.01)
headwidth_alpha = 0.01 # 0.2
linewidth_bundle = min_source_target_numerical_value / np.max(
milestone_numeric_values) * linewidth_bundle_og
else:
headwidth_alpha = headwidth_alpha_og
linewidth_bundle = linewidth_bundle_og * 1.4
rgba = cmap((min_source_target_numerical_value - min_numerical_value) / (
max_numerical_value - min_numerical_value))
# rgba = new_cmap((min_source_target_numerical_value - min_numerical_value) / ( max_numerical_value - min_numerical_value)) #use for non-neuro-ecto zebrahub gene expression
# else: rgba = cmap(min(seg_weight,0.95))#cmap(seg.shape[0]/(max_seg_length-min_seg_length))
# if seg_weight>0.05: seg_weight=0.1
# if seg_count%10000==0: print('seg weight', seg_weight)
seg = seg[:, 0:2].reshape(-1, 2)
seg_p = seg[~np.isnan(seg)].reshape((-1, 2))
if fig_nrows == 1:
if fig_ncols == 1:
# ax.plot(seg_p[:, 0], seg_p[:, 1], linewidth=0.2, alpha=0.1, color='gray')
ax.plot(seg_p[:, 0], seg_p[:, 1], linewidth=linewidth_bundle * seg_weight,
alpha=alpha_bundle, color=rgba) # , zorder=2)#edge_color )
else:
ax[c].plot(seg_p[:, 0], seg_p[:, 1], linewidth=linewidth_bundle * seg_weight,
alpha=alpha_bundle, color=rgba) # edge_color )
else:
ax[r, c].plot(seg_p[:, 0], seg_p[:, 1], linewidth=linewidth_bundle * seg_weight,
alpha=alpha_bundle, color=rgba) # edge_color )
if (show_arrow) & (seg_p.shape[0] > 3):
mid_point = math.floor(seg_p.shape[0] / 2)
if len(arrow_coords) > 0: # dont draw arrows in overlapping segments
for v1 in arrow_coords:
dist_ = dist_points(v1, v2=[seg_p[mid_point, 0], seg_p[mid_point, 1]])
if dist_ < arrow_frequency * delta_x: do_arrow = False
if dist_ < arrow_frequency * delta_y: do_arrow = False
if (do_arrow == True) & (seg_p.shape[0] > 3):
if fig_nrows == 1:
if fig_ncols == 1:
ax.arrow(seg_p[mid_point, 0], seg_p[mid_point, 1],
seg_p[mid_point + (direction * step), 0] - seg_p[mid_point, 0],
seg_p[mid_point + (direction * step), 1] - seg_p[mid_point, 1],
lw=0, length_includes_head=False, head_width=headwidth_bundle, color=rgba,
shape='full', alpha=headwidth_alpha, zorder=5)
else:
ax[c].arrow(seg_p[mid_point, 0], seg_p[mid_point, 1],
seg_p[mid_point + (direction * step), 0] - seg_p[mid_point, 0],
seg_p[mid_point + (direction * step), 1] - seg_p[mid_point, 1],
lw=0, length_includes_head=False, head_width=headwidth_bundle,
color=rgba, shape='full', alpha=headwidth_alpha, zorder=5)
else:
ax[r, c].arrow(seg_p[mid_point, 0], seg_p[mid_point, 1],
seg_p[mid_point + (direction * step), 0] - seg_p[mid_point, 0],
seg_p[mid_point + (direction * step), 1] - seg_p[mid_point, 1],
lw=0, length_includes_head=False, head_width=headwidth_bundle,
color=rgba, shape='full', alpha=headwidth_alpha, zorder=5)
arrow_coords.append([seg_p[mid_point, 0], seg_p[mid_point, 1]])
seg_count += 1
if show_milestones == False:
size_milestones = 0.01
show_milestones = True
scale_scatter_size_pop = False
if show_milestones == True:
milestone_numeric_values_normed = []
milestone_numeric_values_rgba = []
for ei, i in enumerate(milestone_numeric_values):
# if you have numeric values (days, hours) that need to be scaled to 0-1 so they can be used in cmap()
if len(lineage_pathway) > 0: # these are probabilities and we dont want to normalize the lineage likelihooods
rgba_ = cmap(i)
color_numeric = (i)
else:
rgba_ = cmap((i - min_numerical_value) / (max_numerical_value - min_numerical_value))
color_numeric = (i - min_numerical_value) / (max_numerical_value - min_numerical_value)
milestone_numeric_values_normed.append(color_numeric)
milestone_numeric_values_rgba.append(
rgba_) # need a list of rgb when also plotting labels as plot colors are done one-by-one
if scale_scatter_size_pop == True:
n_samples = layout.shape[0]
sqrt_nsamples = math.sqrt(n_samples)
group_pop_scale = [math.log(6 + i / sqrt_nsamples) for i in
hammerbundle_dict['milestone_embedding']['cluster population']]
size_scatter_scaled = [size_milestones * i for i in group_pop_scale]
else:
size_scatter_scaled = size_milestones # constant value
# NOTE # using vmax=1 in the scatter plot would mean that all values are plotted relative to a 0-1 scale and the legend for all plots is 0-1.
# If we want to allow that each legend is unique then there is autoscaling of the colors such that the max color is set to the max value of that particular subplot (even if that max value is well below 1)
if fig_nrows == 1:
if fig_ncols == 1:
im = ax.scatter(layout[:, 0], layout[:, 1], s=0.01,
c=milestone_numeric_values_normed, cmap=cmap_name,
edgecolors='None') # without alpha parameter which otherwise gets passed onto the colorbar
ax.scatter(layout[:, 0], layout[:, 1], s=size_scatter_scaled,
c=milestone_numeric_values_normed, cmap=cmap_name, alpha=alpha_milestones,
edgecolors='None')
else:
im = ax[c].scatter(layout[:, 0], layout[:, 1], s=0.01, c=milestone_numeric_values_normed,
cmap=cmap_name,
edgecolors='None') # without alpha parameter which otherwise gets passed onto the colorbar
ax[c].scatter(layout[:, 0], layout[:, 1], s=size_scatter_scaled,
c=milestone_numeric_values_normed, cmap=cmap_name, alpha=alpha_milestones,
edgecolors='None')
else:
'''
ax[r, c].scatter(layout[:, 0], layout[:, 1], s=size_scatter_scaled*3,
c=milestone_numeric_values_normed, cmap=cmap_name,
alpha=alpha_milestones*0.5, edgecolors='none', vmin=min_numerical_value) # vmax=1)
'''
im = ax[r, c].scatter(layout[:, 0], layout[:, 1], c=milestone_numeric_values_normed, s=0.01,
cmap=cmap_name, edgecolors='none', vmin=min_numerical_value)
ax[r, c].scatter(layout[:, 0], layout[:, 1], s=size_scatter_scaled,
c=milestone_numeric_values_normed, cmap=cmap_name,
alpha=alpha_milestones, edgecolors='none', vmin=min_numerical_value) # vmax=1)
''''
if len(lineage_pathway)>0:
#accentuate the scatter size for nodes significant to a lineage
for j in range(layout.shape[0]):
if milestone_numeric_values_normed[j] > 0.5*max_numerical_value:
if scale_scatter_size_pop:
ax[r, c].scatter(layout[j, 0], layout[j, 1], s=size_scatter_scaled[j] * 1.5,
c=milestone_numeric_values_normed[j], cmap=cmap_name,
alpha=alpha_milestones * 1.5, edgecolors='None',
vmin=min_numerical_value) # vmax=1)
else:
print(f'node {j} {milestone_numeric_values_normed[j]}')
ax[r, c].scatter(layout[j, 0], layout[j, 1], s=size_scatter_scaled*1.5,
c=milestone_numeric_values_normed[j], cmap=cmap_name,
alpha=alpha_milestones*1.5, edgecolors='None',
vmin=min_numerical_value) # vmax=1)
'''
if text_labels == True:
# if text labels is true but user has not provided any labels at the sc level from which to create milestone categorical labels
if sc_labels is None:
if via_object is not None:
sc_labels = via_object.true_label
else:
print(
f'{datetime.now()}\t ERROR: in order to show labels, please provide list of sc_labels at the single cell level OR via_object')
for i in range(layout.shape[0]):
sc_milestone_labels = hammerbundle_dict['sc_milestone_labels']
loc_milestone = np.where(np.asarray(sc_milestone_labels) == i)[0]
mode_label = func_mode(list(np.asarray(sc_labels)[loc_milestone]))
if scale_scatter_size_pop == True:
if fig_nrows == 1:
if fig_ncols == 1:
ax.scatter(layout[i, 0], layout[i, 1], s=size_scatter_scaled[i],
c=np.array([milestone_numeric_values_rgba[i]]),
alpha=alpha_milestones, edgecolors='None', label=mode_label)
else:
ax[c].scatter(layout[i, 0], layout[i, 1], s=size_scatter_scaled[i],
c=np.array([milestone_numeric_values_rgba[i]]),
alpha=alpha_milestones, edgecolors='None', label=mode_label)
else:
ax[r, c].scatter(layout[i, 0], layout[i, 1], s=size_scatter_scaled[i],
c=np.array([milestone_numeric_values_rgba[i]]),
alpha=alpha_milestones, edgecolors='None', label=mode_label)
else:
if fig_nrows == 1:
if fig_ncols == 1:
ax.scatter(layout[i, 0], layout[i, 1], s=size_scatter_scaled,
c=np.array([milestone_numeric_values_rgba[i]]),
alpha=alpha_milestones, edgecolors='None', label=mode_label)
else:
ax[c].scatter(layout[i, 0], layout[i, 1], s=size_scatter_scaled,
c=np.array([milestone_numeric_values_rgba[i]]),
alpha=alpha_milestones, edgecolors='None', label=mode_label)
else:
ax[r, c].scatter(layout[i, 0], layout[i, 1], s=size_scatter_scaled,
c=np.array([milestone_numeric_values_rgba[i]]),
alpha=alpha_milestones, edgecolors='None', label=mode_label)
if fig_nrows == 1:
if fig_ncols == 1:
ax.text(layout[i, 0], layout[i, 1], mode_label, style='italic',
fontsize=fontsize_labels, color="black")
else:
ax[c].text(layout[i, 0], layout[i, 1], mode_label, style='italic',
fontsize=fontsize_labels, color="black")
else:
ax[r, c].text(layout[i, 0], layout[i, 1], mode_label, style='italic',
fontsize=fontsize_labels, color="black")
time = datetime.now()
time = time.strftime("%H:%M")
if len(lineage_pathway) == 0:
title_ = extra_title_text + ' n_milestones = ' + str(int(layout.shape[0])) # + ' time: ' + time
else:
title_ = 'lineage:' + str(lineage_pathway[counter_]) + '-' + str(majority_composition)
if fig_nrows == 1:
if fig_ncols == 1:
ax.axis('off')
ax.grid(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_facecolor(facecolor)
ax.set_title(label=title_, color='black', fontsize=fontsize_title)
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
if len(lineage_pathway) > 0:
cb = fig.colorbar(im, cax=cax, orientation='vertical', label='lineage likelihood')
else:
cb = fig.colorbar(im, cax=cax, orientation='vertical', label='pseudotime')
ax_cb = cb.ax
text = ax_cb.yaxis.label
font = matplotlib.font_manager.FontProperties(
size=fontsize_title) # family='times new roman', style='italic',
text.set_font_properties(font)
ax_cb.tick_params(labelsize=int(fontsize_title * 0.8))
cb.outline.set_visible(False)
else:
ax[c].axis('off')
ax[c].grid(False)
ax[c].spines['top'].set_visible(False)
ax[c].spines['right'].set_visible(False)
ax[c].set_facecolor(facecolor)
ax[c].set_title(label=title_, color='black', fontsize=fontsize_title)
divider = make_axes_locatable(ax[c])
cax = divider.append_axes('right', size='5%', pad=0.05)
if (len(lineage_pathway)) > 0:
colorbar_legend = 'lineage likelihood'
cb = fig.colorbar(im, cax=cax, orientation='vertical', label=colorbar_legend)
else:
cb = fig.colorbar(im, cax=cax, orientation='vertical', label=colorbar_legend)
ax_cb = cb.ax
text = ax_cb.yaxis.label
font = matplotlib.font_manager.FontProperties(
size=fontsize_title) # family='times new roman', style='italic',
text.set_font_properties(font)
ax_cb.tick_params(labelsize=int(fontsize_title * 0.8))
cb.outline.set_visible(False)
else:
ax[r, c].axis('off')
ax[r, c].grid(False)
ax[r, c].spines['top'].set_visible(False)
ax[r, c].spines['right'].set_visible(False)
ax[r, c].set_facecolor(facecolor)
ax[r, c].set_title(label=title_, color='black', fontsize=fontsize_title)
divider = make_axes_locatable(ax[r, c])
cax = divider.append_axes('right', size='5%', pad=0.05)
if len(lineage_pathway) > 0:
cb = fig.colorbar(im, cax=cax, orientation='vertical', label='Lineage likelihood')
else:
cb = fig.colorbar(im, cax=cax, orientation='vertical', label='pseudotime')
ax_cb = cb.ax
text = ax_cb.yaxis.label
font = matplotlib.font_manager.FontProperties(
size=fontsize_title) # family='times new roman', style='italic',
text.set_font_properties(font)
ax_cb.tick_params(labelsize=int(fontsize_title * 0.8))
cb.outline.set_visible(False)
counter_ += 1
else:
if fig_nrows == 1:
if fig_ncols == 1:
ax.axis('off')
ax.grid(False)
else:
ax[c].axis('off')
ax[c].grid(False)
else:
ax[r, c].axis('off')
ax[r, c].grid(False)
return fig, ax
[docs]def animate_atlas(hammerbundle_dict=None, via_object=None, linewidth_bundle=2, frame_interval: int = 10,
n_milestones: int = None, facecolor: str = 'white', cmap: str = 'plasma_r', extra_title_text='',
size_scatter: int = 1, alpha_scatter: float = 0.2,
saveto='/home/user/Trajectory/Datasets/animation_default.gif', time_series_labels: list = None, lineage_pathway = [],
sc_labels_numeric: list = None, show_sc_embedding:bool=False, sc_emb=None, sc_size_scatter:float=10, sc_alpha_scatter:float=0.2, n_intervals:int = 50, n_repeat:int = 2):
'''
:param ax: axis to plot on
:param hammer_bundle: hammerbundle object with coordinates of all the edges to draw
:param layout: coords of cluster nodes and optionally also contains the numeric value associated with each cluster (such as time-stamp) layout[['x','y','numeric label']] sc/cluster/milestone level
:param CSM: cosine similarity matrix. cosine similarity between the RNA velocity between neighbors and the change in gene expression between these neighbors. Only used when available
:param velocity_weight: percentage weightage given to the RNA velocity based transition matrix
:param pt: cluster-level pseudotime
:param alpha_bundle: alpha when drawing lines
:param linewidth_bundle: linewidth of bundled lines
:param edge_color:
:param frame_interval: smaller number, faster refresh and video
:param facecolor: default = white
:param headwidth_bundle: headwidth of arrows used in bundled edges
:param arrow_frequency: min dist between arrows (bundled edges otherwise have overcrowding of arrows)
:param show_direction: True will draw arrows along the lines to indicate direction
:param milestone_edges: pandas DataFrame milestone_edges[['source','target']]
:param t_diff_factor scaling the average the time intervals (0.25 means that for each frame, the time is progressed by 0.25* mean_time_differernce_between adjacent times (only used when sc_labels_numeric are directly passed instead of using pseudotime)
:param show_sc_embedding: plot the single cell embedding under the edges
:param sc_emb numpy array of single cell embedding (ncells x 2)
:param sc_alpha_scatter, Alpha transparency value of points of single cells (1 is opaque, 0 is fully transparent)
:param sc_size_scatter. size of scatter points of single cells
:param n_repeat. number of times you repeat the whole process
:return: axis with bundled edges plotted
'''
import tqdm
cmap = matplotlib.cm.get_cmap(cmap)
if show_sc_embedding:
if sc_emb is None:
sc_emb= via_object.embedding
if sc_emb is None:
print('please provide a single cell embedding as an array')
return
if hammerbundle_dict is None:
if via_object is None:
print(
f'{datetime.now()}\tERROR: Hammerbundle_dict needs to be provided either through via_object or by running make_edgebundle_milestone()')
else:
hammerbundle_dict = via_object.hammerbundle_milestone_dict
if hammerbundle_dict is None:
if n_milestones is None: n_milestones = min(via_object.nsamples, 150)
if sc_labels_numeric is None:
if via_object.time_series_labels is not None:
sc_labels_numeric = via_object.time_series_labels
else:
sc_labels_numeric = via_object.single_cell_pt_markov
hammerbundle_dict = make_edgebundle_milestone(via_object=via_object,
embedding=via_object.embedding,
sc_graph=via_object.ig_full_graph,
n_milestones=n_milestones,
sc_labels_numeric=sc_labels_numeric,
initial_bandwidth=0.02, decay=0.7, weighted=True)
hammer_bundle = hammerbundle_dict['hammerbundle']
layout = hammerbundle_dict['milestone_embedding'][['x', 'y']].values
milestone_edges = hammerbundle_dict['edges']
milestone_numeric_values = hammerbundle_dict['milestone_embedding']['numeric label']
else:
hammer_bundle = hammerbundle_dict['hammerbundle']
layout = hammerbundle_dict['milestone_embedding'][['x', 'y']].values
milestone_edges = hammerbundle_dict['edges']
milestone_numeric_values = hammerbundle_dict['milestone_embedding']['numeric label']
fig, ax = plt.subplots(facecolor=facecolor, figsize=(15, 12))
n_milestones = len(milestone_numeric_values)
if len(lineage_pathway) > 0:
milestone_lin_values = hammerbundle_dict['milestone_embedding'][
'sc_lineage_probability_' + str(lineage_pathway[0])]
p1_sc_bp = np.nan_to_num(via_object.single_cell_bp, nan=0.0, posinf=0.0,
neginf=0.0) # single cell lineage probabilities sc pb
# row normalize
row_sums = p1_sc_bp.sum(axis=1)
p1_sc_bp = p1_sc_bp / row_sums[:,
np.newaxis] # make rowsums a column vector where i'th entry is sum of i'th row in p1-sc-bp
ts_cluster_number = lineage_pathway[0]
ts_array_original = np.asarray(via_object.terminal_clusters)
loc_ts_current = np.where(ts_array_original == ts_cluster_number)[0][0]
print(f'{datetime.now()}\tlocation of {lineage_pathway[0]} is at {np.where(ts_array_original == ts_cluster_number)[0]} and {loc_ts_current}')
p1_sc_bp = p1_sc_bp[:, loc_ts_current]
rgba_lineage_sc = []
rgba_lineage_milestone = []
min_p1_sc_pb = min(p1_sc_bp)
max_p1_sc_pb = max(p1_sc_bp)
min_milestone_lin_values = min(milestone_lin_values)
max_milestone_lin_values = max(milestone_lin_values)
print(f"{datetime.now()}\t making rgba_lineage_sc")
for i in p1_sc_bp:
rgba_lineage_sc_ = cmap((i - min_p1_sc_pb) / (max_p1_sc_pb - min_p1_sc_pb))
rgba_lineage_sc.append(rgba_lineage_sc_)
print(f"{datetime.now()}\t making rgba_lineage_sc")
for i in milestone_lin_values:
rgba_lineage_milestone_ = cmap((i - min_milestone_lin_values) / (max_milestone_lin_values - min_milestone_lin_values))
rgba_lineage_milestone.append(rgba_lineage_milestone_)
# ax.set_facecolor(facecolor)
ax.grid(False)
x_ = [l[0] for l in layout]
y_ = [l[1] for l in layout]
layout = np.asarray(layout)
# make a knn so we can find which clustergraph nodes the segments start and end at
# get each segment. these are separated by nans.
hbnp = hammer_bundle.to_numpy()
splits = (np.isnan(hbnp[:, 0])).nonzero()[0] # location of each nan values
edgelist_segments = []
start = 0
segments = []
arrow_coords = []
seg_len = [] # length of a segment
for stop in splits:
seg = hbnp[start:stop, :]
segments.append(seg)
seg_len.append(seg.shape[0])
start = stop
min_seg_length = min(seg_len)
max_seg_length = max(seg_len)
# mean_seg_length = sum(seg_len)/len(seg_len)
seg_len = np.asarray(seg_len)
seg_len = np.clip(seg_len, a_min=np.percentile(seg_len, 10),
a_max=np.percentile(seg_len, 90))
if milestone_numeric_values is not None:
max_numerical_value = max(milestone_numeric_values)
min_numerical_value = min(milestone_numeric_values)
seg_count = 0
i_sorted_numeric_values = np.argsort(milestone_numeric_values)
ee = int(n_milestones / n_intervals)
print('ee',ee)
loc_time_thresh = i_sorted_numeric_values[0:ee]
for ll in loc_time_thresh:
print('sorted numeric milestone',milestone_numeric_values[ll])
# print('loc time thres', loc_time_thresh)
milestone_edges['source_thresh'] = milestone_edges['source'].isin(
loc_time_thresh) # apply(lambda x: any([k in x for k in loc_time_thresh]))
# print(milestone_edges[0:10])
idx = milestone_edges.index[milestone_edges['source_thresh']].tolist()
# print('loc time thres', time_thresh, loc_time_thresh)
for i in idx:
seg = segments[i]
source_milestone = milestone_edges['source'].values[i]
target_milestone = milestone_edges['target'].values[i]
# seg_weight = max(0.3, math.log(1+seg[-1,2])) seg[-1,2] column index 2 has the weight information
seg_weight = seg[-1, 2] * seg_len[i] / (
max_seg_length - min_seg_length) ##seg.shape[0] / (max_seg_length - min_seg_length)
# cant' quite decide yet if sigmoid is desirable
# seg_weight=sigmoid_scalar(seg.shape[0] / (max_seg_length - min_seg_length), scale=5, shift=mean_seg_length / (max_seg_length - min_seg_length))
alpha_bundle_firstsegments = max(seg_weight, 0.1)
if alpha_bundle_firstsegments > 1:
alpha_bundle_firstsegments = 1
if len(lineage_pathway)>0:
#alpha_bundle_firstsegments = milestone_lin_values[loc_time_thresh[0]] #the alpha should be propotional to the lineage_pb of these segments
alpha_bundle_firstsegments = milestone_lin_values[source_milestone]
if alpha_bundle_firstsegments < 0.7: alpha_bundle_firstsegments *= alpha_bundle_firstsegments
if milestone_numeric_values is not None:
if len(lineage_pathway) > 0:
source_milestone_numerical_value = milestone_lin_values[source_milestone]
target_milestone_numerical_value = milestone_lin_values[target_milestone]
rgba_milestone_value = min(source_milestone_numerical_value, target_milestone_numerical_value)
rgba = cmap((rgba_milestone_value - min_milestone_lin_values) / (max_milestone_lin_values - min_milestone_lin_values))
else:
source_milestone_numerical_value = milestone_numeric_values[source_milestone]
target_milestone_numerical_value = milestone_numeric_values[target_milestone]
rgba_milestone_value = min(source_milestone_numerical_value, target_milestone_numerical_value)
rgba = cmap((rgba_milestone_value - min_numerical_value) / (max_numerical_value - min_numerical_value))
else:
rgba = cmap(min(seg_weight, 0.95)) # cmap(seg.shape[0]/(max_seg_length-min_seg_length))
# if seg_weight>0.05: seg_weight=0.1
if seg_count % 10000 == 0: print('seg weight', seg_weight)
seg = seg[:, 0:2].reshape(-1, 2)
seg_p = seg[~np.isnan(seg)].reshape((-1, 2))
ax.plot(seg_p[:, 0], seg_p[:, 1], linewidth=linewidth_bundle * seg_weight, alpha=alpha_bundle_firstsegments,
color=rgba) # edge_color )
seg_count += 1
milestone_numeric_values_rgba = []
print(f'{datetime.now()}\there1 in animate()')
if milestone_numeric_values is not None:
for i in milestone_numeric_values:
rgba_ = cmap((i - min_numerical_value) / (max_numerical_value - min_numerical_value))
milestone_numeric_values_rgba.append(rgba_)
if show_sc_embedding:
if len(lineage_pathway)>0:
weighted_alpha = [sc_alpha_scatter*i for i in p1_sc_bp]
weighted_alpha = [0.5*sc_alpha_scatter if i<= 0.5*sc_alpha_scatter else i for i in weighted_alpha]
ax.scatter(sc_emb[:, 0], sc_emb[:, 1], s=sc_size_scatter,
c=p1_sc_bp, alpha=weighted_alpha, cmap=cmap)
else: ax.scatter(sc_emb[:, 0], sc_emb[:, 1], s=size_scatter, c='blue', alpha=0)
if len(lineage_pathway) > 0:
ax.scatter(layout[loc_time_thresh, 0], layout[loc_time_thresh, 1], s=size_scatter,
c=np.asarray(rgba_lineage_milestone)[loc_time_thresh], alpha=alpha_scatter)
else: ax.scatter(layout[loc_time_thresh, 0], layout[loc_time_thresh, 1], s=size_scatter,
c=np.asarray(milestone_numeric_values_rgba)[loc_time_thresh], alpha=alpha_scatter)
# if we dont plot all the points, then the size of axis changes and the location of the graph moves/changes as more points are added
ax.scatter(layout[:, 0], layout[:, 1], s=size_scatter,
c=np.asarray(milestone_numeric_values_rgba), alpha=0)
else:
ax.scatter(layout[:, 0], layout[:, 1], s=size_scatter, c='red', alpha=alpha_scatter)
print('here2 in animate()')
ax.set_facecolor(facecolor)
ax.axis('off')
time = datetime.now()
time = time.strftime("%H:%M")
title_ = 'n_milestones = ' + str(int(layout.shape[0])) + ' time: ' + time + ' ' + extra_title_text
ax.set_title(label=title_, color='black')
print(f"{datetime.now()}\tFinished plotting edge bundle")
if time_series_labels is None: #over-ride via_object's saved time_series_labels and/or pseudotime
if via_object is not None:
time_series_labels = via_object.time_series_labels
if time_series_labels is None:
time_series_labels = via_object.single_cell_pt_markov
cycles = n_intervals
min_time_series_labels = min(time_series_labels)
max_time_series_labels = max(time_series_labels)
sc_rgba = []
for i in time_series_labels:
sc_rgba_ = cmap((i - min_time_series_labels) / (max_time_series_labels - min_time_series_labels))
sc_rgba.append(sc_rgba_)
if show_sc_embedding:
print(f"{datetime.now()}\tdoing argsort of single cell time_series_labels")
i_sorted_sc_time = np.argsort(time_series_labels)
print(f"{datetime.now()}\tfinish argsort of single cell time_series_labels")
n_cells = len(time_series_labels)
def update_edgebundle(frame_no):
print(frame_no, 'out of', n_intervals, 'cycles')
rem = (frame_no % n_intervals)
if (frame_no % n_intervals)==0:
loc_time_thresh = i_sorted_numeric_values[0:int(n_milestones/n_intervals)]
if show_sc_embedding:
sc_loc_time_thresh = i_sorted_sc_time[0:int(n_milestones / n_intervals)]
else:
loc_time_thresh = i_sorted_numeric_values[rem*int(n_milestones/n_intervals):(rem+1)*int(n_milestones/n_intervals)]
if show_sc_embedding:
sc_loc_time_thresh = i_sorted_sc_time[rem*int(n_cells/n_intervals):(rem+1)*int(n_cells/n_intervals)]
#sc_loc_time_thresh = np.where((np.asarray(time_series_labels) <= time_thresh) & ( np.asarray(time_series_labels) > time_thresh - t_diff_mean))[0].tolist()
milestone_edges['source_thresh'] = milestone_edges['source'].isin(
loc_time_thresh) # apply(lambda x: any([k in x for k in loc_time_thresh]))
idx = milestone_edges.index[milestone_edges['source_thresh']].tolist()
for i in idx:
seg = segments[i]
source_milestone = milestone_edges['source'].values[i]
# seg_weight = max(0.3, math.log(1+seg[-1,2])) seg[-1,2] column index 2 has the weight information
seg_weight = seg[-1, 2] * seg_len[i] / (
max_seg_length - min_seg_length) ##seg.shape[0] / (max_seg_length - min_seg_length)
# cant' quite decide yet if sigmoid is desirable
# seg_weight=sigmoid_scalar(seg.shape[0] / (max_seg_length - min_seg_length), scale=5, shift=mean_seg_length / (max_seg_length - min_seg_length))
alpha_bundle = max(seg_weight, 0.1) # max(0.1, math.log(1 + seg[-1, 2]))
if alpha_bundle > 1: alpha_bundle = 1
if milestone_numeric_values is not None:
source_milestone_numerical_value = milestone_numeric_values[source_milestone]
if len(lineage_pathway)==0:
rgba = cmap((source_milestone_numerical_value - min_numerical_value) / ( max_numerical_value - min_numerical_value))
else:
rgba = list(rgba_lineage_milestone[source_milestone])
#print('pre alpha modified', rgba)
#rgba[3] = milestone_lin_values[source_milestone] #source_milestone
#rgba = tuple(rgba)
#print('alpha modified',rgba)
else:
rgba = cmap(min(seg_weight, 0.95)) # cmap(seg.shape[0]/(max_seg_length-min_seg_length))
# if seg_weight>0.05: seg_weight=0.1
seg = seg[:, 0:2].reshape(-1, 2)
seg_p = seg[~np.isnan(seg)].reshape((-1, 2))
if len(lineage_pathway)>0:
squared_alpha = milestone_lin_values[source_milestone]
if squared_alpha<0.6: squared_alpha*=squared_alpha
#print('squared_alpha',squared_alpha)
ax.plot(seg_p[:, 0], seg_p[:, 1], linewidth=linewidth_bundle * seg_weight, color=rgba, alpha=squared_alpha)#milestone_lin_values[source_milestone])
else: ax.plot(seg_p[:, 0], seg_p[:, 1], linewidth=linewidth_bundle * seg_weight, color=rgba,alpha=alpha_bundle)
milestone_numeric_values_rgba = []
if milestone_numeric_values is not None:
for i in milestone_numeric_values:
rgba_ = cmap((i - min_numerical_value) / (max_numerical_value - min_numerical_value))
milestone_numeric_values_rgba.append(rgba_)
if ((frame_no > n_repeat*n_intervals) and (rem ==0)): #by using > rather than >= sign, two complete cycles run before the axis is cleared and the animation is restarted
ax.clear()
ax.axis('off')
else:
ax.scatter(layout[:, 0], layout[:, 1], s=size_scatter,
c=np.asarray(milestone_numeric_values_rgba), alpha=0)
if len(lineage_pathway) > 0:
ax.scatter(layout[loc_time_thresh, 0], layout[loc_time_thresh, 1], s=size_scatter,
c=np.asarray(rgba_lineage_milestone)[loc_time_thresh], alpha=alpha_scatter)
else: ax.scatter(layout[loc_time_thresh, 0], layout[loc_time_thresh, 1], s=size_scatter,
c=np.asarray(milestone_numeric_values_rgba)[loc_time_thresh], alpha=alpha_scatter)
if show_sc_embedding:
if len(lineage_pathway)>0:
ax.scatter(sc_emb[sc_loc_time_thresh, 0], sc_emb[sc_loc_time_thresh, 1], s=sc_size_scatter, edgecolors = None,
c=np.asarray(rgba_lineage_sc)[sc_loc_time_thresh], alpha=[p1_sc_bp[sc_loc_time_thresh]]) #no *sc_alpha_scatter
else: ax.scatter(sc_emb[sc_loc_time_thresh, 0], sc_emb[sc_loc_time_thresh, 1], s=sc_size_scatter,
c=np.asarray(sc_rgba)[sc_loc_time_thresh], alpha=sc_alpha_scatter)
else:
ax.scatter(layout[:, 0], layout[:, 1], s=size_scatter, c='red', alpha=alpha_scatter)
# pbar.update()
frame_no = int(int(cycles)*n_repeat)
animation = FuncAnimation(fig, update_edgebundle, frames=frame_no, interval=frame_interval, repeat=False) # 100
# pbar = tqdm.tqdm(total=frame_no)
# pbar.close()
print('complete animate')
animation.save(saveto, writer='imagemagick') # , fps=30)
print('saved animation')
plt.show()
return
[docs]def animate_streamplot(via_object, embedding, density_grid=1,
linewidth=0.5, min_mass=1, cutoff_perc=None, scatter_size=500, scatter_alpha=0.2,
marker_edgewidth=0.1, smooth_transition=1, smooth_grid=0.5, color_scheme='annotation',
other_labels=[], b_bias=20, n_neighbors_velocity_grid=None, fontsize=8, alpha_animate=0.7,
cmap_scatter='rainbow', cmap_stream='Blues', segment_length=1,
saveto='/home/shobi/Trajectory/Datasets/animation.gif', use_sequentially_augmented=False,
facecolor_='white', random_seed=0):
'''
Draw Animated vector plots. the Saved .gif file saved at the saveto address, is the best for viewing the animation as the fig, ax output can be slow
:param via_object: viaobject
:param embedding: ndarray (nsamples,2) umap, tsne, via-umap, via-mds
:param density_grid:
:param linewidth:
:param min_mass:
:param cutoff_perc:
:param scatter_size:
:param scatter_alpha:
:param marker_edgewidth:
:param smooth_transition:
:param smooth_grid:
:param color_scheme: 'annotation', 'cluster', 'other'
:param add_outline_clusters:
:param cluster_outline_edgewidth:
:param gp_color:
:param bg_color:
:param title:
:param b_bias:
:param n_neighbors_velocity_grid:
:param fontsize:
:param alpha_animate:
:param cmap_scatter:
:param cmap_stream: string of a cmap for streamlines, default = 'Blues' (for dark blue lines) . Consider 'Blues_r' for white lines OR 'Greys/_r' 'gist_yard/_r'
:param color_stream: string like 'white'. will override cmap_stream
:param segment_length:
:return: fig, ax.
'''
import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, writers
from matplotlib.collections import LineCollection
from pyVIA.windmap import Streamlines
import matplotlib.patheffects as PathEffects
print(f'{datetime.now()}\tStep1 velocity embedding')
# import cartopy.crs as ccrs
if embedding is None:
embedding = via_object.embedding
if embedding is None: print(
f'ERROR: please provide input parameter embedding of ndarray with shape (nsamples, 2)')
V_emb = _pl_velocity_embedding(via_object, embedding, smooth_transition, b=b_bias,
use_sequentially_augmented=use_sequentially_augmented)
V_emb *= 10 # the velocity of the samples has shape (n_samples x 2).*100
print(f'{datetime.now()}\tStep2 interpolate')
# interpolate the velocity along all grid points based on the velocities of the samples in V_emb
X_grid, V_grid = compute_velocity_on_grid(
X_emb=embedding,
V_emb=V_emb,
density=density_grid,
smooth=smooth_grid,
min_mass=min_mass,
autoscale=False,
adjust_for_stream=True,
cutoff_perc=cutoff_perc, n_neighbors=n_neighbors_velocity_grid)
print(f'{datetime.now()}\tInside animated. File will be saved to location {saveto}')
# lengths = np.sqrt((V_grid ** 2).sum(0))
fig = plt.figure(figsize=(10, 8))
ax = plt.subplot(1, 1, 1)
fig.patch.set_visible(False)
if color_scheme == 'time':
ax.scatter(embedding[:, 0], embedding[:, 1], c=via_object.single_cell_pt_markov, alpha=scatter_alpha, zorder=0,
s=scatter_size, linewidths=marker_edgewidth, cmap=cmap_scatter)
else:
if color_scheme == 'annotation': color_labels = via_object.true_label
if color_scheme == 'cluster': color_labels = via_object.labels
if color_scheme == 'other': color_labels = other_labels
n_true = len(set(color_labels))
lin_col = np.linspace(0, 1, n_true)
col = 0
cmap = matplotlib.cm.get_cmap(cmap_scatter) # 'twilight' is nice too
cmap = cmap(np.linspace(0.01, 0.80, n_true)) # .95
# cmap = cmap(np.linspace(0.5, 0.95, n_true))
for color, group in zip(lin_col, sorted(set(color_labels))):
color_ = np.asarray(cmap[col]).reshape(-1, 4)
color_[0, 3] = scatter_alpha
where = np.where(np.array(color_labels) == group)[0]
ax.scatter(embedding[where, 0], embedding[where, 1], label=group,
c=color_,
alpha=scatter_alpha, zorder=0, s=scatter_size,
linewidths=marker_edgewidth) # plt.cm.rainbow(color))
x_mean = embedding[where, 0].mean()
y_mean = embedding[where, 1].mean()
ax.text(x_mean, y_mean, '', fontsize=fontsize, zorder=4,
path_effects=[PathEffects.withStroke(linewidth=linewidth, foreground='w')],
weight='bold') # str(group)
col += 1
ax.set_facecolor(facecolor_)
lengths = []
colors = []
lines = []
linewidths = []
count = 0
# X, Y, U, V = interpolate_static_stream(X_grid[0], X_grid[1], V_grid[0],V_grid[1])
s = Streamlines(X_grid[0], X_grid[1], V_grid[0], V_grid[1])
# s = Streamlines(X,Y,U,V)
for streamline in s.streamlines:
random_seed += 1
count += 1
x, y = streamline
# interpolate x, y data to handle nans
x_ = np.array(x)
nans, func_ = nan_helper(x_)
x_[nans] = np.interp(func_(nans), func_(~nans), x_[~nans])
y_ = np.array(y)
nans, func_ = nan_helper(y_)
y_[nans] = np.interp(func_(nans), func_(~nans), y_[~nans])
# test=proj.transform_points(x=np.array(x),y=np.array(y),src_crs=proj)
# points = np.array([x, y]).T.reshape(-1, 1, 2)
points = np.array([x_, y_]).T.reshape(-1, 1, 2)
# print('points')
# print(points.shape)
segments = np.concatenate([points[:-1], points[1:]], axis=1) # nx2x2
n = len(segments)
D = np.sqrt(((points[1:] - points[:-1]) ** 2).sum(axis=-1)) / segment_length
np.random.seed(random_seed)
L = D.cumsum().reshape(n, 1) + np.random.uniform(0, 1)
C = np.zeros((n, 4)) # 3
C[::-1] = (L * 1.5) % 1
C[:, 3] = alpha_animate
lw = L.flatten().tolist()
line = LineCollection(segments, color=C, linewidth=1) # 0.1 when changing linewidths in update
# line = LineCollection(segments, color=C_locationbased, linewidth=1)
lengths.append(L)
colors.append(C)
linewidths.append(lw)
lines.append(line)
ax.add_collection(line)
print('total number of stream lines', count)
ax.set_xlim(min(X_grid[0]), max(X_grid[0]))
ax.set_xticks([])
ax.set_ylim(min(X_grid[1]), max(X_grid[1]))
ax.set_yticks([])
plt.tight_layout()
# print('colors', colors)
def update(frame_no):
cmap = matplotlib.cm.get_cmap(cmap_stream)
# cmap = cmap(np.linspace(0.1, 0.2, 100)) #darker portion
cmap = cmap(np.linspace(0.8, 0.9, 100)) # lighter portion
for i in range(len(lines)):
lengths[i] -= 0.05
# esthetic factors here by adding 0.1 and doing some clipping, 0.1 ensures no "hard blacks"
colors[i][::-1] = np.clip(0.1 + (lengths[i] * 1.5) % 1, 0.2, 0.9)
colors[i][:, 3] = alpha_animate
temp = (lengths[i] * 1) % 2 # *1.5
# temp = (lengths[i] * 1.5) % 1 # *1.5 original until Sep 7 2022
linewidths[i] = temp.flatten().tolist()
# if i==5: print('temp', linewidths[i])
'''
if i%5 ==0:
print('lengths',i, lengths[i])
colors[i][::-1] = (lengths[i] * 1.5) % 1
colors[i][:, 0] = 1
'''
# CMAP COLORS
# cmap_colors = [cmap(j) for j in colors[i][:,0]] #when using full cmap_stream
cmap_colors = [cmap[int(j * 100)] for j in colors[i][:, 0]] # when using truncated cmap_stream
'''
cmap_colors = [cmap[int(j*100)] for j in colors[i][:, 0]]
linewidths[i] = [f[0]*2 for f in cmap_colors]
if i ==5: print('colors', [f[0]for f in cmap_colors])
'''
for row in range(colors[i].shape[0]):
colors[i][row, :] = cmap_colors[row][0:4]
# colors[i][row, 3] = (1-colors[i][row][0])*0.6#alpha_animate
# linewidths[i][row] = 2-((colors[i][row][0])%2) #1-colors[i]... #until 7 sept 2022
# if color_stream is not None: colors[i][row, :] = matplotlib.colors.to_rgba_array(color_stream)[0] #monochrome is nice 1 or 0
# if i == 5: print('lw', linewidths[i])
colors[i][:, 3] = alpha_animate
lines[i].set_linewidth(linewidths[i])
lines[i].set_color(colors[i])
pbar.update()
n = 250 # 27
animation = FuncAnimation(fig, update, frames=n, interval=40)
pbar = tqdm.tqdm(total=n)
pbar.close()
animation.save(saveto, writer='imagemagick', fps=25)
# animation.save('/home/shobi/Trajectory/Datasets/Toy3/wind_density_ffmpeg.mp4', writer='ffmpeg', fps=60)
# fig.patch.set_visible(False)
# ax.axis('off')
plt.show()
return fig, ax
[docs]def via_streamplot(via_object, embedding: ndarray = None, density_grid: float = 0.5, arrow_size: float = 0.7,
arrow_color: str = 'k', color_dict: dict = None,
arrow_style="-|>", max_length: int = 4, linewidth: float = 1, min_mass=1, cutoff_perc: int = 5,
scatter_size: int = 500, scatter_alpha: float = 0.5, marker_edgewidth: float = 0.1,
density_stream: int = 2, smooth_transition: int = 1, smooth_grid: float = 0.5,
color_scheme: str = 'annotation', add_outline_clusters: bool = False,
cluster_outline_edgewidth=0.001, gp_color='white', bg_color='black', dpi=300, title='Streamplot',
b_bias=20, n_neighbors_velocity_grid=None, labels: list = None, use_sequentially_augmented=False,
cmap: str = 'rainbow', show_text_labels: bool = True):
'''
Construct vector streamplot on the embedding to show a fine-grained view of inferred directions in the trajectory
:param via_object:
:param embedding: np.ndarray of shape (n_samples, 2) umap or other 2-d embedding on which to project the directionality of cells
:param density_grid:
:param arrow_size:
:param arrow_color:
:param arrow_style:
:param max_length:
:param linewidth: width of lines in streamplot, default = 1
:param min_mass:
:param cutoff_perc:
:param scatter_size: size of scatter points default =500
:param scatter_alpha: transpsarency of scatter points
:param marker_edgewidth: width of outline arround each scatter point, default = 0.1
:param density_stream:
:param smooth_transition:
:param smooth_grid:
:param color_scheme: str, default = 'annotation' corresponds to self.true_labels. Other options are 'time' (uses single-cell pseudotime) and 'cluster' (via cluster graph) and 'other'. Alternatively provide labels as a list
:param add_outline_clusters:
:param cluster_outline_edgewidth:
:param gp_color:
:param bg_color:
:param dpi:
:param title:
:param b_bias: default = 20. higher value makes the forward bias of pseudotime stronger
:param n_neighbors_velocity_grid:
:param labels: list (will be used for the color scheme) or if a color_dict is provided these labels should match
:param use_sequentially_augmented:
:param cmap:
:return: fig, ax
'''
"""
Parameters
----------
X_emb:
scatter_size: int, default = 500
linewidth:
marker_edgewidth:
streamplot matplotlib.pyplot instance of fine-grained trajectories drawn on top of scatter plot
"""
import matplotlib.patheffects as PathEffects
if embedding is None:
embedding = via_object.embedding
if embedding is None:
print(
f'{datetime.now()}\tWARNING: please assign ambedding attribute to via_object as via_object.embedding = ndarray of [n_cells x 2]')
V_emb = via_object._velocity_embedding(embedding, smooth_transition, b=b_bias,
use_sequentially_augmented=use_sequentially_augmented)
V_emb *= 20 # 5
X_grid, V_grid = compute_velocity_on_grid(
X_emb=embedding,
V_emb=V_emb,
density=density_grid,
smooth=smooth_grid,
min_mass=min_mass,
autoscale=False,
adjust_for_stream=True,
cutoff_perc=cutoff_perc, n_neighbors=n_neighbors_velocity_grid)
# adapted from : https://github.com/theislab/scvelo/blob/1805ab4a72d3f34496f0ef246500a159f619d3a2/scvelo/plotting/velocity_embedding_grid.py#L27
lengths = np.sqrt((V_grid ** 2).sum(0))
linewidth = 1 if linewidth is None else linewidth
# linewidth *= 2 * lengths / np.percentile(lengths[~np.isnan(lengths)],90)
linewidth *= 2 * lengths / lengths[~np.isnan(lengths)].max()
# linewidth=0.5
fig, ax = plt.subplots(dpi=dpi)
ax.grid(False)
ax.streamplot(X_grid[0], X_grid[1], V_grid[0], V_grid[1], color=arrow_color, arrowsize=arrow_size,
arrowstyle=arrow_style, zorder=3, linewidth=linewidth, density=density_stream, maxlength=max_length)
# num_cluster = len(set(super_cluster_labels))
if add_outline_clusters:
# add black outline to outer cells and a white inner rim
# adapted from scanpy (scVelo utils adapts this from scanpy)
gp_size = (2 * (scatter_size * cluster_outline_edgewidth * .1) + 0.1 * scatter_size) ** 2
bg_size = (2 * (scatter_size * cluster_outline_edgewidth) + math.sqrt(gp_size)) ** 2
ax.scatter(embedding[:, 0], embedding[:, 1], s=bg_size, marker=".", c=bg_color, zorder=-2)
ax.scatter(embedding[:, 0], embedding[:, 1], s=gp_size, marker=".", c=gp_color, zorder=-1)
if labels is None:
if color_scheme == 'time':
ax.scatter(embedding[:, 0], embedding[:, 1], c=via_object.single_cell_pt_markov, alpha=scatter_alpha,
zorder=0, s=scatter_size, linewidths=marker_edgewidth, cmap='viridis_r')
else:
if color_scheme == 'annotation': color_labels = via_object.true_label
if color_scheme == 'cluster': color_labels = via_object.labels
cmap_ = plt.get_cmap(cmap)
# plt.cm.rainbow(color)
line = np.linspace(0, 1, len(set(color_labels)))
for color, group in zip(line, sorted(set(color_labels))):
where = np.where(np.array(color_labels) == group)[0]
ax.scatter(embedding[where, 0], embedding[where, 1], label=group,
c=np.asarray(cmap_(color)).reshape(-1, 4),
alpha=scatter_alpha, zorder=0, s=scatter_size, linewidths=marker_edgewidth)
if show_text_labels:
x_mean = embedding[where, 0].mean()
y_mean = embedding[where, 1].mean()
ax.text(x_mean, y_mean, str(group), fontsize=5, zorder=4,
path_effects=[PathEffects.withStroke(linewidth=1, foreground='w')], weight='bold')
elif labels is not None:
if (isinstance(labels[0], str)) == True: # labels are categorical
if color_dict is not None:
for key in color_dict:
loc_key = np.where(np.asarray(labels) == key)[0]
ax.scatter(embedding[loc_key, 0], embedding[loc_key, 1], color=color_dict[key], label=key,
s=scatter_size,
alpha=scatter_alpha, zorder=0, linewidths=marker_edgewidth)
x_mean = embedding[loc_key, 0].mean()
y_mean = embedding[loc_key, 1].mean()
if show_text_labels == True: ax.text(x_mean, y_mean, key, style='italic', fontsize=10,
color="black")
else: # there is no color_dict but labels are categorical
cmap_ = plt.get_cmap(cmap)
line = np.linspace(0, 1, len(set(labels)))
for color, group in zip(line, sorted(set(labels))):
where = np.where(np.array(labels) == group)[0]
ax.scatter(embedding[where, 0], embedding[where, 1], label=group,
c=np.asarray(cmap_(color)).reshape(-1, 4),
alpha=scatter_alpha, zorder=0, s=scatter_size, linewidths=marker_edgewidth)
if show_text_labels:
x_mean = embedding[where, 0].mean()
y_mean = embedding[where, 1].mean()
ax.text(x_mean, y_mean, str(group), fontsize=5, zorder=4,
path_effects=[PathEffects.withStroke(linewidth=1, foreground='w')], weight='bold')
else: # not categorical
ax.scatter(embedding[:, 0], embedding[:, 1], c=labels, alpha=scatter_alpha, zorder=0,
s=scatter_size, linewidths=marker_edgewidth, cmap=cmap)
fig.patch.set_visible(False)
ax.axis('off')
ax.set_title(title)
return fig, ax
[docs]def get_gene_expression(via_object, gene_exp: pd.DataFrame, cmap: str = 'jet', dpi: int = 150, marker_genes: list = [],
linewidth: float = 2.0, n_splines: int = 10, spline_order: int = 4, fontsize_: int = 8,
marker_lineages=[], optional_title_text: str = '', cmap_dict: dict = None):
'''
:param via_object: via object
:param gene_exp: dataframe where columns are features (gene) and rows are single cells
:param cmap: default: 'jet'
:param dpi: default:150
:param marker_genes: Default is to use all genes in gene_exp. other provide a list of marker genes that will be used from gene_exp.
:param linewidth: default:2
:param n_slines: default:10 Note n_splines must be > spline_order.
:param spline_order: default:4 n_splines must be > spline_order.
:param marker_lineages: Default is to use all lineage pathways. other provide a list of lineage number (terminal cluster number).
:param cmap_dict: {lineage number: 'color'}
:return: fig, axs
'''
sc_bp_original = via_object.single_cell_bp
if len(marker_lineages) == 0:
marker_lineages = via_object.terminal_clusters
n_terminal_states = len(via_object.terminal_clusters)
else:
n_terminal_states = len(marker_lineages)
if len(marker_genes) > 0: gene_exp = gene_exp[marker_genes]
sc_pt = via_object.single_cell_pt_markov
if cmap_dict is None:
palette = cm.get_cmap(cmap, n_terminal_states)
cmap_ = palette(range(n_terminal_states))
else:
cmap_ = cmap_dict
n_genes = gene_exp.shape[1]
fig_nrows, mod = divmod(n_genes, 4)
if mod == 0: fig_nrows = fig_nrows
if mod != 0: fig_nrows += 1
fig_ncols = 4
fig, axs = plt.subplots(fig_nrows, fig_ncols, dpi=dpi)
fig.patch.set_visible(False)
i_gene = 0 # counter for number of genes
i_terminal = 0 # counter for terminal cluster
# for i in range(n_terminal_states): #[0]
for r in range(fig_nrows):
for c in range(fig_ncols):
if (i_gene < n_genes):
for enum_i_lineage, i_lineage in enumerate(marker_lineages):
valid_scbp = False
if i_lineage in via_object.terminal_clusters:
i_terminal = np.where(np.asarray(via_object.terminal_clusters) == i_lineage)[0]
if len(i_terminal) > 0:
sc_bp = sc_bp_original.copy()
valid_scbp = len(np.where(sc_bp[:, i_terminal] > 0.9)[0]) > 0
i_terminal = i_terminal[0]
# if (via_object.terminal_clusters[i_terminal] in marker_lineages and len(np.where(sc_bp[:, i_terminal] > 0.8)[ 0]) > 0): # check if terminal state is in marker_lineage and in case this terminal state i cannot be reached (sc_bp is all 0)
if (i_lineage in via_object.terminal_clusters) and (valid_scbp):
cluster_i_loc = \
np.where(np.asarray(via_object.labels) == via_object.terminal_clusters[i_terminal])[0]
majority_true = via_object.func_mode(list(np.asarray(via_object.true_label)[cluster_i_loc]))
gene_i = gene_exp.columns[i_gene]
loc_i = np.where(sc_bp[:, i_terminal] > 0.9)[0]
val_pt = [sc_pt[pt_i] for pt_i in loc_i] # TODO, replace with array to speed up
max_val_pt = max(val_pt)
loc_i_bp = np.where(sc_bp[:, i_terminal] > 0.000)[0] # 0.001
loc_i_sc = np.where(np.asarray(sc_pt) <= max_val_pt)[0]
loc_ = np.intersect1d(loc_i_bp, loc_i_sc)
gam_in = np.asarray(sc_pt)[loc_]
x = gam_in.reshape(-1, 1)
y = np.asarray(gene_exp[gene_i])[loc_].reshape(-1, 1)
weights = np.asarray(sc_bp[:, i_terminal])[loc_].reshape(-1, 1)
if len(loc_) > 1:
geneGAM = pg.LinearGAM(n_splines=n_splines, spline_order=spline_order, lam=10).fit(x, y,
weights=weights)
xval = np.linspace(min(sc_pt), max_val_pt, 100 * 2)
yg = geneGAM.predict(X=xval)
else:
print(
f'{datetime.now()}\tLineage {i_terminal} cannot be reached. Exclude this lineage in trend plotting')
if cmap_dict is None:
color_ = cmap_[enum_i_lineage]
else:
print('cmap dict', cmap_)
print('i_lineage', i_lineage)
color_ = cmap_[i_lineage]
if fig_nrows > 1:
axs[r, c].plot(xval, yg, color=color_, linewidth=linewidth, zorder=3,
label=f"Lineage:{majority_true} {via_object.terminal_clusters[i_terminal]}")
axs[r, c].set_title(gene_i + optional_title_text, fontsize=fontsize_)
# Set tick font size
for label in (axs[r, c].get_xticklabels() + axs[r, c].get_yticklabels()):
label.set_fontsize(fontsize_ - 1)
if i_gene == n_genes - 1:
axs[r, c].legend(frameon=False, fontsize=fontsize_)
axs[r, c].set_xlabel('Time', fontsize=fontsize_)
axs[r, c].set_ylabel('Intensity', fontsize=fontsize_)
axs[r, c].spines['top'].set_visible(False)
axs[r, c].spines['right'].set_visible(False)
axs[r, c].grid(False)
else:
axs[c].plot(xval, yg, color=color_, linewidth=linewidth, zorder=3,
label=f"Lineage:{majority_true} {via_object.terminal_clusters[i_terminal]}")
axs[c].set_title(gene_i + optional_title_text, fontsize=fontsize_)
# Set tick font size
for label in (axs[c].get_xticklabels() + axs[c].get_yticklabels()):
label.set_fontsize(fontsize_ - 1)
if i_gene == n_genes - 1:
axs[c].legend(frameon=False, fontsize=fontsize_)
axs[c].set_xlabel('Time', fontsize=fontsize_)
axs[c].set_ylabel('Intensity', fontsize=fontsize_)
axs[c].spines['top'].set_visible(False)
axs[c].spines['right'].set_visible(False)
axs[c].grid(False)
i_gene += 1
else:
if fig_nrows > 1:
axs[r, c].axis('off')
axs[r, c].grid(False)
else:
axs[c].axis('off')
axs[c].grid(False)
return fig, axs
[docs]def plot_trajectory_curves(via_object, embedding: ndarray = None, idx: Optional[list] = None,
title_str: str = "Pseudotime", draw_all_curves: bool = True,
arrow_width_scale_factor: float = 15.0,
scatter_size: float = 50, scatter_alpha: float = 0.5,
linewidth: float = 1.5, marker_edgewidth: float = 1, cmap_pseudotime: str = 'viridis_r',
dpi: int = 150, highlight_terminal_states: bool = True, use_maxout_edgelist: bool = False):
'''
projects the graph based coarse trajectory onto a umap/tsne embedding
:param via_object: via object
:param embedding: 2d array [n_samples x 2] with x and y coordinates of all n_samples. Umap, tsne, pca OR use the via computed embedding via_object.embedding
:param idx: default: None. Or List. if you had previously computed a umap/tsne (embedding) only on a subset of the total n_samples (subsampled as per idx), then the via objects and results will be indexed according to idx too
:param title_str: title of figure
:param draw_all_curves: if the clustergraph has too many edges to project in a visually interpretable way, set this to False to get a simplified view of the graph pathways
:param arrow_width_scale_factor:
:param scatter_size:
:param scatter_alpha:
:param linewidth:
:param marker_edgewidth:
:param cmap_pseudotime:
:param dpi: int default = 150. Use 300 for paper figures
:param highlight_terminal_states: whether or not to highlight/distinguish the clusters which are detected as the terminal states by via
:return: f, ax1, ax2
'''
if embedding is None:
embedding = via_object.embedding
if embedding is None: print(
f'{datetime.now()}\t ERROR please provide an embedding or compute using via_mds() or via_umap()')
from mpl_toolkits.axes_grid1 import make_axes_locatable
if idx is None: idx = np.arange(0, via_object.nsamples)
cluster_labels = list(np.asarray(via_object.labels)[idx])
super_cluster_labels = list(np.asarray(via_object.labels)[idx])
super_edgelist = via_object.edgelist
if use_maxout_edgelist == True:
super_edgelist = via_object.edgelist_maxout
true_label = list(np.asarray(via_object.true_label)[idx])
knn = via_object.knn
ncomp = via_object.ncomp
final_super_terminal = via_object.terminal_clusters
sub_terminal_clusters = via_object.terminal_clusters
sc_pt_markov = list(np.asarray(via_object.single_cell_pt_markov)[idx])
super_root = via_object.root[0]
sc_supercluster_nn = sc_loc_ofsuperCluster_PCAspace(via_object, np.arange(0, len(cluster_labels)))
# draw_all_curves. True draws all the curves in the piegraph, False simplifies the number of edges
# arrow_width_scale_factor: size of the arrow head
X_dimred = embedding * 1. / np.max(embedding, axis=0)
x = X_dimred[:, 0]
y = X_dimred[:, 1]
max_x = np.percentile(x, 90)
noise0 = max_x / 1000
df = pd.DataFrame({'x': x, 'y': y, 'cluster': cluster_labels, 'super_cluster': super_cluster_labels,
'projected_sc_pt': sc_pt_markov},
columns=['x', 'y', 'cluster', 'super_cluster', 'projected_sc_pt'])
df_mean = df.groupby('cluster', as_index=False).mean()
sub_cluster_isin_supercluster = df_mean[['cluster', 'super_cluster']]
sub_cluster_isin_supercluster = sub_cluster_isin_supercluster.sort_values(by='cluster')
sub_cluster_isin_supercluster['int_supercluster'] = sub_cluster_isin_supercluster['super_cluster'].round(0).astype(
int)
df_super_mean = df.groupby('super_cluster', as_index=False).mean()
pt = df_super_mean['projected_sc_pt'].values
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=[20, 10], dpi=dpi)
num_true_group = len(set(true_label))
num_cluster = len(set(super_cluster_labels))
line = np.linspace(0, 1, num_true_group)
for color, group in zip(line, sorted(set(true_label))):
where = np.where(np.array(true_label) == group)[0]
ax1.scatter(X_dimred[where, 0], X_dimred[where, 1], label=group,
c=np.asarray(plt.cm.rainbow(color)).reshape(-1, 4),
alpha=scatter_alpha, s=scatter_size, linewidths=marker_edgewidth * .1) # 10 # 0.5 and 4
ax1.legend(fontsize=6, frameon=False)
ax1.set_title('True Labels: ncomps:' + str(ncomp) + '. knn:' + str(knn))
G_orange = ig.Graph(n=num_cluster, edges=super_edgelist)
ll_ = [] # this can be activated if you intend to simplify the curves
for fst_i in final_super_terminal:
path_orange = G_orange.get_shortest_paths(super_root, to=fst_i)[0]
len_path_orange = len(path_orange)
for enum_edge, edge_fst in enumerate(path_orange):
if enum_edge < (len_path_orange - 1):
ll_.append((edge_fst, path_orange[enum_edge + 1]))
edges_to_draw = super_edgelist if draw_all_curves else list(set(ll_))
for e_i, (start, end) in enumerate(edges_to_draw):
if pt[start] >= pt[end]:
start, end = end, start
x_i_start = df[df['super_cluster'] == start]['x'].values
y_i_start = df[df['super_cluster'] == start]['y'].values
x_i_end = df[df['super_cluster'] == end]['x'].values
y_i_end = df[df['super_cluster'] == end]['y'].values
super_start_x = X_dimred[sc_supercluster_nn[start], 0]
super_end_x = X_dimred[sc_supercluster_nn[end], 0]
super_start_y = X_dimred[sc_supercluster_nn[start], 1]
super_end_y = X_dimred[sc_supercluster_nn[end], 1]
direction_arrow = -1 if super_start_x > super_end_x else 1
minx = min(super_start_x, super_end_x)
maxx = max(super_start_x, super_end_x)
miny = min(super_start_y, super_end_y)
maxy = max(super_start_y, super_end_y)
x_val = np.concatenate([x_i_start, x_i_end])
y_val = np.concatenate([y_i_start, y_i_end])
idx_keep = np.where((x_val <= maxx) & (x_val >= minx))[0]
idy_keep = np.where((y_val <= maxy) & (y_val >= miny))[0]
idx_keep = np.intersect1d(idy_keep, idx_keep)
x_val = x_val[idx_keep]
y_val = y_val[idx_keep]
super_mid_x = (super_start_x + super_end_x) / 2
super_mid_y = (super_start_y + super_end_y) / 2
from scipy.spatial import distance
very_straight = False
straight_level = 3
noise = noise0
x_super = np.array(
[super_start_x, super_end_x, super_start_x, super_end_x, super_start_x, super_end_x, super_start_x,
super_end_x, super_start_x + noise, super_end_x + noise,
super_start_x - noise, super_end_x - noise])
y_super = np.array(
[super_start_y, super_end_y, super_start_y, super_end_y, super_start_y, super_end_y, super_start_y,
super_end_y, super_start_y + noise, super_end_y + noise,
super_start_y - noise, super_end_y - noise])
if abs(minx - maxx) <= 1:
very_straight = True
straight_level = 10
x_super = np.append(x_super, super_mid_x)
y_super = np.append(y_super, super_mid_y)
for i in range(straight_level): # DO THE SAME FOR A MIDPOINT TOO
y_super = np.concatenate([y_super, y_super])
x_super = np.concatenate([x_super, x_super])
list_selected_clus = list(zip(x_val, y_val))
if len(list_selected_clus) >= 1 & very_straight:
dist = distance.cdist([(super_mid_x, super_mid_y)], list_selected_clus, 'euclidean')
k = min(2, len(list_selected_clus))
midpoint_loc = dist[0].argsort()[:k]
midpoint_xy = []
for i in range(k):
midpoint_xy.append(list_selected_clus[midpoint_loc[i]])
noise = noise0 * 2
if k == 1:
mid_x = np.array([midpoint_xy[0][0], midpoint_xy[0][0] + noise, midpoint_xy[0][0] - noise])
mid_y = np.array([midpoint_xy[0][1], midpoint_xy[0][1] + noise, midpoint_xy[0][1] - noise])
if k == 2:
mid_x = np.array(
[midpoint_xy[0][0], midpoint_xy[0][0] + noise, midpoint_xy[0][0] - noise, midpoint_xy[1][0],
midpoint_xy[1][0] + noise, midpoint_xy[1][0] - noise])
mid_y = np.array(
[midpoint_xy[0][1], midpoint_xy[0][1] + noise, midpoint_xy[0][1] - noise, midpoint_xy[1][1],
midpoint_xy[1][1] + noise, midpoint_xy[1][1] - noise])
for i in range(3):
mid_x = np.concatenate([mid_x, mid_x])
mid_y = np.concatenate([mid_y, mid_y])
x_super = np.concatenate([x_super, mid_x])
y_super = np.concatenate([y_super, mid_y])
x_val = np.concatenate([x_val, x_super])
y_val = np.concatenate([y_val, y_super])
x_val = x_val.reshape((len(x_val), -1))
y_val = y_val.reshape((len(y_val), -1))
xp = np.linspace(minx, maxx, 500)
gam50 = pg.LinearGAM(n_splines=4, spline_order=3, lam=10).gridsearch(x_val, y_val)
XX = gam50.generate_X_grid(term=0, n=500)
preds = gam50.predict(XX)
idx_keep = np.where((xp <= (maxx)) & (xp >= (minx)))[0]
ax2.plot(XX, preds, linewidth=linewidth, c='#323538') # 3.5#1.5
mean_temp = np.mean(xp[idx_keep])
closest_val = xp[idx_keep][0]
closest_loc = idx_keep[0]
for i, xp_val in enumerate(xp[idx_keep]):
if abs(xp_val - mean_temp) < abs(closest_val - mean_temp):
closest_val = xp_val
closest_loc = idx_keep[i]
step = 1
head_width = noise * arrow_width_scale_factor # arrow_width needs to be adjusted sometimes # 40#30 ##0.2 #0.05 for mESC #0.00001 (#for 2MORGAN and others) # 0.5#1
if direction_arrow == 1:
ax2.arrow(xp[closest_loc], preds[closest_loc], xp[closest_loc + step] - xp[closest_loc],
preds[closest_loc + step] - preds[closest_loc], shape='full', lw=0, length_includes_head=False,
head_width=head_width, color='#323538')
else:
ax2.arrow(xp[closest_loc], preds[closest_loc], xp[closest_loc - step] - xp[closest_loc],
preds[closest_loc - step] - preds[closest_loc], shape='full', lw=0, length_includes_head=False,
head_width=head_width, color='#323538')
c_edge = []
width_edge = []
pen_color = []
super_cluster_label = []
terminal_count_ = 0
dot_size = []
for i in sc_supercluster_nn:
if i in final_super_terminal:
print(f'{datetime.now()}\tSuper cluster {i} is a super terminal with sub_terminal cluster',
sub_terminal_clusters[terminal_count_])
c_edge.append('yellow') # ('yellow')
if highlight_terminal_states == True:
width_edge.append(2)
super_cluster_label.append('TS' + str(sub_terminal_clusters[terminal_count_]))
else:
width_edge.append(0)
super_cluster_label.append('')
pen_color.append('black')
# super_cluster_label.append('TS' + str(i)) # +'('+str(i)+')')
# +'('+str(i)+')')
dot_size.append(60) # 60
terminal_count_ = terminal_count_ + 1
else:
width_edge.append(0)
c_edge.append('black')
pen_color.append('red')
super_cluster_label.append(str(' ')) # i or ' '
dot_size.append(00) # 20
ax2.set_title(title_str)
im2 = ax2.scatter(X_dimred[:, 0], X_dimred[:, 1], c=sc_pt_markov, cmap=cmap_pseudotime, s=0.01)
divider = make_axes_locatable(ax2)
cax = divider.append_axes('right', size='5%', pad=0.05)
f.colorbar(im2, cax=cax, orientation='vertical',
label='pseudotime') # to avoid lines drawn on the colorbar we need an image instance without alpha variable
ax2.scatter(X_dimred[:, 0], X_dimred[:, 1], c=sc_pt_markov, cmap=cmap_pseudotime, alpha=scatter_alpha,
s=scatter_size, linewidths=marker_edgewidth * .1)
count_ = 0
loci = [sc_supercluster_nn[key] for key in sc_supercluster_nn]
for i, c, w, pc, dsz, lab in zip(loci, c_edge, width_edge, pen_color, dot_size,
super_cluster_label): # sc_supercluster_nn
ax2.scatter(X_dimred[i, 0], X_dimred[i, 1], c='black', s=dsz, edgecolors=c, linewidth=w)
ax2.annotate(str(lab), xy=(X_dimred[i, 0], X_dimred[i, 1]))
count_ = count_ + 1
ax1.grid(False)
ax2.grid(False)
f.patch.set_visible(False)
ax1.axis('off')
ax2.axis('off')
return f, ax1, ax2
[docs]def plot_viagraph_(ax=None, hammer_bundle=None, layout: ndarray = None, CSM: ndarray = None,
velocity_weight: float = None, pt: list = None, alpha_bundle=1, linewidth_bundle=2,
edge_color='darkblue', headwidth_bundle=0.1, arrow_frequency=0.05, show_direction=True,
ax_text: bool = True, title: str = '', plot_clusters: bool = False, cmap: str = 'viridis',
via_object=None, fontsize: float = 9, dpi: int = 300,tune_edges:bool = False,initial_bandwidth=0.05, decay=0.9, edgebundle_pruning=0.5):
'''
this plots the edgebundles on the via clustergraph level and also adds the relevant arrow directions based on the TI directionality
:param ax: axis to plot on
:param hammer_bundle: hammerbundle object with coordinates of all the edges to draw. self.hammer
:param layout: coords of cluster nodes
:param CSM: cosine similarity matrix. cosine similarity between the RNA velocity between neighbors and the change in gene expression between these neighbors. Only used when available
:param velocity_weight: percentage weightage given to the RNA velocity based transition matrix
:param pt: cluster-level pseudotime (or other intensity level of features at average-cluster level)
:param alpha_bundle: alpha when drawing lines
:param linewidth_bundle: linewidth of bundled lines
:param edge_color:
:param headwidth_bundle: headwidth of arrows used in bundled edges
:param arrow_frequency: min dist between arrows (bundled edges otherwise have overcrowding of arrows)
:param show_direction: bool default True. will draw arrows along the lines to indicate direction
:param plot_clusters: bool default False. When this function is called on its own (and not from within draw_piechart_graph() then via_object must be provided
:param ax_text: bool default True. Show labels of the clusters with the cluster population and PARC cluster label
:param fontsize: float default 9 Font size of labels
:return: fig, ax with bundled edges plotted
'''
return_fig_ax = False # return only the ax
if ax == None:
fig, ax = plt.subplots(dpi=dpi)
ax.set_facecolor('white')
fig.patch.set_visible(False)
return_fig_ax = True
if (plot_clusters == True) and (via_object is None):
print('Warning: please provide a via object in order to plot the clusters on the graph')
if via_object is not None:
if hammer_bundle is None: hammer_bundle = via_object.hammerbundle_cluster
if layout is None: layout = via_object.graph_node_pos
if CSM is None: CSM = via_object.CSM
if velocity_weight is None: velocity_weight = via_object.velo_weight
if pt is None: pt = via_object.scaled_hitting_times
if tune_edges:
print('make new edgebundle')
hammer_bundle, layout = make_edgebundle_viagraph(via_object=via_object, layout=via_object.layout, decay=decay,
initial_bandwidth=initial_bandwidth,
edgebundle_pruning=edgebundle_pruning) # hold the layout fixed. only change the edges
x_ = [l[0] for l in layout]
y_ = [l[1] for l in layout]
# min_x, max_x = min(x_), max(x_)
# min_y, max_y = min(y_), max(y_)
delta_x = max(x_) - min(x_)
delta_y = max(y_) - min(y_)
layout = np.asarray(layout)
# make a knn so we can find which clustergraph nodes the segments start and end at
neigh = NearestNeighbors(n_neighbors=1)
neigh.fit(layout)
# get each segment. these are separated by nans.
hbnp = hammer_bundle.to_numpy()
splits = (np.isnan(hbnp[:, 0])).nonzero()[0] # location of each nan values
edgelist_segments = []
start = 0
segments = []
arrow_coords = []
for stop in splits:
seg = hbnp[start:stop, :]
segments.append(seg)
start = stop
n = 1 # every nth segment is plotted
step = 1
for seg in segments[::n]:
do_arrow = True
# seg_weight = max(0.3, math.log(1+seg[-1,2]))
seg_weight = max(0.05, math.log(1 + seg[-1, 2]))
# print('seg weight', seg_weight)
seg = seg[:, 0:2].reshape(-1, 2)
seg_p = seg[~np.isnan(seg)].reshape((-1, 2))
start = neigh.kneighbors(seg_p[0, :].reshape(1, -1), return_distance=False)[0][0]
end = neigh.kneighbors(seg_p[-1, :].reshape(1, -1), return_distance=False)[0][0]
# print('start,end',[start, end])
if ([start, end] in edgelist_segments) | ([end, start] in edgelist_segments):
do_arrow = False
edgelist_segments.append([start, end])
direction_ = infer_direction_piegraph(start_node=start, end_node=end, CSM=CSM, velocity_weight=velocity_weight,
pt=pt)
direction = -1 if direction_ < 0 else 1
ax.plot(seg_p[:, 0], seg_p[:, 1], linewidth=linewidth_bundle * seg_weight, alpha=alpha_bundle, color=edge_color)
mid_point = math.floor(seg_p.shape[0] / 2)
if len(arrow_coords) > 0: # dont draw arrows in overlapping segments
for v1 in arrow_coords:
dist_ = dist_points(v1, v2=[seg_p[mid_point, 0], seg_p[mid_point, 1]])
# print('dist between points', dist_)
if dist_ < arrow_frequency * delta_x: do_arrow = False
if dist_ < arrow_frequency * delta_y: do_arrow = False
if (do_arrow == True) & (seg_p.shape[0] > 3):
ax.arrow(seg_p[mid_point, 0], seg_p[mid_point, 1],
seg_p[mid_point + (direction * step), 0] - seg_p[mid_point, 0],
seg_p[mid_point + (direction * step), 1] - seg_p[mid_point, 1],
lw=0, length_includes_head=False, head_width=headwidth_bundle, color=edge_color, shape='full',
alpha=0.6, zorder=5)
arrow_coords.append([seg_p[mid_point, 0], seg_p[mid_point, 1]])
if plot_clusters == True:
group_pop = np.ones([layout.shape[0], 1])
if via_object is not None:
for group_i in set(via_object.labels):
# n_groups = len(set(via_object.labels))
loc_i = np.where(via_object.labels == group_i)[0]
group_pop[group_i] = len(loc_i)
gp_scaling = 1000 / max(group_pop) # 500 / max(group_pop)
group_pop_scale = group_pop * gp_scaling * 0.5
c_edge, l_width = [], []
if via_object is not None:
terminal_clusters_placeholder = via_object.terminal_clusters
else:
terminal_clusters_placeholder = []
for ei, pti in enumerate(pt):
if ei in terminal_clusters_placeholder:
c_edge.append('red')
l_width.append(1.5)
else:
c_edge.append('gray')
l_width.append(0.0)
ax.scatter(layout[:, 0], layout[:, 1], s=group_pop_scale, c=pt, cmap=cmap,
edgecolors=c_edge,
alpha=1, zorder=3, linewidth=l_width)
if ax_text:
x_max_range = np.amax(layout[:, 0]) / 100
y_max_range = np.amax(layout[:, 1]) / 100
for ii in range(layout.shape[0]):
ax.text(layout[ii, 0] + max(x_max_range, y_max_range),
layout[ii, 1] + min(x_max_range, y_max_range),
'C' + str(ii) + 'pop' + str(int(group_pop[ii][0])),
color='black', zorder=4, fontsize=fontsize)
ax.set_title(title)
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')
ax.set_facecolor('white')
if return_fig_ax == True:
return fig, ax
else:
return ax
def _slow_sklearn_mds(via_graph: csr_matrix, X_pca: ndarray, t_diff_op: int = 1):
'''
:param via_graph: via_graph =via_object.csr_full_graph #single cell knn graph representation based on hnsw
:param t_diff_op:
:param X_pca ndarray adata_counts.obsm['X_pca'][:, 0:ncomps]
:return: ndarray
'''
from sklearn.preprocessing import normalize
via_graph.data = np.clip(via_graph.data, np.percentile(via_graph.data, 10), np.percentile(via_graph.data, 90))
row_stoch = normalize(via_graph, norm='l1', axis=1)
# note that the edge weights are affinities in via_graph
from sklearn import manifold
mds = manifold.MDS(
n_components=2,
max_iter=3000,
eps=1e-9,
random_state=0,
dissimilarity="precomputed",
n_jobs=2,
)
row_stoch = row_stoch ** t_diff_op # level of diffusion
temp = csr_matrix(X_pca)
X_mds = row_stoch * temp # matrix multiplication
X_mds = squareform(pdist(X_mds.todense()))
X_mds = mds.fit(X_mds).embedding_
# X_mds = squareform(pdist(adata_counts.obsm['X_pca'][:, 0:ncomps+20])) #no diffusion makes is less streamlined and compact. more fuzzy
return X_mds
[docs]def plot_piechart_only_viagraph(via_object, type_data='pt', gene_exp: list = [], cmap_piechart: str = 'rainbow', title='',
cmap: str = None, ax_text=True, dpi=150, headwidth_arrow=0.1, alpha_edge=0.4,
linewidth_edge=2, edge_color='darkblue', reference_labels=None, show_legend: bool = True,
pie_size_scale: float = 0.8, fontsize: float = 8, pt_visual_threshold: int = 99,
highlight_terminal_clusters: bool = True, size_node_notpiechart: float = 1,tune_edges:bool = False,initial_bandwidth=0.05, decay=0.9, edgebundle_pruning=0.5):
'''
plot two subplots with a clustergraph level representation of the viagraph showing true-label composition (lhs) and pseudotime/gene expression (rhs)
Returns matplotlib figure with two axes that plot the clustergraph using edge bundling
left axis shows the clustergraph with each node colored by annotated ground truth membership.
right axis shows the same clustergraph with each node colored by the pseudotime or gene expression
:param via_object: is class VIA (the same function also exists as a method of the class and an external plotting function
:param type_data: string default 'pt' for pseudotime colored nodes. or 'gene'
:param gene_exp: list of values (or column of dataframe) corresponding to feature or gene expression to be used to color nodes at CLUSTER level
:param cmap_piechart: str cmap for piechart categories
:param title: string
:param cmap: default None. automatically chooses coolwarm for gene expression or viridis_r for pseudotime
:param ax_text: Bool default= True. Annotates each node with cluster number and population of membership
:param dpi: int default = 150
:param headwidth_arrow: default = 0.1. width of arrowhead used to directed edges
:param reference_labels: None or list. list of categorical (str) labels for cluster composition of the piecharts (LHS subplot) length = n_samples.
:param pie_size_scale: float default=0.8 scaling factor of the piechart nodes
:param pt_visual_threshold: int (percentage) default = 95 corresponding to rescaling the visual color scale by clipping outlier cluster pseudotimes
:param highlight_terminal_clusters:bool = True (red border around terminal clusters)
:param size_node_notpiechart: scaling factor for node size of the viagraph (not the piechart part)
:param initial_bandwidth: (float = 0.05) increasing bw increases merging of minor edges. Only used when tune_edges = True
:param decay: (decay = 0.9) increasing decay increases merging of minor edges . Only used when tune_edges = True
:param edgebundle_pruning (float = 0.5). takes on values between 0-1. smaller value means more pruning away edges that can be visualised. Only used when tune_edges = True
:return: f, ax, ax1
'''
f, ax = plt.subplots( dpi=dpi)
node_pos = via_object.graph_node_pos
node_pos = np.asarray(node_pos)
if cmap is None: cmap = 'coolwarm' if type_data == 'gene' else 'viridis_r'
if type_data == 'pt':
pt = via_object.markov_hitting_times # via_object.scaled_hitting_times
threshold_high = np.percentile(pt, pt_visual_threshold)
pt_subset = [x for x in pt if x < threshold_high] # remove high outliers
new_upper_pt = np.percentile(pt_subset, pt_visual_threshold) # 'true' upper percentile after removing outliers
pt = [x if x < new_upper_pt else new_upper_pt for x in pt]
title_ax1 = "Pseudotime " + title
if (type_data == 'gene') | (len(gene_exp) > 0):
pt = gene_exp
title_ax1 = title
if reference_labels is None: reference_labels = via_object.true_label
n_groups = len(set(via_object.labels))
n_truegroups = len(set(reference_labels))
group_pop = np.zeros([n_groups, 1])
if type(reference_labels[0]) == int or type(reference_labels[0]) == float:
sorted_col_ = sorted(list(set(reference_labels)))
group_frac = pd.DataFrame(np.zeros([n_groups, n_truegroups]), columns=sorted_col_)
else:
sorted_col_ = list(set(reference_labels))
sorted_col_.sort()
group_frac = pd.DataFrame(np.zeros([n_groups, n_truegroups]),
columns=sorted_col_) # list(set(reference_labels))
via_object.cluster_population_dict = {}
set_labels = list(set(via_object.labels))
set_labels.sort()
for group_i in set_labels:
loc_i = np.where(via_object.labels == group_i)[0]
group_pop[group_i] = len(loc_i) # np.sum(loc_i) / 1000 + 1
via_object.cluster_population_dict[group_i] = len(loc_i)
true_label_in_group_i = list(np.asarray(reference_labels)[loc_i])
ll_temp = list(set(true_label_in_group_i))
for ii in ll_temp:
group_frac[ii][group_i] = true_label_in_group_i.count(ii)
line_true = np.linspace(0, 1, n_truegroups)
cmap_piechart_ = plt.get_cmap(cmap_piechart)
color_true_list = [cmap_piechart_(color) for color in line_true] # plt.cm.rainbow(color)
sct = ax.scatter(node_pos[:, 0], node_pos[:, 1],
c='white', edgecolors='face', s=group_pop, cmap=cmap_piechart)
bboxes = getbb(sct, ax)
ax = plot_viagraph_(ax, via_object= via_object, pt=pt, headwidth_bundle=headwidth_arrow,
alpha_bundle=alpha_edge, linewidth_bundle=linewidth_edge, edge_color=edge_color,
tune_edges=tune_edges, initial_bandwidth=initial_bandwidth, decay=decay,
edgebundle_pruning=edgebundle_pruning)
trans = ax.transData.transform
bbox = ax.get_position().get_points()
ax_x_min = bbox[0, 0]
ax_x_max = bbox[1, 0]
ax_y_min = bbox[0, 1]
ax_y_max = bbox[1, 1]
ax_len_x = ax_x_max - ax_x_min
ax_len_y = ax_y_max - ax_y_min
trans2 = ax.transAxes.inverted().transform
pie_axs = []
pie_size_ar = ((group_pop - np.min(group_pop)) / (np.max(group_pop) - np.min(group_pop)) + 0.5) / 10 # 10
for node_i in range(n_groups):
cluster_i_loc = np.where(np.asarray(via_object.labels) == node_i)[0]
majority_true = via_object.func_mode(list(np.asarray(reference_labels)[cluster_i_loc]))
pie_size = pie_size_ar[node_i][0] * pie_size_scale
x1, y1 = trans(node_pos[node_i]) # data coordinates
xa, ya = trans2((x1, y1)) # axis coordinates
xa = ax_x_min + (xa - pie_size / 2) * ax_len_x
ya = ax_y_min + (ya - pie_size / 2) * ax_len_y
# clip, the fruchterman layout sometimes places below figure
if ya < 0: ya = 0
if xa < 0: xa = 0
rect = [xa, ya, pie_size * ax_len_x, pie_size * ax_len_y]
frac = np.asarray([ff for ff in group_frac.iloc[node_i].values])
pie_axs.append(plt.axes(rect, frameon=False))
pie_axs[node_i].pie(frac, wedgeprops={'linewidth': 0.0}, colors=color_true_list)
pie_axs[node_i].set_xticks([])
pie_axs[node_i].set_yticks([])
pie_axs[node_i].set_aspect('equal')
# pie_axs[node_i].text(0.5, 0.5, graph_node_label[node_i])
if ax_text == True:
pie_axs[node_i].text(0.5, 0.5, str(majority_true)+'_c'+str(node_i), fontsize=fontsize)
#pie_axs[node_i].text(0.5, 0.5, 'c' + str(node_i), fontsize=fontsize)
patches, texts = pie_axs[node_i].pie(frac, wedgeprops={'linewidth': 0.0}, colors=color_true_list)
labels = list(set(reference_labels))
labels.sort()
if show_legend == True: plt.legend(patches, labels, loc=(-5, -5), fontsize=6, frameon=False)
if via_object.time_series == True:
ti = 'Cluster Composition. K=' + str(via_object.knn) + '. ncomp = ' + str(via_object.ncomp) + 'knnseq_' + str(
via_object.knn_sequential)
elif via_object.do_spatial_knn == True:
ti = 'Cluster Composition. K=' + str(via_object.knn) + '. ncomp = ' + str(via_object.ncomp) + 'SpatKnn_' + str(
via_object.spatial_knn)
else:
ti = 'Cluster Composition. K=' + str(via_object.knn) + '. ncomp = ' + str(via_object.ncomp)
ax.set_title(ti)
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
f.patch.set_visible(False)
ax.axis('off')
ax.set_facecolor('white')
return f, ax
[docs]def plot_piechart_viagraph(via_object, type_data='pt', gene_exp: list = [], cmap_piechart: str = 'rainbow', title='',
cmap: str = None, ax_text=True, dpi=150, headwidth_arrow=0.1, alpha_edge=0.4,
linewidth_edge=2, edge_color='darkblue', reference_labels=None, show_legend: bool = True,
pie_size_scale: float = 0.8, fontsize: float = 8, pt_visual_threshold: int = 99,
highlight_terminal_clusters: bool = True, size_node_notpiechart: float = 1,tune_edges:bool = False,initial_bandwidth=0.05, decay=0.9, edgebundle_pruning=0.5):
'''
plot two subplots with a clustergraph level representation of the viagraph showing true-label composition (lhs) and pseudotime/gene expression (rhs)
Returns matplotlib figure with two axes that plot the clustergraph using edge bundling
left axis shows the clustergraph with each node colored by annotated ground truth membership.
right axis shows the same clustergraph with each node colored by the pseudotime or gene expression
:param via_object: is class VIA (the same function also exists as a method of the class and an external plotting function
:param type_data: string default 'pt' for pseudotime colored nodes. or 'gene'
:param gene_exp: list of values (or column of dataframe) corresponding to feature or gene expression to be used to color nodes at CLUSTER level
:param cmap_piechart: str cmap for piechart categories
:param title: string
:param cmap: default None. automatically chooses coolwarm for gene expression or viridis_r for pseudotime
:param ax_text: Bool default= True. Annotates each node with cluster number and population of membership
:param dpi: int default = 150
:param headwidth_arrow: default = 0.1. width of arrowhead used to directed edges
:param reference_labels: None or list. list of categorical (str) labels for cluster composition of the piecharts (LHS subplot) length = n_samples.
:param pie_size_scale: float default=0.8 scaling factor of the piechart nodes
:param pt_visual_threshold: int (percentage) default = 95 corresponding to rescaling the visual color scale by clipping outlier cluster pseudotimes
:param highlight_terminal_clusters:bool = True (red border around terminal clusters)
:param size_node_notpiechart: scaling factor for node size of the viagraph (not the piechart part)
:param initial_bandwidth: (float = 0.05) increasing bw increases merging of minor edges. Only used when tune_edges = True
:param decay: (decay = 0.9) increasing decay increases merging of minor edges . Only used when tune_edges = True
:param edgebundle_pruning (float = 0.5). takes on values between 0-1. smaller value means more pruning away edges that can be visualised. Only used when tune_edges = True
:return: f, ax, ax1
'''
from mpl_toolkits.axes_grid1 import make_axes_locatable
f, ((ax, ax1)) = plt.subplots(1, 2, sharey=True, dpi=dpi)
node_pos = via_object.graph_node_pos
node_pos = np.asarray(node_pos)
if cmap is None: cmap = 'coolwarm' if type_data == 'gene' else 'viridis_r'
if type_data == 'pt':
pt = via_object.markov_hitting_times # via_object.scaled_hitting_times
threshold_high = np.percentile(pt, pt_visual_threshold)
pt_subset = [x for x in pt if x < threshold_high] # remove high outliers
new_upper_pt = np.percentile(pt_subset, pt_visual_threshold) # 'true' upper percentile after removing outliers
pt = [x if x < new_upper_pt else new_upper_pt for x in pt]
title_ax1 = "Pseudotime " + title
if (type_data == 'gene') | (len(gene_exp) > 0):
pt = gene_exp
title_ax1 = title
if reference_labels is None: reference_labels = via_object.true_label
n_groups = len(set(via_object.labels))
n_truegroups = len(set(reference_labels))
group_pop = np.zeros([n_groups, 1])
if type(reference_labels[0]) == int or type(reference_labels[0]) == float:
sorted_col_ = sorted(list(set(reference_labels)))
group_frac = pd.DataFrame(np.zeros([n_groups, n_truegroups]), columns=sorted_col_)
else:
sorted_col_ = list(set(reference_labels))
sorted_col_.sort()
group_frac = pd.DataFrame(np.zeros([n_groups, n_truegroups]),
columns=sorted_col_) # list(set(reference_labels))
via_object.cluster_population_dict = {}
set_labels = list(set(via_object.labels))
set_labels.sort()
for group_i in set_labels:
loc_i = np.where(via_object.labels == group_i)[0]
group_pop[group_i] = len(loc_i) # np.sum(loc_i) / 1000 + 1
via_object.cluster_population_dict[group_i] = len(loc_i)
true_label_in_group_i = list(np.asarray(reference_labels)[loc_i])
ll_temp = list(set(true_label_in_group_i))
for ii in ll_temp:
group_frac[ii][group_i] = true_label_in_group_i.count(ii)
line_true = np.linspace(0, 1, n_truegroups)
cmap_piechart_ = plt.get_cmap(cmap_piechart)
color_true_list = [cmap_piechart_(color) for color in line_true] # plt.cm.rainbow(color)
sct = ax.scatter(node_pos[:, 0], node_pos[:, 1],
c='white', edgecolors='face', s=group_pop, cmap=cmap_piechart)
bboxes = getbb(sct, ax)
print('tune edges', tune_edges)
'''
ax = plot_viagraph_(ax, via_object.hammerbundle_cluster, layout=via_object.graph_node_pos, CSM=via_object.CSM,
velocity_weight=via_object.velo_weight, pt=pt, headwidth_bundle=headwidth_arrow,
alpha_bundle=alpha_edge, linewidth_bundle=linewidth_edge, edge_color=edge_color,tune_edges = tune_edges,initial_bandwidth=initial_bandwidth, decay=decay, edgebundle_pruning=edgebundle_pruning)
'''
ax = plot_viagraph_(ax, via_object= via_object, pt=pt, headwidth_bundle=headwidth_arrow,
alpha_bundle=alpha_edge, linewidth_bundle=linewidth_edge, edge_color=edge_color,
tune_edges=tune_edges, initial_bandwidth=initial_bandwidth, decay=decay,
edgebundle_pruning=edgebundle_pruning)
trans = ax.transData.transform
bbox = ax.get_position().get_points()
ax_x_min = bbox[0, 0]
ax_x_max = bbox[1, 0]
ax_y_min = bbox[0, 1]
ax_y_max = bbox[1, 1]
ax_len_x = ax_x_max - ax_x_min
ax_len_y = ax_y_max - ax_y_min
trans2 = ax.transAxes.inverted().transform
pie_axs = []
pie_size_ar = ((group_pop - np.min(group_pop)) / (np.max(group_pop) - np.min(group_pop)) + 0.5) / 10 # 10
for node_i in range(n_groups):
cluster_i_loc = np.where(np.asarray(via_object.labels) == node_i)[0]
majority_true = via_object.func_mode(list(np.asarray(reference_labels)[cluster_i_loc]))
pie_size = pie_size_ar[node_i][0] * pie_size_scale
x1, y1 = trans(node_pos[node_i]) # data coordinates
xa, ya = trans2((x1, y1)) # axis coordinates
xa = ax_x_min + (xa - pie_size / 2) * ax_len_x
ya = ax_y_min + (ya - pie_size / 2) * ax_len_y
# clip, the fruchterman layout sometimes places below figure
if ya < 0: ya = 0
if xa < 0: xa = 0
rect = [xa, ya, pie_size * ax_len_x, pie_size * ax_len_y]
frac = np.asarray([ff for ff in group_frac.iloc[node_i].values])
pie_axs.append(plt.axes(rect, frameon=False))
pie_axs[node_i].pie(frac, wedgeprops={'linewidth': 0.0}, colors=color_true_list)
pie_axs[node_i].set_xticks([])
pie_axs[node_i].set_yticks([])
pie_axs[node_i].set_aspect('equal')
# pie_axs[node_i].text(0.5, 0.5, graph_node_label[node_i])
if ax_text == True:
pie_axs[node_i].text(0.5, 0.5, str(majority_true)+'_c'+str(node_i), fontsize=fontsize)
#pie_axs[node_i].text(0.5, 0.5, 'c' + str(node_i), fontsize=fontsize)
patches, texts = pie_axs[node_i].pie(frac, wedgeprops={'linewidth': 0.0}, colors=color_true_list)
labels = list(set(reference_labels))
labels.sort()
if show_legend == True: plt.legend(patches, labels, loc=(-5, -5), fontsize=6, frameon=False)
if via_object.time_series == True:
ti = 'Cluster Composition. K=' + str(via_object.knn) + '. ncomp = ' + str(via_object.ncomp) + 'knnseq_' + str(
via_object.knn_sequential)
elif via_object.do_spatial_knn == True:
ti = 'Cluster Composition. K=' + str(via_object.knn) + '. ncomp = ' + str(via_object.ncomp) + 'SpatKnn_' + str(
via_object.spatial_knn)
else:
ti = 'Cluster Composition. K=' + str(via_object.knn) + '. ncomp = ' + str(via_object.ncomp)
ax.set_title(ti)
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
title_list = [title_ax1]
for i, ax_i in enumerate([ax1]):
# pt = via_object.markov_hitting_times if type_data == 'pt' else gene_exp
c_edge, l_width = [], []
for ei, pti in enumerate(pt):
if ei in via_object.terminal_clusters:
c_edge.append('red')
if not highlight_terminal_clusters:
l_width.append(0)
else:
l_width.append(1.5)
else:
c_edge.append('gray')
l_width.append(0.0)
gp_scaling = 1000 / max(group_pop)
group_pop_scale = group_pop * gp_scaling * 0.5
'''
ax_i = plot_viagraph_(ax_i, via_object.hammerbundle_cluster, layout=via_object.graph_node_pos,
CSM=via_object.CSM, velocity_weight=via_object.velo_weight, pt=pt,
headwidth_bundle=headwidth_arrow, alpha_bundle=alpha_edge,
linewidth_bundle=linewidth_edge, edge_color=edge_color,tune_edges=tune_edges,initial_bandwidth=initial_bandwidth, decay=decay, edgebundle_pruning=edgebundle_pruning)
'''
ax = plot_viagraph_(ax_i, via_object=via_object, pt=pt, headwidth_bundle=headwidth_arrow,
alpha_bundle=alpha_edge, linewidth_bundle=linewidth_edge, edge_color=edge_color,
tune_edges=tune_edges, initial_bandwidth=initial_bandwidth, decay=decay,
edgebundle_pruning=edgebundle_pruning)
im1 = ax_i.scatter(node_pos[:, 0], node_pos[:, 1], s=group_pop_scale * size_node_notpiechart, c=pt, cmap=cmap,
edgecolors=c_edge,
alpha=1, zorder=3, linewidth=l_width)
if ax_text:
x_max_range = np.amax(node_pos[:, 0]) / 100
y_max_range = np.amax(node_pos[:, 1]) / 100
for ii in range(node_pos.shape[0]):
ax_i.text(node_pos[ii, 0] + max(x_max_range, y_max_range),
node_pos[ii, 1] + min(x_max_range, y_max_range),
'C' + str(ii) + 'pop' + str(int(group_pop[ii][0])),
color='black', zorder=4, fontsize=fontsize)
ax_i.set_title(title_list[i])
ax_i.grid(False)
ax_i.set_xticks([])
ax_i.set_yticks([])
divider = make_axes_locatable(ax1)
cax = divider.append_axes('right', size='5%', pad=0.05)
if type_data == 'pt':
f.colorbar(im1, cax=cax, orientation='vertical', label='pseudotime')
else:
f.colorbar(im1, cax=cax, orientation='vertical', label='Gene expression')
f.patch.set_visible(False)
ax1.axis('off')
ax.axis('off')
ax.set_facecolor('white')
ax1.set_facecolor('white')
return f, ax, ax1
[docs]def plot_clusters_spatial(spatial_coords, clusters=[], via_labels= [], title_sup='', fontsize_=6,color='green', s:int=5, alpha=0.5, xlim_max=None, ylim_max=None,xlim_min=None, ylim_min=None, reference_labels:list=[], reference_labels2:list = [],equal_axes_lim: bool = True):
'''
:param spatial_coords: ndarray of spatial coords ncellsx2 dims
:param clusters: the clusters in via_object.labels which you want to plot (usually a subset of the total number of clusters)
:param via_labels: via_object.labels (cluster level labels, list of n_cells length)
:param title_sup: title of the overall figure
:param fontsize_: fontsize for legend
:param color: color of scatter points
:param s: size of scatter points
:param alpha: float alpha transparency of scatter (0 fully transporent, 1 is opaque)
:param xlim_max: limits of axes
:param ylim_max: limits of axes
:param xlim_min: limits of axes
:param ylim_min: limits of axes
:param reference_labels: optional list of single-cell labels (e.g. time, annotation). this will be used in the title of each subplot to note the majority cell (ref2) type for each cluster
:param reference_labels2: optional list of single-cell labels (e.g. time, annotation). this will be used in the title of each subplot to note the majority cell (ref2) type for each cluster
:return: fig, axs
'''
if xlim_max is None: xlim_max = np.max(spatial_coords[:, 0])
if ylim_max is None: ylim_max = np.max(spatial_coords[:, 1])
if xlim_min is None: xlim_min = np.min(spatial_coords[:, 0])
if ylim_min is None: ylim_min = np.min(spatial_coords[:, 1])
n_clusters = len(clusters)
col_init = min(4,n_clusters)
fig_nrows, mod = divmod(n_clusters, col_init)
if mod == 0: fig_nrows = fig_nrows
if mod != 0: fig_nrows += 1
fig_ncols = col_init
fig, axs = plt.subplots(fig_nrows, fig_ncols)
fig.patch.set_visible(False)
i_gene = 0 # counter for number of genes
i_terminal = 0 # counter for terminal cluster
# for i in range(n_terminal_states): #[0]
for r in range(fig_nrows):
for c in range(fig_ncols):
if (i_gene < n_clusters):
cluster_i = clusters[i_gene]
df = pd.DataFrame(spatial_coords, columns=['x', 'y'])
df['v0'] = via_labels
if len(reference_labels)>0:
df['reference'] = reference_labels
if len(reference_labels2) > 0:
df['reference2'] = reference_labels2
df = df[df.v0 == cluster_i]
df_coords_majref = pd.DataFrame(spatial_coords, columns=['x', 'y'])
majority_reference = ''
majority_reference2 = ''
if len(reference_labels)>0:
reference_labels_sub = list(df['reference'])
majority_reference = func_mode(reference_labels_sub)
if len(reference_labels2) > 0:
reference2_labels_sub = list(df['reference2'])
majority_reference2 = func_mode(reference2_labels_sub)
df_coords_majref['reference'] = reference_labels
df_coords_majref = df_coords_majref[df_coords_majref.reference == majority_reference]
df_coords_majref = df_coords_majref[['x', 'y']].values
else: df_coords_majref = df_coords_majref.values
emb = df[['x', 'y']].values
if not equal_axes_lim:
xlim_max = np.max(emb[:, 0]) *1.2
ylim_max = np.max(emb[:, 1]) *1.2
xlim_min = np.min(emb[:, 0]) *1.2
ylim_min = np.min(emb[:, 1]) *1.2
#color = cmap_[i_gene]
if fig_nrows > 1:
axs[r, c].scatter(df_coords_majref[:, 0], df_coords_majref[:, 1], c='gray', s=1, alpha=0.2)
axs[r, c].scatter(emb[:, 0], emb[:, 1], c=color, s=s, alpha=alpha )
if len(reference_labels)>0: axs[r,c].set_title('c:' + str(cluster_i)+'_'+str(majority_reference)+'_'+str(majority_reference2))
else: axs[r,c].set_title('c:' + str(cluster_i))#
axs[r,c].set_xlim([xlim_min, xlim_max]) #axis limits
axs[r, c].set_ylim([ylim_min, ylim_max]) # axis limits
# Set tick font size
for label in (axs[r, c].get_xticklabels() + axs[r, c].get_yticklabels()):
label.set_fontsize(fontsize_ - 1)
if i_gene == n_clusters - 1:
axs[r, c].legend(frameon=False, fontsize=fontsize_)
axs[r, c].spines['top'].set_visible(False)
axs[r, c].spines['right'].set_visible(False)
axs[r, c].grid(False)
elif fig_ncols ==1:
axs.scatter(df_coords_majref[:, 0], df_coords_majref[:, 1], c='gray', s=s, alpha=alpha*0.5)
axs.scatter(emb[:, 0], emb[:, 1], c=color, s=s,alpha=alpha )
if len(reference_labels)>0: axs.set_title('c:' + str(cluster_i)+'_'+str(majority_reference)+'_'+str(majority_reference2))
else: axs.set_title('c:' + str(cluster_i))#
axs.set_xlim([xlim_min, xlim_max]) # axis limits
axs.set_ylim([ylim_min, ylim_max]) # axis limits
# Set tick font size
for label in (axs.get_xticklabels() + axs.get_yticklabels()):
label.set_fontsize(fontsize_ - 1)
if i_gene == n_clusters - 1:
axs.legend(frameon=False, fontsize=fontsize_)
axs.spines['top'].set_visible(False)
axs.spines['right'].set_visible(False)
axs.grid(False)
else:
print('df_coords_majref shape', df_coords_majref.shape)
axs[c].scatter(df_coords_majref[:, 0], df_coords_majref[:, 1], c='gray', s=s, alpha=0.1)
axs[c].scatter(emb[:, 0], emb[:, 1], c=color, s=s,alpha=alpha )
if len(reference_labels)>0: axs[c].set_title('c:' + str(cluster_i)+'_'+str(majority_reference)+'_'+str(majority_reference2))
else: axs[c].set_title('c:' + str(cluster_i))#
axs[ c].set_xlim([xlim_min, xlim_max]) # axis limits
axs[ c].set_ylim([ylim_min, ylim_max]) # axis limits
# Set tick font size
for label in (axs[c].get_xticklabels() + axs[c].get_yticklabels()):
label.set_fontsize(fontsize_ - 1)
if i_gene == n_clusters - 1:
axs[c].legend(frameon=False, fontsize=fontsize_)
axs[c].spines['top'].set_visible(False)
axs[c].spines['right'].set_visible(False)
axs[c].grid(False)
i_gene += 1
else:
if fig_nrows > 1:
axs[r, c].axis('off')
axs[r, c].grid(False)
else:
axs[c].axis('off')
axs[c].grid(False)
fig.suptitle(title_sup, fontsize=8)
return fig, axs
[docs]def make_dict_of_clusters_for_each_celltype(via_labels:list = [], true_label:list = [], verbose:bool = False):
'''
:param via_labels: usually set to via_object.labels. list of length n_cells of cluster membership
:param true_label: cell type labels (list of length n_cells)
:return:
'''
df_mode = pd.DataFrame()
df_mode['cluster'] = via_labels
df_mode['class_str'] = true_label
majority_cluster_population_dict = df_mode.groupby(['cluster'])['class_str'].agg(
lambda x: pd.Series.mode(x)[0])
majority_cluster_population_dict = majority_cluster_population_dict.to_dict()
print(f'dict cluster to majority pop: {majority_cluster_population_dict}')
class_to_cluster_dict = collect_dictionary(majority_cluster_population_dict)
print('list of clusters for each majority', class_to_cluster_dict)
return class_to_cluster_dict
[docs]def plot_all_spatial_clusters(spatial_coords, true_label, via_labels, save_to:str = '', color_dict:dict = {}, cmap:str = 'rainbow', alpha = 0.4, s=5, verbose:bool=False, reference_labels:list=[],reference_labels2:list=[]):
'''
:param spatial_coords: ndarray of x,y coords of tissue location of cells (ncells x2)
:param true_label: categorial labels (list of length n_cells)
:param via_labels: cluster membership labels (list of length n_cells)
:param save_to:
:param color_dict: optional dict with keys corresponding to true_label type. e.g. {true_label_celltype1: 'green',true_label_celltype2: 'red'}
:param cmap: string default = rainbow
:param reference_labels: optional list of single-cell labels (e.g. time, annotation). Used to selectively provide a grey background to cells not in the cluster being inspected. If you have multipe time points, then set reference_labels to the time_points. All cells in the most prevalent timepoint seen in the cluster of interest will be plotted as a background
:param reference_labels2: optional list of single-cell labels (e.g. time, annotation). this will be used in the title of each subplot to note the majority cell (ref2) type for each cluster
:return: list lists of [[fig1, axs_set1], [fig2, axs_set2],...]
'''
clusters_for_each_celltype_dict=make_dict_of_clusters_for_each_celltype(via_labels, true_label)
if verbose: print(clusters_for_each_celltype_dict)
keys = list(sorted(list(clusters_for_each_celltype_dict.keys())))
print('keys', keys)
potential_majority_celltype_keys = list(sorted(list(set(true_label))))
list_of_figs=[]
if len(color_dict) ==0:
set_labels = list(set(true_label))
set_labels.sort(reverse=False) # True)
palette = cm.get_cmap(cmap, len(set_labels))
cmap_ = palette(range(len(set_labels)))
for index, value in enumerate(set_labels):
color_dict[value] = cmap_[index]
for i, keyi in enumerate(potential_majority_celltype_keys):
if keyi in keys:
color = color_dict[keyi] # cmap_[i]
clusters_list = clusters_for_each_celltype_dict[keyi]
f, axs = plot_clusters_spatial(spatial_coords, clusters=clusters_list, via_labels=via_labels,
title_sup=keyi, color=color, s=s,
alpha=alpha, reference_labels2=reference_labels2, reference_labels=reference_labels)
list_of_figs.append([f, axs])
fig_nrows, mod = divmod(len(clusters_list), 4)
if mod == 0: fig_nrows = fig_nrows
if mod != 0: fig_nrows += 1
f.set_size_inches(10, 2 * fig_nrows)
#f.savefig( WORK_PATH + 'Viagraphs/Bregma' + str(int(bregma)) + 'Spatknn' + str(spatial_knn) + 'cluster_' + keyi[ 0:4] + '.png')
f.savefig(save_to + 'cluster_' + keyi[0:4] + '.png')
else:
print('No cluster has a majority population of ', keyi)
return list_of_figs
[docs]def animate_atlas_old(hammerbundle_dict=None, via_object=None, linewidth_bundle=2, frame_interval: int = 10,
n_milestones: int = None, facecolor: str = 'white', cmap: str = 'plasma_r', extra_title_text='',
size_scatter: int = 1, alpha_scatter: float = 0.2,
saveto='/home/user/Trajectory/Datasets/animation_default.gif', time_series_labels: list = None, lineage_pathway = [],
sc_labels_numeric: list = None, t_diff_factor:float=0.25, show_sc_embedding:bool=False, sc_emb=None, sc_size_scatter:float=10, sc_alpha_scatter:float=0.2, n_intervals:int = 50):
'''
:param ax: axis to plot on
:param hammer_bundle: hammerbundle object with coordinates of all the edges to draw
:param layout: coords of cluster nodes and optionally also contains the numeric value associated with each cluster (such as time-stamp) layout[['x','y','numeric label']] sc/cluster/milestone level
:param CSM: cosine similarity matrix. cosine similarity between the RNA velocity between neighbors and the change in gene expression between these neighbors. Only used when available
:param velocity_weight: percentage weightage given to the RNA velocity based transition matrix
:param pt: cluster-level pseudotime
:param alpha_bundle: alpha when drawing lines
:param linewidth_bundle: linewidth of bundled lines
:param edge_color:
:param frame_interval: smaller number, faster refresh and video
:param facecolor: default = white
:param headwidth_bundle: headwidth of arrows used in bundled edges
:param arrow_frequency: min dist between arrows (bundled edges otherwise have overcrowding of arrows)
:param show_direction: True will draw arrows along the lines to indicate direction
:param milestone_edges: pandas DataFrame milestone_edges[['source','target']]
:param t_diff_factor scaling the average the time intervals (0.25 means that for each frame, the time is progressed by 0.25* mean_time_differernce_between adjacent times (only used when sc_labels_numeric are directly passed instead of using pseudotime)
:param show_sc_embedding: plot the single cell embedding under the edges
:param sc_emb numpy array of single cell embedding (ncells x 2)
:param sc_alpha_scatter, Alpha transparency value of points of single cells (1 is opaque, 0 is fully transparent)
:param sc_size_scatter. size of scatter points of single cells
:param time_series_labels, should be a single-cell level list (n_cells) of numerical values that form a discrete set. I.e. not continuous like pseudotime,
:return: axis with bundled edges plotted
'''
import tqdm
if show_sc_embedding:
if sc_emb is None:
sc_emb= via_object.embedding
if sc_emb is None:
print('please provide a single cell embedding as an array')
return
if hammerbundle_dict is None:
if via_object is None:
print(
f'{datetime.now()}\tERROR: Hammerbundle_dict needs to be provided either through via_object or by running make_edgebundle_milestone()')
else:
hammerbundle_dict = via_object.hammerbundle_milestone_dict
if hammerbundle_dict is None:
if n_milestones is None: n_milestones = min(via_object.nsamples, 150)
if sc_labels_numeric is None:
if via_object.time_series_labels is not None:
sc_labels_numeric = via_object.time_series_labels
else:
sc_labels_numeric = via_object.single_cell_pt_markov
hammerbundle_dict = make_edgebundle_milestone(via_object=via_object,
embedding=via_object.embedding,
sc_graph=via_object.ig_full_graph,
n_milestones=n_milestones,
sc_labels_numeric=sc_labels_numeric,
initial_bandwidth=0.02, decay=0.7, weighted=True)
hammer_bundle = hammerbundle_dict['hammerbundle']
layout = hammerbundle_dict['milestone_embedding'][['x', 'y']].values
milestone_edges = hammerbundle_dict['edges']
milestone_numeric_values = hammerbundle_dict['milestone_embedding']['numeric label']
milestone_pt = hammerbundle_dict['milestone_embedding']['pt'] # used when plotting arrows
else:
hammer_bundle = hammerbundle_dict['hammerbundle']
layout = hammerbundle_dict['milestone_embedding'][['x', 'y']].values
milestone_edges = hammerbundle_dict['edges']
milestone_numeric_values = hammerbundle_dict['milestone_embedding']['numeric label']
milestone_pt = hammerbundle_dict['milestone_embedding']['pt'] # used when plotting arrows
fig, ax = plt.subplots(facecolor=facecolor, figsize=(15, 12))
time_thresh = min(milestone_numeric_values)
# ax.set_facecolor(facecolor)
ax.grid(False)
x_ = [l[0] for l in layout]
y_ = [l[1] for l in layout]
# min_x, max_x = min(x_), max(x_)
# min_y, max_y = min(y_), max(y_)
delta_x = max(x_) - min(x_)
delta_y = max(y_) - min(y_)
layout = np.asarray(layout)
# make a knn so we can find which clustergraph nodes the segments start and end at
# get each segment. these are separated by nans.
hbnp = hammer_bundle.to_numpy()
splits = (np.isnan(hbnp[:, 0])).nonzero()[0] # location of each nan values
edgelist_segments = []
start = 0
segments = []
arrow_coords = []
seg_len = [] # length of a segment
for stop in splits:
seg = hbnp[start:stop, :]
segments.append(seg)
seg_len.append(seg.shape[0])
start = stop
min_seg_length = min(seg_len)
max_seg_length = max(seg_len)
# mean_seg_length = sum(seg_len)/len(seg_len)
seg_len = np.asarray(seg_len)
seg_len = np.clip(seg_len, a_min=np.percentile(seg_len, 10),
a_max=np.percentile(seg_len, 90))
step = 1 # every step'th segment is plotted
cmap = matplotlib.cm.get_cmap(cmap)
if milestone_numeric_values is not None:
max_numerical_value = max(milestone_numeric_values)
min_numerical_value = min(milestone_numeric_values)
seg_count = 0
# print('numeric vals', milestone_numeric_values)
loc_time_thresh = np.where(np.asarray(milestone_numeric_values) <= time_thresh)[0].tolist()
i_sorted_numeric_values = np.argsort(milestone_numeric_values)
ee = int(len(milestone_numeric_values) / n_intervals)
print('ee',ee)
loc_time_thresh = i_sorted_numeric_values[0:ee]
for ll in loc_time_thresh:
print('sorted numeric milestone',milestone_numeric_values[ll])
# print('loc time thres', loc_time_thresh)
milestone_edges['source_thresh'] = milestone_edges['source'].isin(
loc_time_thresh) # apply(lambda x: any([k in x for k in loc_time_thresh]))
# print(milestone_edges[0:10])
idx = milestone_edges.index[milestone_edges['source_thresh']].tolist()
# print('loc time thres', time_thresh, loc_time_thresh)
for i in idx:
seg = segments[i]
source_milestone = milestone_edges['source'].values[i]
target_milestone = milestone_edges['target'].values[i]
# seg_weight = max(0.3, math.log(1+seg[-1,2])) seg[-1,2] column index 2 has the weight information
seg_weight = seg[-1, 2] * seg_len[i] / (
max_seg_length - min_seg_length) ##seg.shape[0] / (max_seg_length - min_seg_length)
# cant' quite decide yet if sigmoid is desirable
# seg_weight=sigmoid_scalar(seg.shape[0] / (max_seg_length - min_seg_length), scale=5, shift=mean_seg_length / (max_seg_length - min_seg_length))
alpha_bundle = max(seg_weight, 0.1) # max(0.1, math.log(1 + seg[-1, 2]))
if alpha_bundle > 1: alpha_bundle = 1
if milestone_numeric_values is not None:
source_milestone_numerical_value = milestone_numeric_values[source_milestone]
target_milestone_numerical_value = milestone_numeric_values[target_milestone]
# print('source milestone', source_milestone_numerical_value)
# print('target milestone', target_milestone_numerical_value)
rgba_milestone_value = min(source_milestone_numerical_value, target_milestone_numerical_value)
rgba = cmap((rgba_milestone_value - min_numerical_value) / (max_numerical_value - min_numerical_value))
else:
rgba = cmap(min(seg_weight, 0.95)) # cmap(seg.shape[0]/(max_seg_length-min_seg_length))
# if seg_weight>0.05: seg_weight=0.1
if seg_count % 10000 == 0: print('seg weight', seg_weight)
seg = seg[:, 0:2].reshape(-1, 2)
seg_p = seg[~np.isnan(seg)].reshape((-1, 2))
ax.plot(seg_p[:, 0], seg_p[:, 1], linewidth=linewidth_bundle * seg_weight, alpha=alpha_bundle,
color=rgba) # edge_color )
seg_count += 1
milestone_numeric_values_rgba = []
if len(lineage_pathway) > 0:
milestone_lin_values = hammerbundle_dict['milestone_embedding'][
'sc_lineage_probability_' + str(lineage_pathway[0])]
p1_sc_bp = np.nan_to_num(via_object.single_cell_bp, nan=0.0, posinf=0.0,
neginf=0.0) # single cell lineage probabilities sc pb
# row normalize
row_sums = p1_sc_bp.sum(axis=1)
p1_sc_bp = p1_sc_bp / row_sums[:,
np.newaxis] # make rowsums a column vector where i'th entry is sum of i'th row in p1-sc-bp
ts_cluster_number = lineage_pathway[0]
ts_array_original = np.asarray(via_object.terminal_clusters)
loc_ts_current = np.where(ts_array_original == ts_cluster_number)[0][0]
print(
f'location of {lineage_pathway[0]} is at {np.where(ts_array_original == ts_cluster_number)[0]} and {loc_ts_current}')
p1_sc_bp = p1_sc_bp[:, loc_ts_current]
rgba_lineage_sc = []
rgba_lineage_milestone = []
for i in p1_sc_bp:
rgba_lineage_sc_ = cmap((i - min(p1_sc_bp)) / (max(p1_sc_bp) - min(p1_sc_bp)))
rgba_lineage_sc.append(rgba_lineage_sc_)
for i in milestone_lin_values:
rgba_lineage_milestone_ = cmap((i - min(milestone_lin_values)) / (max(milestone_lin_values) - min(milestone_lin_values)))
rgba_lineage_milestone.append(rgba_lineage_milestone_)
print('here1 in animate()')
if milestone_numeric_values is not None:
for i in milestone_numeric_values:
rgba_ = cmap((i - min_numerical_value) / (max_numerical_value - min_numerical_value))
milestone_numeric_values_rgba.append(rgba_)
ax.scatter(layout[loc_time_thresh, 0], layout[loc_time_thresh, 1], s=size_scatter,
c=np.asarray(milestone_numeric_values_rgba)[loc_time_thresh], alpha=alpha_scatter)
# if we dont plot all the points, then the size of axis changes and the location of the graph moves/changes as more points are added
ax.scatter(layout[:, 0], layout[:, 1], s=size_scatter,
c=np.asarray(milestone_numeric_values_rgba), alpha=0)
if show_sc_embedding:
if len(lineage_pathway)>0: ax.scatter(sc_emb[:, 0], sc_emb[:, 1], s=sc_size_scatter,
c=p1_sc_bp, alpha=sc_alpha_scatter, cmap=cmap)
ax.scatter(sc_emb[:, 0], sc_emb[:, 1], s=size_scatter,
c='blue', alpha=0)
else:
ax.scatter(layout[:, 0], layout[:, 1], s=size_scatter, c='red', alpha=alpha_scatter)
print('here2 in animate()')
ax.set_facecolor(facecolor)
ax.axis('off')
time = datetime.now()
time = time.strftime("%H:%M")
title_ = 'n_milestones = ' + str(int(layout.shape[0])) + ' time: ' + time + ' ' + extra_title_text
ax.set_title(label=title_, color='black')
print(f"{datetime.now()}\tFinished plotting edge bundle")
if time_series_labels is not None: #over-ride via_object's saved time_series_labels and/or pseudotime
time_series_set_order = list(sorted(list(set(time_series_labels))))
t_diff_mean = t_diff_factor * np.mean(
np.array([int(abs(y - x)) for x, y in zip(time_series_set_order[:-1], time_series_set_order[1:])]))
print('t_diff_mean:', t_diff_mean)
cycles = (max_numerical_value - min_numerical_value) / t_diff_mean
print('number cycles', cycles)
else:
if via_object is not None:
time_series_labels = via_object.time_series_labels
if time_series_labels is None:
time_series_labels = via_object.single_cell_pt_markov
time_series_labels_int = [int(i * 10) for i in time_series_labels] #to be able to get "mean t_diff" if time_series_labels are continuous rather than a set of discrete values
time_series_set_order = list(sorted(list(set(time_series_labels_int))))
print('times series order set using pseudotime', time_series_set_order)
t_diff_mean = t_diff_factor * np.mean(np.array([int(abs(y - x)) for x, y in zip(time_series_set_order[:-1],
time_series_set_order[
1:])])) / 10 # divide by 10 because we multiplied the single_cell_pt_markov by 10
cycles = (max_numerical_value - min_numerical_value) / (t_diff_mean)
print('number cycles if no time_series labels given', cycles)
min_time_series_labels = min(time_series_labels)
max_time_series_labels = max(time_series_labels)
sc_rgba = []
for i in time_series_labels:
sc_rgba_ = cmap((i - min_time_series_labels) / (max_time_series_labels - min_time_series_labels))
sc_rgba.append(sc_rgba_)
if show_sc_embedding:
i_sorted_sc_time = np.argsort(time_series_labels)
def update_edgebundle(frame_no):
print('inside update', frame_no, 'out of', int(cycles), 'cycles')
if len(time_series_labels) > 0:
time_thresh = min_numerical_value + frame_no % (cycles + 1) * t_diff_mean
print('time thresh', time_thresh)
else:
#n_intervals = 10
time_thresh = min_numerical_value + (frame_no % n_intervals) * (
max_numerical_value - min_numerical_value) / n_intervals
#time-based loc_time_thresh
loc_time_thresh = np.where((np.asarray(milestone_numeric_values) <= time_thresh) & ( np.asarray(milestone_numeric_values) > time_thresh - t_diff_mean))[0].tolist()
sc_loc_time_thresh = np.where((np.asarray(time_series_labels) <= time_thresh) & ( np.asarray(time_series_labels) > time_thresh - t_diff_mean))[0].tolist()
milestone_edges['source_thresh'] = milestone_edges['source'].isin(
loc_time_thresh) # apply(lambda x: any([k in x for k in loc_time_thresh]))
idx = milestone_edges.index[milestone_edges['source_thresh']].tolist()
print('len of number of edges in this cycle', len(idx), 'for REM=', rem)
for i in idx:
seg = segments[i]
source_milestone = milestone_edges['source'].values[i]
# seg_weight = max(0.3, math.log(1+seg[-1,2])) seg[-1,2] column index 2 has the weight information
seg_weight = seg[-1, 2] * seg_len[i] / (
max_seg_length - min_seg_length) ##seg.shape[0] / (max_seg_length - min_seg_length)
# cant' quite decide yet if sigmoid is desirable
# seg_weight=sigmoid_scalar(seg.shape[0] / (max_seg_length - min_seg_length), scale=5, shift=mean_seg_length / (max_seg_length - min_seg_length))
alpha_bundle = max(seg_weight, 0.1) # max(0.1, math.log(1 + seg[-1, 2]))
if alpha_bundle > 1: alpha_bundle = 1
if milestone_numeric_values is not None:
source_milestone_numerical_value = milestone_numeric_values[source_milestone]
if len(lineage_pathway)==0:
rgba = cmap((source_milestone_numerical_value - min_numerical_value) / ( max_numerical_value - min_numerical_value))
else:
rgba = list(rgba_lineage_milestone[source_milestone])
rgba[3] = milestone_lin_values[source_milestone]
rgba = tuple(rgba)
else:
rgba = cmap(min(seg_weight, 0.95)) # cmap(seg.shape[0]/(max_seg_length-min_seg_length))
# if seg_weight>0.05: seg_weight=0.1
seg = seg[:, 0:2].reshape(-1, 2)
seg_p = seg[~np.isnan(seg)].reshape((-1, 2))
if len(lineage_pathway)>0:
ax.plot(seg_p[:, 0], seg_p[:, 1], linewidth=linewidth_bundle * seg_weight, color=rgba) # edge_color ) #alpha=alpha_bundle,
else: ax.plot(seg_p[:, 0], seg_p[:, 1], linewidth=linewidth_bundle * seg_weight, color=rgba,alpha=alpha_bundle)
milestone_numeric_values_rgba = []
if milestone_numeric_values is not None:
for i in milestone_numeric_values:
rgba_ = cmap((i - min_numerical_value) / (max_numerical_value - min_numerical_value))
milestone_numeric_values_rgba.append(rgba_)
if time_thresh > 1.1 * max_numerical_value:
ax.clear()
else:
ax.scatter(layout[:, 0], layout[:, 1], s=size_scatter,
c=np.asarray(milestone_numeric_values_rgba), alpha=0)
ax.scatter(layout[loc_time_thresh, 0], layout[loc_time_thresh, 1], s=size_scatter,
c=np.asarray(milestone_numeric_values_rgba)[loc_time_thresh], alpha=alpha_scatter)
if show_sc_embedding:
if len(lineage_pathway)>0:
ax.scatter(sc_emb[sc_loc_time_thresh, 0], sc_emb[sc_loc_time_thresh, 1], s=sc_size_scatter,
c=np.asarray(rgba_lineage_sc)[sc_loc_time_thresh], alpha=p1_sc_bp[sc_loc_time_thresh])
else: ax.scatter(sc_emb[sc_loc_time_thresh, 0], sc_emb[sc_loc_time_thresh, 1], s=sc_size_scatter,
c=np.asarray(sc_rgba)[sc_loc_time_thresh], alpha=sc_alpha_scatter)
else:
ax.scatter(layout[:, 0], layout[:, 1], s=size_scatter, c='red', alpha=alpha_scatter)
#pbar.update()
frame_no = int(cycles)*2
animation = FuncAnimation(fig, update_edgebundle, frames=frame_no, interval=frame_interval, repeat=False) # 100
# pbar = tqdm.tqdm(total=frame_no)
# pbar.close()
print('complete animate')
animation.save(saveto, writer='imagemagick') # , fps=30)
print('saved animation')
plt.show()
return