Source code for seam.meta_explainer

from __future__ import division  # Should be first import

# Standard libraries
import os
import sys
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.stats import entropy

# Visualization libraries
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
import seaborn as sns
import matplotlib as mpl
import matplotlib.patches as patches
import itertools

# Bioinformatics libraries
from Bio import motifs  # For PWM/enrichment logos

# BatchLogo package imports
from .logomaker_batch.batch_logo import BatchLogo

# Local utilities
try:  # Try relative import first (for pip package)
    from . import utils
except ImportError:  # Fall back to direct import (for development/Colab)
    import utils

[docs] class MetaExplainer: """A class for analyzing and visualizing attribution map clusters. This class builds on the Clusterer class to provide detailed analysis and visualization of attribution map clusters. Features -------- Analysis - Mechanism Summary Matrix (MSM) generation - Sequence logos and attribution logos - Cluster membership tracking - Background separation and noise reduction of attribution maps Visualization - DNN score distributions per cluster - Sequence logos (PWM and enrichment) - Attribution logos (fixed and adaptive scaling) - Mechanism Summary Matrices - Cluster profile plots Requirements ----------- - All requirements from Clusterer class - Biopython - Logomaker - Seaborn - SQUID-NN """
[docs] def __init__(self, clusterer, mave_df, attributions, ref_idx=0, background_separation=False, mut_rate=0.10, sort_method='median', alphabet=None): """Initialize MetaExplainer with clusterer and data. Parameters ---------- clusterer : Clusterer Initialized Clusterer object with clustering results. mave_df : pandas.DataFrame DataFrame containing sequences and their scores. Must have columns: - 'Sequence': DNA/RNA sequences - 'Score' or 'DNN': Model predictions - 'Cluster': Cluster assignments attributions : numpy.ndarray Attribution maps for sequences. Shape should be (n_sequences, seq_length, n_characters). ref_idx : int, default=0 Index of reference sequence in mave_df. background_separation : bool, default=False Whether to separate background signal from logos. mut_rate : float, default=0.10 Mutation rate used for background sequence generation. sort_method : {'median', 'visual', None}, default='median' How to sort clusters in all visualizations and analyses. - 'median': Sort by median DNN score - 'visual': Sort based on hierarchical clustering of the MSM pattern - None: Use original cluster indices alphabet : list of str, optional List of characters to use in sequence logos. Default is ['A', 'C', 'G', 'T']. """ # Store inputs self.clusterer = clusterer # clusterer.cluster_labels should contain labels_n self.mave = mave_df self.attributions = attributions self.sort_method = sort_method self.ref_idx = ref_idx self.mut_rate = mut_rate self.alphabet = ['A', 'C', 'G', 'T'] # Default DNA alphabet self.background_logos = None # Initialize other attributes self.msm = None self.cluster_background = None self.consensus_df = None self.membership_df = None # Validate and process inputs self._validate_inputs() self._process_inputs() # Get the cluster ordering once at initialization if self.sort_method: self.cluster_order = self.get_cluster_order(sort_method=self.sort_method) else: self.cluster_order = None
def _validate_inputs(self): """Validate input data and parameters.""" # Ensure mave_df has required columns required_cols = {'Sequence', 'DNN'} if not required_cols.issubset(self.mave.columns): raise ValueError(f"mave_df must contain columns: {required_cols}") # Validate cluster labels exist if not hasattr(self.clusterer, 'cluster_labels') or self.clusterer.cluster_labels is None: raise ValueError("Clusterer must have valid cluster_labels. Did you run clustering?") # Get reference sequence from index self.ref_seq = self.mave['Sequence'].iloc[self.ref_idx] # Determine alphabet from sequences self.alphabet = sorted(list(set(self.mave['Sequence'][0:100].apply(list).sum()))) def _process_inputs(self): """Process inputs and initialize derived data structures.""" # Create membership tracking DataFrame self.membership_df = pd.DataFrame({ 'Cluster': self.clusterer.cluster_labels, 'Original_Index': range(len(self.mave)) }) # Add cluster assignments to mave DataFrame self.mave = self.mave.copy() # Create a copy to avoid modifying original self.mave['Cluster'] = self.clusterer.cluster_labels # Initialize cluster indices from unique cluster labels self.cluster_indices = np.unique(self.clusterer.cluster_labels)
[docs] def get_cluster_order(self, sort_method='median', sort_indices=None): """Get cluster ordering based on specified method.""" if sort_method is None: return self.cluster_indices # Return actual indices instead of range if sort_method == 'predefined' and sort_indices is not None: return np.array(sort_indices) if sort_method == 'hierarchical': if not hasattr(self, 'msm') or self.msm is None: raise ValueError("MSM required for hierarchical sorting. Call generate_msm() first.") from scipy.cluster import hierarchy from scipy.spatial import distance matrix_data = self.msm.pivot(columns='Position', index='Cluster', values='Entropy') linkage = hierarchy.linkage(distance.pdist(matrix_data), method='ward') dendro = hierarchy.dendrogram(linkage, no_plot=True, color_threshold=-np.inf) return self.cluster_indices[dendro['leaves']] # Map back to actual indices if sort_method == 'median': # Calculate median DNN score for each cluster cluster_medians = [] for k in self.cluster_indices: k_idxs = self.mave.loc[self.mave['Cluster'] == k].index cluster_medians.append(self.mave.loc[k_idxs, 'DNN'].median()) # Sort clusters by median DNN score sorted_order = np.argsort(cluster_medians) return self.cluster_indices[sorted_order] # Map back to actual indices raise ValueError(f"Unknown sort_method: {sort_method}")
[docs] def plot_cluster_stats(self, plot_type='box', metric='prediction', save_path=None, show_ref=True, show_fliers=False, compact=False, fontsize=8, dpi=200, figsize=None, file_format='png'): """Plot cluster statistics with various visualization options. Parameters ---------- plot_type : {'box', 'bar'} Type of visualization: - 'box': Show distribution as box plots (predictions only) - 'bar': Show bar plot of predictions or counts metric : {'prediction', 'counts'} What to visualize (only used for bar plots): - 'prediction': DNN prediction scores - 'counts': cluster occupancy/size save_path : str, optional Path to save figure. If None, display instead show_ref : bool If True and reference sequence exists, highlight its cluster show_fliers : bool If True and plot_type='box', show outlier points compact: bool If False, shows full boxplots. (default: False) If True, uses a compact representation for boxplots with dots and IQR lines. fontsize : int Font size for tick labels dpi : int DPI for saved figure figsize : tuple, optional Figure size (width, height) in inches (default: None, uses matplotlib default) file_format : str, optional Format for saved figure (default: 'png'). Common formats: 'png', 'pdf', 'svg', 'eps' """ # Collect data for each cluster boxplot_data = [] # Use actual clusters from data instead of cluster_indices actual_clusters = np.sort(self.mave['Cluster'].unique()) cluster_to_idx = {k: i for i, k in enumerate(actual_clusters)} for k in actual_clusters: k_idxs = self.mave.loc[self.mave['Cluster'] == k].index if plot_type == 'box' or metric == 'prediction': data = self.mave.loc[k_idxs, 'DNN'] boxplot_data.append(data) else: # counts for bar plot boxplot_data.append([len(k_idxs)]) # Sort using class-level ordering if it exists if self.cluster_order is not None: sorted_data = [] for k in self.cluster_order: idx = cluster_to_idx[k] sorted_data.append(boxplot_data[idx]) boxplot_data = sorted_data # Update membership tracking mapping_dict = {old_k: new_k for new_k, old_k in enumerate(self.cluster_order)} self.membership_df['Cluster_Sorted'] = self.membership_df['Cluster'].map(mapping_dict) if plot_type == 'box': # Calculate IQR iqr_values = [np.percentile(data, 75) - np.percentile(data, 25) for data in boxplot_data if len(data) > 0] average_iqr = np.mean(iqr_values) if iqr_values else 0 if figsize is not None: plt.figure(figsize=figsize) else: plt.figure(figsize=(6.4, 4.8)) if not compact: plt.boxplot(boxplot_data[::-1], vert=False, showfliers=show_fliers, medianprops={'color': 'black'}) plt.yticks(range(1, len(boxplot_data) + 1)[::10], range(len(boxplot_data))[::-1][::10], fontsize=fontsize) else: for pos, values in enumerate(boxplot_data[::-1]): values = np.array(values) median = np.median(values) q1 = np.percentile(values, 25) q3 = np.percentile(values, 75) plt.plot([q1, q3], [pos+1, pos+1], color='gray', lw=.5) # plot the IQR line plt.plot(median, pos+1, 'o', color='k', markersize=1, zorder=100) # plot the median point plt.yticks(range(1, len(boxplot_data) + 1)[::10], range(len(boxplot_data))[::-1][::10], fontsize=fontsize) plt.ylabel('Clusters') plt.xlabel('DNN') plt.gca().xaxis.set_major_locator(plt.MaxNLocator(integer=True)) plt.title(f'Average IQR: {average_iqr:.2f}') # Update reference cluster index if sorting is enabled if show_ref and self.ref_seq is not None: ref_cluster = self.membership_df.loc[self.ref_idx, 'Cluster'] if self.cluster_order is not None: ref_cluster = mapping_dict[ref_cluster] ref_data = boxplot_data[ref_cluster] if len(ref_data) > 0: plt.axvline(np.median(ref_data), c='red', label='Ref', zorder=-100) plt.legend(loc='best') plt.tight_layout() if save_path: plt.savefig(save_path + '/cluster_%s_%s.%s' % (metric, plot_type, file_format), facecolor='w', dpi=dpi, bbox_inches='tight') plt.close() else: plt.show() else: # bar plot if figsize is not None: plt.figure(figsize=figsize) else: plt.figure(figsize=(1.5, 5)) y_positions = np.arange(len(boxplot_data)) values = [np.median(data) if metric == 'prediction' else data[0] for data in boxplot_data] height = 1.0 if show_ref and self.ref_seq is not None: ref_cluster = self.membership_df.loc[self.ref_idx, 'Cluster'] if self.cluster_order is not None: ref_cluster = mapping_dict[ref_cluster] colors = ['red' if i == ref_cluster else 'C0' for i in range(len(values))] plt.barh(y_positions, values, height=height, color=colors) else: plt.barh(y_positions, values, height=height) plt.yticks(y_positions[::10], y_positions[::10], fontsize=fontsize) plt.ylabel('Cluster') plt.xlabel('DNN' if metric == 'prediction' else 'Count') plt.gca().invert_yaxis() plt.axvline(x=0, color='black', linewidth=0.5, zorder=100) plt.tight_layout() if save_path: plt.savefig(save_path + '/cluster_%s_%s.%s' % (metric, plot_type, file_format), facecolor='w', dpi=dpi, bbox_inches='tight') plt.close() else: plt.show()
[docs] def generate_msm(self, n_seqs=1000, batch_size=50, gpu=False): """Generate a Mechanism Summary Matrix (MSM) from cluster attribution maps. Parameters ---------- n_seqs : int, default=1000 Number of sequences to generate per cluster. batch_size : int, default=50 Number of sequences to process in each batch. gpu : bool, default=False Whether to use GPU acceleration if available. Returns ------- numpy.ndarray The Mechanism Summary Matrix with shape (n_clusters, n_clusters). Each entry [i,j] represents the average DNN score when applying cluster i's mechanism to sequences from cluster j. """ # Get sequence length from first sequence seq_length = len(self.mave['Sequence'].iloc[0]) if gpu: import tensorflow as tf device = '/GPU:0' if tf.test.is_built_with_cuda() else '/CPU:0' print(f"Using device: {device}") # Convert sequences to numpy array for faster processing sequences = np.array([list(seq) for seq in self.mave['Sequence']]) # Initialize MSM DataFrame n_entries = len(self.cluster_indices) * seq_length self.msm = pd.DataFrame({ 'Cluster': np.repeat(self.cluster_indices, seq_length), 'Position': np.tile(np.arange(seq_length), len(self.cluster_indices)), 'Reference': np.nan, 'Consensus': np.nan, 'Entropy': np.nan }) # Precompute one-hot encoding of reference sequence ref_oh = utils.seq2oh(self.ref_seq, self.alphabet) # Process each cluster in parallel from concurrent.futures import ThreadPoolExecutor from functools import partial def process_cluster(k, sequences, ref_oh): # Get sequences in current cluster k_mask = self.mave['Cluster'] == k seqs_k = sequences[k_mask] n_seqs = len(seqs_k) # Create position-wise counts matrix counts = np.zeros((len(self.alphabet), seq_length)) for i, base in enumerate(self.alphabet): counts[i] = (seqs_k == base).sum(axis=0) # Calculate position-wise frequencies freqs = counts / n_seqs # Calculate entropy (vectorized) with np.errstate(divide='ignore', invalid='ignore'): pos_entropy = -np.sum(freqs * np.log2(freqs + 1e-10), axis=0) pos_entropy = np.nan_to_num(pos_entropy) # Get consensus sequence consensus_indices = np.argmax(counts, axis=0) consensus_seq = np.array(self.alphabet)[consensus_indices] # Calculate matches consensus_oh = utils.seq2oh(consensus_seq, self.alphabet) consensus_matches = np.diagonal(consensus_oh.dot(counts)) / n_seqs * 100 if self.ref_seq is not None: ref_matches = np.diagonal(ref_oh.dot(counts)) / n_seqs * 100 ref_mismatches = 100 - ref_matches else: ref_mismatches = np.full(seq_length, np.nan) return k, pos_entropy, consensus_matches, ref_mismatches # Process clusters in parallel with ThreadPoolExecutor() as executor: process_fn = partial(process_cluster, sequences=sequences, ref_oh=ref_oh) results = list(tqdm( executor.map(process_fn, self.cluster_indices), total=len(self.cluster_indices), desc="Processing clusters" )) # Fill MSM with results for k, entropy, consensus, reference in results: mask = self.msm['Cluster'] == k self.msm.loc[mask, 'Entropy'] = np.tile(entropy, 1) self.msm.loc[mask, 'Consensus'] = np.tile(consensus, 1) if self.ref_seq is not None: self.msm.loc[mask, 'Reference'] = np.tile(reference, 1) return self.msm
[docs] def plot_msm(self, column='Entropy', delta_entropy=False, square_cells=False, view_window=None, show_tfbs_clusters=False, tfbs_clusters=None, entropy_multiplier=0.5, cov_matrix=None, row_order=None, revels=None, save_path=None, dpi=200, figsize=None, file_format='png', gui=False, gui_figure=None): """Visualize the Mechanism Summary Matrix (MSM) as a heatmap. Parameters ---------- column : str Which MSM metric to visualize: - 'Entropy': Shannon entropy of characters at each position per cluster - 'Reference': Percentage of mismatches to reference sequence - 'Consensus': Percentage of matches to cluster consensus sequence delta_entropy : bool If True and column='Entropy', show change in entropy from background expectation (based on mutation rate) square_cells : bool If True, set cells in MSM to be perfectly square view_window : list of [start, end], optional If provided, crop the x-axis to this window of positions show_tfbs_clusters : bool Whether to show TFBS cluster rectangles (default: False) tfbs_clusters : dict, optional Dictionary mapping cluster IDs to lists of positions. Required if show_tfbs_clusters is True. entropy_multiplier : float, optional Multiplier for entropy threshold when identifying background (default: 0.5) cov_matrix : numpy.ndarray, optional Covariance matrix for TFBS cluster plotting. Required if show_tfbs_clusters is True. row_order : list of int, optional Order of rows in cov_matrix. Required if show_tfbs_clusters is True. revels : pandas.DataFrame, optional Revels matrix for entropy calculations. Required if show_tfbs_clusters is True. save_path : str, optional Path to save figure. If None, display instead dpi : int DPI for saved figure figsize : tuple, optional Figure size (width, height) in inches (default: None, uses matplotlib default) file_format : str, optional Format for saved figure (default: 'png'). Common formats: 'png', 'pdf', 'svg', 'eps' gui : bool If True, return data for GUI processing without plotting gui_figure : matplotlib.figure.Figure, optional Existing figure to plot on when gui=True. If None, creates a new figure. """ if show_tfbs_clusters: if not hasattr(self, 'msm') or self.msm is None: raise ValueError("MSM not generated. Call generate_msm() first.") if any(x is None for x in [cov_matrix, row_order, revels]): raise ValueError("cov_matrix, row_order, and revels required for TFBS cluster plotting") self.cov_matrix = cov_matrix self.row_order = row_order self.revels = revels # Prepare data matrix n_clusters = self.msm['Cluster'].max() + 1 n_positions = self.msm['Position'].max() + 1 matrix_data = self.msm.pivot(columns='Position', index='Cluster', values=column) # Apply view window if specified if view_window is not None: start, end = view_window matrix_data = matrix_data.iloc[:, start:end] n_positions = end - start cluster_order = self.cluster_order if self.cluster_order is not None else np.sort(self.mave['Cluster'].unique()) matrix_data = matrix_data.reindex(cluster_order) if gui: # Use existing figure if provided, otherwise create new one if gui_figure is not None: fig = gui_figure # Clear the figure first fig.clear() else: if figsize is not None: fig = plt.figure(figsize=figsize) else: fig = plt.figure(figsize=(10, 6)) main_ax = fig.add_subplot(111) cmap_settings = self._get_colormap_settings(column, delta_entropy, matrix_data) if delta_entropy and column == 'Entropy': matrix_data -= cmap_settings.pop('bg_entropy', 0) heatmap = main_ax.pcolormesh(matrix_data, cmap=cmap_settings['cmap'], norm=cmap_settings['norm']) # Ensure consistent color scaling if column in ['Reference', 'Consensus']: heatmap.set_clim(0, 100) elif column == 'Entropy' and not delta_entropy: heatmap.set_clim(0, 2) from mpl_toolkits.axes_grid1 import make_axes_locatable divider = make_axes_locatable(main_ax) cbar_ax = divider.append_axes('right', size='2%', pad=0.05) cbar = fig.colorbar(heatmap, cax=cbar_ax, orientation='vertical') main_ax.set_xlabel('Position', fontsize=8) main_ax.set_ylabel('Cluster', fontsize=8) main_ax.invert_yaxis() self._configure_matrix_ticks(main_ax, n_positions, n_clusters, cluster_order) cbar.ax.set_ylabel(cmap_settings['label'], rotation=270, fontsize=8, labelpad=10) cbar.ax.tick_params(labelsize=6) # Set square cells if requested if square_cells: main_ax.set_aspect('equal') plt.tight_layout() return main_ax, cbar_ax, cluster_order, matrix_data # Setup plot if figsize is not None: fig = plt.figure(figsize=figsize) else: fig = plt.figure(figsize=(10, 6)) main_ax = fig.add_subplot(111) # Get colormap settings cmap_settings = self._get_colormap_settings(column, delta_entropy, matrix_data) if delta_entropy and column == 'Entropy': matrix_data -= cmap_settings.pop('bg_entropy', 0) # Create heatmap heatmap = main_ax.pcolormesh(matrix_data, cmap=cmap_settings['cmap'], norm=cmap_settings['norm']) # Ensure consistent color scaling if column in ['Reference', 'Consensus']: heatmap.set_clim(0, 100) elif column == 'Entropy' and not delta_entropy: heatmap.set_clim(0, 2) # Add TFBS cluster rectangles if requested if show_tfbs_clusters and tfbs_clusters is not None: # Define entropy threshold for active regions using instance mut_rate null_rate = 1 - self.mut_rate background_entropy = entropy([null_rate, (1-null_rate)/3, (1-null_rate)/3, (1-null_rate)/3], base=2) entropy_threshold = background_entropy * entropy_multiplier # Store active clusters by TFBS active_clusters_by_tfbs = {} for cluster, positions in tfbs_clusters.items(): original_indices = self.cov_matrix.index.tolist() reordered_positions = [original_indices[self.row_order.index(pos)] for pos in positions] if reordered_positions: start = min(reordered_positions) end = max(reordered_positions) # Find all clusters where this TFBS is active active_clusters = [] for cluster_idx in range(len(matrix_data)): # Get the original cluster index if sorting is used if self.sort_method and self.cluster_order is not None: original_cluster_idx = self.cluster_order[cluster_idx] else: original_cluster_idx = cluster_idx # Use revels for entropy calculation with original index cluster_entropy = self.revels.iloc[original_cluster_idx, start:end + 1] mean_entropy = cluster_entropy.mean() if mean_entropy < entropy_threshold: # Store the sorted index position for rectangle drawing active_clusters.append(cluster_idx) # Store active clusters for this TFBS active_clusters_by_tfbs[cluster] = active_clusters # Group consecutive clusters for k, g in itertools.groupby(enumerate(active_clusters), lambda x: x[0] - x[1]): group = list(map(lambda x: x[1], g)) if group: rect_start = min(group) rect_height = len(group) # If using view window, adjust start position plot_start = start plot_end = end if view_window: if end < view_window[0] or start > view_window[1]: # rectangle outside view window continue plot_start = max(start, view_window[0]) plot_end = min(end, view_window[1]) plot_start -= view_window[0] plot_end -= view_window[0] rect = patches.Rectangle( (plot_start, rect_start), plot_end - plot_start + 1, rect_height, linewidth=1, edgecolor='black', facecolor='none' ) main_ax.add_patch(rect) # Store active clusters information self.active_clusters_by_tfbs = active_clusters_by_tfbs # Set square cells if requested if square_cells: main_ax.set_aspect('equal') # Configure axes main_ax.set_xlabel('Position', fontsize=8) main_ax.set_ylabel('Cluster', fontsize=8) main_ax.invert_yaxis() # Set tick spacing based on data size self._configure_matrix_ticks(main_ax, n_positions, n_clusters, cluster_order) # Update x-axis ticks if using view window if view_window is not None: start, end = view_window x_ticks = main_ax.get_xticks() x_labels = [str(int(i + start)) for i in x_ticks] main_ax.set_xticklabels(x_labels) # Add colorbar from mpl_toolkits.axes_grid1 import make_axes_locatable divider = make_axes_locatable(main_ax) cbar_ax = divider.append_axes('right', size='2%', pad=0.05) cbar = fig.colorbar(heatmap, cax=cbar_ax, orientation='vertical') # Set colorbar limits and label if column in ['Reference', 'Consensus']: heatmap.set_clim(0, 100) elif column == 'Entropy' and not delta_entropy: heatmap.set_clim(0, 2) cbar.ax.set_ylabel(cmap_settings['label'], rotation=270, fontsize=8, labelpad=10) plt.tight_layout() if save_path: if show_tfbs_clusters is False and tfbs_clusters is None: plt.savefig(save_path + f'/msm_{column.lower()}.{file_format}', facecolor='w', dpi=dpi, bbox_inches='tight') elif show_tfbs_clusters is not None and tfbs_clusters is not None and column == 'Entropy': plt.savefig(save_path + f'/msm_{column.lower()}_identified.{file_format}', facecolor='w', dpi=dpi, bbox_inches='tight') plt.close() else: plt.show() return fig, main_ax
def _configure_matrix_ticks(self, ax, n_positions, n_clusters, cluster_order): """Configure tick marks and labels for MSM visualization. Parameters ---------- ax : matplotlib.axes.Axes The axes to configure. n_positions : int Number of sequence positions. n_clusters : int Number of clusters. cluster_order : array-like Order of clusters for y-axis labels. """ cluster_order = self.cluster_order if self.cluster_order is not None else np.sort(self.mave['Cluster'].unique()) # Set position (x-axis) ticks x_skip = 10 if n_positions > 100 else 20 if n_positions > 1000 else 1 x_ticks = np.arange(0.5, n_positions, x_skip) x_labels = [str(int(i-0.5)) for i in x_ticks] ax.set_xticks(x_ticks) ax.set_xticklabels(x_labels, rotation=0) # Set cluster (y-axis) ticks y_skip = 10 if n_clusters > 10 else 1 y_ticks = np.arange(0.5, n_clusters, y_skip) # Use sequential labels (0, 1, 2, ...) instead of original cluster numbers y_labels = [str(int(i-0.5)) for i in y_ticks] #y_labels = [str(cluster_order[int(i-0.5)]) for i in y_ticks] ax.set_yticks(y_ticks) ax.set_yticklabels(y_labels, rotation=0) # Set tick parameters ax.tick_params(axis='both', which='major', labelsize=6) def _get_colormap_settings(self, column, delta_entropy=False, matrix_data=None): """Get colormap settings for MSM visualization.""" if column == 'Entropy': if delta_entropy: r = self.mut_rate p = np.array([1-r] + [r/(len(self.alphabet)-1)] * (len(self.alphabet)-1)) bg_entropy = entropy(p, base=2) if matrix_data is not None: return { 'cmap': 'seismic', 'norm': TwoSlopeNorm( vmin=matrix_data.min().min(), vcenter=0, vmax=matrix_data.max().max() ), 'label': 'ΔH (bits)', 'bg_entropy': bg_entropy } return { 'cmap': 'seismic', 'norm': None, 'label': 'ΔH (bits)', 'bg_entropy': bg_entropy } return {'cmap': 'rocket_r', 'norm': None, 'label': 'Entropy (bits)'} return { 'cmap': 'viridis' if column == 'Reference' else 'rocket', #Blues_r for dark blue background 'norm': None, 'label': 'Percent mismatch' if column == 'Reference' else 'Percent match' }
[docs] def generate_logos(self, logo_type='average', background_separation=False, mut_rate=0.01, entropy_multiplier=0.5, adaptive_background_scaling=False, figsize=(20, 2.5), batch_size=50, font_name='sans', stack_order='big_on_top', center_values=True, color_scheme='classic', font_weight=None, fade_below=0.5, shade_below=0.5, width=0.9): """Generate sequence or attribution logos for each cluster. This method creates visualization logos that represent either the average attribution patterns or sequence patterns within each cluster. It can optionally remove background signal to highlight cluster-specific patterns. Parameters ---------- logo_type : {'average', 'pwm', 'enrichment'}, default='average' Type of logo to generate: - 'average': Shows average attribution values (based on attribution maps) - 'pwm': Shows position weight matrix of nucleotide frequencies (based on sequence statistics) - 'enrichment': Shows nucleotide enrichment relative to background (based on sequence statistics) background_separation : bool, default=False Whether to remove background signal from logos. Only applies to 'average' logos. When True, subtracts the background pattern computed by compute_background(), forcused on highly variable positions. mut_rate : float, default=0.01 Mutation rate for background entropy calculation. Only used if background_separation=True. entropy_multiplier : float, default=0.5 Controls stringency of background position identification via a multiplier on the background entropy. Only used if background_separation=True. adaptive_background_scaling : bool, default=False If True and background_separation=True, uniformly scales the background pattern differently for each cluster based on the magnitude of its background signal. This is useful when clusters have similar background patterns but at different scales. figsize : tuple, default=(20, 2.5) Figure size in inches (width, height). batch_size : int, default=50 Number of logos to process in each batch. font_name : str, default='sans' Font name for logo text. stack_order : {'big_on_top', 'small_on_top', 'fixed'}, default='big_on_top' How to order nucleotides in each stack: - 'big_on_top': Largest values on top - 'small_on_top': Smallest values on top - 'fixed': Fixed order (A, C, G, T) center_values : bool, default=True Whether to center values in each position. Only applies to 'average' logos. color_scheme : str or dict, default='classic' Color scheme for logo characters. font_weight : str or int, optional Font weight for logo text. Can be string ('normal', 'bold', etc.) or numeric (0-1000). fade_below : float, default=0.5 Controls alpha transparency for negative values. Higher values make negative values more transparent. shade_below : float, default=0.5 Controls color darkening for negative values. Higher values make negative values darker. width : float, default=0.9 Controls the horizontal width of each character. """ # Get sorted cluster order using class attribute cluster_order = self.get_cluster_order(sort_method=self.sort_method) # Compute background if needed if background_separation and logo_type == 'average': if not hasattr(self, 'background'): self.compute_background(mut_rate, entropy_multiplier, adaptive_background_scaling) # For enrichment logos, compute background PFM if not already done if logo_type == 'enrichment' and not hasattr(self, 'background_pfm'): # Calculate background PFM from all sequences all_seqs = self.mave['Sequence'] seq_array = motifs.create(all_seqs, alphabet=self.alphabet) self.background_pfm = seq_array.counts # Get cluster matrices cluster_matrices = [] for i, k in enumerate(tqdm(cluster_order, desc='Generating matrices')): k_idxs = self.mave['Cluster'] == k seqs_k = self.mave.loc[k_idxs, 'Sequence'] if logo_type == 'average': maps_avg = np.mean(self.attributions[k_idxs], axis=0) if background_separation: # Always use background_scaling (will be 1s if adaptive scaling is disabled) maps_avg -= self.background_scaling[i] * self.background cluster_matrices.append(maps_avg) elif logo_type in ['pwm', 'enrichment']: center_values = False # Calculate position frequency matrix seq_array = motifs.create(seqs_k, alphabet=self.alphabet) pfm = seq_array.counts pseudocounts = 0.5 if logo_type == 'pwm': # Convert to PPM and calculate information content ppm = pd.DataFrame(pfm.normalize(pseudocounts=pseudocounts)) background = getattr(self, 'background_freqs', np.array([1.0 / len(self.alphabet)] * len(self.alphabet))) ppm += 1e-6 # Avoid log(0) info_content = np.sum(ppm * np.log2(ppm / background), axis=1) cluster_matrices.append(np.array(ppm.multiply(info_content, axis=0))) else: # enrichment # Calculate enrichment relative to background frequencies enrichment = (pd.DataFrame(pfm) + pseudocounts) / \ (pd.DataFrame(self.background_pfm) + pseudocounts) cluster_matrices.append(np.log2(enrichment)) # Stack matrices into 3D array logo_array = np.stack(cluster_matrices) # Store both raw and background-separated maps if logo_type == 'average': if background_separation: # Store both versions - the raw maps and the background-separated maps self.cluster_maps = np.stack([np.mean(self.attributions[self.mave['Cluster'] == k], axis=0) for k in cluster_order]) self.cluster_maps_no_bg = logo_array.copy() # These are already background-separated else: # If no background separation, just store the raw maps self.cluster_maps = logo_array.copy() # Always compute global y-limits for attribution logos y_min_max = None if logo_type == 'average': y_mins = [] y_maxs = [] # Make a copy of logo_array to avoid modifying the original when centering matrices = logo_array.copy() if center_values: # Center all matrices if center_values is True for i, matrix in enumerate(matrices): matrices[i] = matrix - np.expand_dims(np.mean(matrix, axis=1), axis=1) # Calculate y-limits from either centered or uncentered matrices for matrix in matrices: positive_mask = matrix > 0 positive_matrix = matrix * positive_mask positive_sums = positive_matrix.sum(axis=1) negative_mask = matrix < 0 negative_matrix = matrix * negative_mask negative_sums = negative_matrix.sum(axis=1) y_mins.append(negative_sums.min()) y_maxs.append(positive_sums.max()) y_min_max = [min(y_mins), max(y_maxs)] batch_logos = BatchLogo( logo_array, alphabet=self.alphabet, figsize=figsize, batch_size=batch_size, font_name=font_name, stack_order=stack_order, center_values=center_values, y_min_max=y_min_max, color_scheme=color_scheme, font_weight=font_weight, fade_below=fade_below, shade_below=shade_below, width=width ) batch_logos.process_all() self.batch_logos = batch_logos return batch_logos
[docs] def show_sequences(self, cluster_idx): """Show sequences belonging to a specific cluster. Parameters ---------- cluster_idx : int Index of cluster to show sequences for. If sorting was specified during initialization, this index refers to the sorted order (e.g., 0 is the first cluster after sorting). Returns ------- pandas.DataFrame DataFrame containing sequences and scores for the specified cluster. """ # Get original cluster index using class-level sorting if available if self.cluster_order is not None: original_idx = self.cluster_order[cluster_idx] else: original_idx = cluster_idx # Get sequences from the specified cluster cluster_seqs = self.mave[self.mave['Cluster'] == original_idx] return cluster_seqs[['Sequence', 'DNN']]
[docs] def plot_cluster_profiles(self, profiles, save_dir=None, dpi=200, figsize=None, file_format='png'): """Plot overlay of profiles associated with each cluster. Parameters ---------- profiles : np.ndarray Array of profile data corresponding to sequences in mave_df save_dir : str, optional Directory to save profile plots. If None, displays instead. dpi : int DPI for saved figures figsize : tuple, optional Figure size (width, height) in inches (default: None, uses matplotlib default) file_format : str, optional Format for saved figure (default: 'png'). Common formats: 'png', 'pdf', 'svg', 'eps' """ if not os.path.exists(save_dir): os.makedirs(save_dir) for k in self.cluster_indices: k_idxs = self.mave.loc[self.mave['Cluster'] == k].index cluster_profiles = profiles[k_idxs] if figsize is not None: plt.figure(figsize=figsize) else: plt.figure(figsize=(10, 5)) for profile in cluster_profiles: plt.plot(profile, alpha=0.1, color='gray') plt.plot(cluster_profiles.mean(axis=0), color='red', linewidth=2) plt.title(f'Cluster {k} Profiles') plt.xlabel('Position') plt.ylabel('Value') if save_dir: plt.savefig(os.path.join(save_dir, f'cluster_{k}_profiles.{file_format}'), dpi=dpi, bbox_inches='tight') plt.close() else: plt.show()
[docs] def compute_background(self, mut_rate=0.01, entropy_multiplier=0.5, adaptive_background_scaling=False, process_logos=True): """Compute background signal based on entropic positions. This method identifies and computes background signal patterns for each cluster based on positions with high entropy (high variability). The background can be computed either uniformly across all clusters or with cluster-specific scaling. Parameters ---------- mut_rate : float, default=0.01 Mutation rate used to calculate background entropy threshold. Higher values will identify more positions as entropic. entropy_multiplier : float, default=0.5 Factor to multiply background entropy by for threshold. Lower values make the threshold more stringent (fewer positions identified as entropic). adaptive_background_scaling : bool, default=False If True, computes a scaling factor for each cluster that best matches the magnitude of that cluster's background signal. This is useful when different clusters have similar background patterns but at different scales. If False, uses the same background scale for all clusters. process_logos : bool, default=True If True, creates and processes BatchLogo instances for background visualization. If False, skips logo processing to save time and memory. Notes ----- The background computation process: 1. Identifies entropic (highly variable) positions in each cluster 2. Computes the average attribution pattern at these positions 3. If adaptive_background_scaling is True, computes a scaling factor for each cluster based on positions that are entropic in both that cluster and the global background """ # Calculate background entropy threshold null_rate = 1 - mut_rate background_entropy = entropy( np.array([null_rate, (1-null_rate)/3, (1-null_rate)/3, (1-null_rate)/3]), base=2 ) entropy_threshold = background_entropy * entropy_multiplier # Initialize cluster background matrix n_clusters = len(self.mave['Cluster'].unique()) seq_length = self.attributions.shape[1] n_chars = len(self.alphabet) cluster_background = np.zeros(shape=(n_clusters, seq_length, n_chars)) # Compute background for each cluster for idx, k in enumerate(self.get_cluster_order()): k_idxs = self.mave['Cluster'] == k child_maps = self.attributions[k_idxs] # Get entropic positions for this cluster entropic_positions = self._get_entropic_positions(k, entropy_threshold) if len(entropic_positions) == 0: continue # Compute background for entropic positions for ep in entropic_positions: for child in child_maps: cluster_background[idx, ep, :] += child[ep, :] cluster_background[idx] /= len(child_maps) # Store cluster backgrounds and compute global background self.cluster_backgrounds = cluster_background self.background = np.mean(cluster_background, axis=0) # Initialize background scaling factors background_scaling = np.ones(len(self.cluster_backgrounds)) # Compute cluster-specific scaling factors if requested if adaptive_background_scaling: for i in range(len(self.cluster_backgrounds)): cluster_bg = self.cluster_backgrounds[i] # Get positions that have signal in both cluster and global background entropic_mask = np.any(cluster_bg != 0, axis=1) # True for positions that were entropic in this cluster global_mask = np.any(self.background != 0, axis=1) # True for positions that were entropic in any cluster valid_positions = entropic_mask & global_mask # Only use positions that were entropic in both # Compute scaling using only valid positions if np.any(valid_positions): alpha = np.sum(np.abs(cluster_bg[valid_positions])) / np.sum(np.abs(self.background[valid_positions])) background_scaling[i] = alpha self.background_scaling = background_scaling # Create BatchLogo instance for backgrounds only if requested if process_logos: self.background_logos = BatchLogo( cluster_background, alphabet=self.alphabet, figsize=[20, 2.5], batch_size=50, ) self.background_logos.process_all() else: self.background_logos = None return self.background_logos
def _get_entropic_positions(self, cluster, threshold): """Get positions with entropy above threshold for given cluster. Parameters ---------- cluster : int Cluster index threshold : float Entropy threshold value Returns ------- np.ndarray Array of positions with entropy above threshold """ # Get sequences for this cluster k_idxs = self.mave['Cluster'] == cluster seqs = self.mave.loc[k_idxs, 'Sequence'] # Convert sequences to position-specific frequency matrix seq_length = len(seqs.iloc[0]) char_counts = np.zeros((seq_length, len(self.alphabet))) for seq in seqs: for pos, char in enumerate(seq): char_idx = self.alphabet.index(char) char_counts[pos, char_idx] += 1 # Convert to frequencies freqs = char_counts / len(seqs) # Calculate entropy at each position entropies = np.zeros(seq_length) for pos in range(seq_length): pos_freqs = freqs[pos] # Avoid log(0) by only considering non-zero frequencies valid_freqs = pos_freqs[pos_freqs > 0] entropies[pos] = -np.sum(valid_freqs * np.log2(valid_freqs)) # Return positions above threshold return np.where(entropies > threshold)[0]
[docs] def get_cluster_maps(self, cluster_idx): """Get attribution maps belonging to a specific cluster. Parameters ---------- cluster_idx : int Index of cluster to get maps for. If sorting was specified during initialization, this index refers to the sorted order (e.g., 0 is the first cluster after sorting). Returns ------- numpy.ndarray Attribution maps for the specified cluster. """ # Get original cluster index using class-level sorting if available if self.cluster_order is not None: original_idx = self.cluster_order[cluster_idx] else: original_idx = cluster_idx # Get maps from the specified cluster cluster_maps = self.attributions[self.mave['Cluster'] == original_idx] return cluster_maps
[docs] def plot_attribution_variation(self, scope='all', metric='std', save_path=None, view_window=None, figsize=None, dpi=600, colors=None, xtick_spacing=5, file_format='png'): """Visualize the variation in attribution values across attribution maps for each nucleotide position. Parameters ---------- scope : {'all', 'clusters'}, default='all' Scope of variation calculation: - 'all': Use all individual attribution maps - 'clusters': Use cluster-averaged attribution maps metric : {'std', 'var'}, default='std' Metric to use for variation calculation: - 'std': Standard deviation - 'var': Variance save_path : str, optional Path to save figure. If None, display instead. view_window : list of [start, end], optional If provided, crop the x-axis to this window of positions. figsize : tuple, optional Figure size (width, height) in inches (default: None, uses matplotlib default) dpi : int, default=600 DPI for saved figure. colors : dict, optional Dictionary mapping nucleotide indices to RGB colors. Default: {0: [0, .5, 0], 1: [0, 0, 1], 2: [1, .65, 0], 3: [1, 0, 0]} for A, C, G, T respectively. xtick_spacing : int, default=5 Show x-axis labels every nth position. Set to 1 to show all positions. file_format : str, optional Format for saved figure (default: 'png'). Common formats: 'png', 'pdf', 'svg', 'eps' Returns ------- numpy.ndarray Array of variation values (std or var) for each position and nucleotide """ if not hasattr(self, 'attributions') or self.attributions is None: raise ValueError("No attribution maps found. Run compute() first.") if scope not in ['all', 'clusters']: raise ValueError("scope must be one of: 'all', 'clusters'") if metric not in ['std', 'var']: raise ValueError("metric must be one of: 'std', 'var'") # Get appropriate attribution maps based on scope if scope == 'all': maps_to_analyze = self.attributions else: # clusters # Get cluster-averaged attribution maps cluster_maps = [] cluster_order = self.cluster_order if self.cluster_order is not None else np.sort(self.mave['Cluster'].unique()) for k in cluster_order: k_maps = self.get_cluster_maps(k) cluster_maps.append(np.mean(k_maps, axis=0)) maps_to_analyze = np.stack(cluster_maps, axis=0) # Calculate variation metric across maps if metric == 'std': variation = np.std(maps_to_analyze, axis=0) # shape: (L, A) else: # var variation = np.var(maps_to_analyze, axis=0) # shape: (L, A) # Set default colors if not provided if colors is None: colors = { 0: [0, .5, 0], # A: green 1: [0, 0, 1], # C: blue 2: [1, .65, 0], # G: orange 3: [1, 0, 0] # T: red } # Create position indices and apply view window if view_window is not None: start, end = view_window variation = variation[start:end] plot_positions = np.arange(len(variation)) seq_positions = np.arange(start, end) else: plot_positions = np.arange(len(variation)) seq_positions = plot_positions # Create plot if figsize is not None: plt.figure(figsize=figsize) else: plt.figure(figsize=(20, 1.5)) # Plot bars for each nucleotide bar_width = 0.2 for i, nuc in enumerate(self.alphabet): plt.bar(plot_positions + i * bar_width, variation[:, i], width=bar_width, color=colors[i], label=nuc) # Customize plot with spaced ticks tick_mask = np.zeros_like(plot_positions, dtype=bool) tick_mask[::xtick_spacing] = True plt.xticks(plot_positions[tick_mask] + 1.5 * bar_width, seq_positions[tick_mask], rotation=90) # Set x-axis limits to show full range with small padding plt.xlim(plot_positions[0] - 0.25, plot_positions[-1] + 0.85) # Add y-axis label based on metric plt.ylabel('Std Dev' if metric == 'std' else 'Variance') plt.gca().spines['left'].set_visible(False) plt.gca().spines['right'].set_visible(False) plt.gca().spines['top'].set_visible(False) plt.tight_layout() if save_path: plt.savefig(save_path + '/attribution_variation_%s_%s.%s' % (metric, scope, file_format), dpi=dpi, bbox_inches='tight') plt.close() else: plt.show() return variation