1. Introduction
Three-dimensional (3D) spatial interaction data, spanning Hi-C contact matrices in genomics [
1,
2,
3] and point clouds in computer vision [
4,
5,
6,
7], are characterized by extreme sparsity and high dimensionality. As shown in
Figure 1a, a typical single-cell Hi-C contact matrix spans tens of thousands of spatial bins yet exhibits detectable contacts in only a tiny fraction of bin pairs; the corresponding 3D spatial architecture is visualized in
Figure 1b.
Technologies that co-assay chromatin conformation with transcriptome in individual cells, exemplified by HiRES [
8] and scHiCAR [
9], anchor sparse 3D spatial measurements with orthogonal RNA readouts. This multi-modal co-registration is analogous to how RGB imagery enriches LiDAR point clouds in autonomous driving [
10,
11], and has motivated deep learning architectures that jointly model 3D spatial interactions and orthogonal molecular readouts from the same cell.
Single-cell multi-modal profiling has motivated a range of deep learning architectures for joint representation learning, including VAE-based integration [
12,
13,
14,
15], graph-based methods [
16], and attention-driven architectures [
17,
18,
19]. Among these, HiGLUE [
20] is the first model to simultaneously combine VAE encoders, graph neural networks (GNNs), transformer attention, and multi-branch encoder–decoder designs for jointly modeling paired single-cell Hi-C and RNA-seq data.
HiGLUE inherits a computational challenge that is structurally analogous to multi-modal 3D perception systems in autonomous driving [
10]. In both settings, separate encoder–decoder branches process modalities of vastly different scales: in HiGLUE, the RNA modality consists of several thousand gene expression values while the Hi-C modality exceeds it by over one order of magnitude; in LiDAR-camera fusion, the point cloud branch similarly dominates computation over the camera branch. This asymmetry concentrates the majority of model parameters in the heavier spatial encoder and causes decoder activations to accumulate linearly across decomposition layers, creating a structural memory bottleneck whose severity is predictable directly from the data configuration. Lightweight configurations fit within typical GPU memory, whereas heavyweight configurations inevitably exhaust it, resulting in out-of-memory (OOM) failures regardless of batch size.
Existing distributed training methods apply generic cost models that treat all model components uniformly, overlooking modality-specific memory asymmetries. For heavyweight model instances, they recommend standard data parallelism, which replicates rather than distributes the structural memory bottleneck across all GPUs without reducing per-device activation memory. In contrast, we identified that partitioning the hidden dimension of the 3D spatial decoder removes a single bottleneck: reducing per-device activation memory across all subsequent strata, keeping cumulative memory below device limits without modifying the encoder or graph embedding stages.
We argue that for domain-specific models with structurally predictable computational heterogeneity, the problem is not searching for an optimal parallel strategy in an infinite space but diagnosing the bottleneck from the model configuration and prescribing a deterministic remedy.
We introduce Automatic Modality-aware Parallelization (AMP), a framework that operates in three phases:
Auto-Profile. Extracts three bottleneck signals (input dimensionality ratio, attention decoder presence, and decomposition depth) from preprocessed data files without executing a forward pass.
Auto-Diagnose. Maps each detected signal independently to its corresponding strategy via a fixed signal-to-strategy mapping: HIGH_INPUT_DIMENSIONALITY activates and ; ATTENTION_DECODER activates ; HEAVY_DECOMPOSITION activates when the predicted peak memory exceeds GPU capacity.
Auto-Execute. Composes the activated strategies into a unified per-stratum computation under the data-parallel infrastructure: determines the local feature partition, applies key-first decoding using this partition, provides chunked attention as a reusable operator, and internally invokes both and in its sharded attention block.
Among the five strategies, the key innovation is
, a hidden-dimension tensor parallelism strategy that partitions the 3D decoder’s hidden dimension across GPUs, transforming five non-standard operators into sharded forms, each proven equivalent through three shared mathematical properties (
Section 3.4.5). Every transformation in the framework is mathematically proven to preserve single GPU numerical results.
We evaluated AMP on the HiRES mouse brain dataset [
8]. Lightweight configurations trained successfully on a single GPU and achieved throughput improvement with four-GPU data parallelism. Heavyweight configurations, which are untrainable on a single GPU due to OOM, were correctly diagnosed by the framework and completed training successfully. Scaling from four to eight GPUs on heavyweight configurations, the 500 kb and 100 kb variants achieved 2.0× and 3.8× training speedups, respectively. A leave-one-out ablation confirmed that chunked attention (
) and hidden-dimension tensor parallelism (
) are each individually necessary for training to succeed.
The remainder of this paper is organized as follows.
Section 2 reviews related work in multi-modal 3D biological models and distributed parallelization.
Section 3 presents AMP framework, detailing the Auto-Profile, Auto-Diagnose, and Auto-Execute phases along with the five strategies.
Section 4 reports experimental results.
Section 5 concludes with limitations and future directions.
3. Methods
This section presents the proposed framework, which operates through three deterministic phases: Auto-Profile (
Section 3.2) extracts bottleneck signals from the data configuration, Auto-Diagnose (
Section 3.3) maps them to a strategy prescription via a fixed signal-to-strategy mapping, and Auto-Execute (
Section 3.5) implements the prescribed training pipeline. These three phases are organized into two functional layers: a lightweight routing layer (
Section 3.2 and
Section 3.3) and a strategy library (
Section 3.4) containing five mathematically equivalent parallel strategies (
–
).
3.1. Problem Formulation and System Overview
We formulated the parallelization problem for multi-modal 3D biological models as follows: A model
M consists of a set of modalities
, where each modality
k has an encoder
and decoder
, and a shared guidance graph encoder
that processes a prior knowledge graph. Given
N GPUs with per-device available memory
, the goal is to determine a parallel strategy configuration
,
, indicating which of the five parallel strategies are activated, such that the peak per-GPU memory consumption satisfies
where
denotes the
i-th GPU. Here,
is the peak GPU memory consumption on device
when model
M is executed under strategy configuration
. In practice this value is predicted analytically by the EPM estimator (Equation (
2)) during the profiling phase without executing a forward pass. The secondary objective is to minimize the total training time
. An additional constraint is that every parallel transformation must be mathematically equivalent to the single-GPU forward and backward computations.
Figure 2 illustrates the three phases of the proposed framework. The framework takes only a preprocessed data directory and the number of available GPUs as user input and proceeds through three deterministic phases: Auto-Profile, Auto-Diagnose, and Auto-Execute. The strategy configuration is determined by profiling rather than manual specification or cost-model search. The framework is designed around five principles:
Zero Intrusion. The original single-GPU training code was preserved and unmodified.
Diagnosis-Driven. Strategy activation is determined by data-derived bottleneck signals rather than user configuration.
Mathematical Equivalence. Every parallel transformation was proven to preserve the single-GPU numerical results.
Determinism. Identical inputs always produce identical strategy configurations.
Extensibility. New bottleneck signals and corresponding strategies can be registered using a standard interface.
Architecturally, these three phases are organized into two functional layers. The upper layer (
Section 3.2 and
Section 3.3) provides a lightweight bottleneck-aware routing mechanism: Auto-Profile extracts structural signals from the data configuration, and Auto-Diagnose maps them to a strategy prescription using a fixed signal-to-strategy mapping. This layer is intentionally simple because, for multi-modal 3D biological models, the bottleneck is structurally knowable from a small set of data-derived signals. The lower layer (
Section 3.4) constitutes the strategy library, which includes five parallel strategies with a formal mathematical equivalence proof. This is where the principal technical depth of this study resides. Strategy
in particular involves the complete tensor-parallel transformation of five core operators in the Hi-C decoder, each proven equivalent through three shared mathematical properties (
Section 3.4.5).
3.2. Auto-Profile: Bottleneck Signature Extraction
In the first phase, modality signatures are automatically extracted from the pre-processed data files without executing a single forward pass. Given a data directory, the profiler scans the modality-specific subdirectories (RNA expression, Hi-C contact matrices, and optionally ATAC-seq or methylation data) and reads the associated structured data files to extract properties that are predictive of the downstream memory bottlenecks. This profiling procedure is deliberately lightweight. The key architectural claim of this study is that a small set of structurally encoded signals suffices to diagnose the bottleneck for this class of models. This eliminates the need for an elaborate cost-model search employed by general-purpose automatic parallelism frameworks.
For each modality, the profiler records: (1) the probability model type (e.g., Negative Binomial (NB) for RNA, Hi-C Zero-Inflated Negative Binomial (HiCZINB) for Hi-C, and Zero-Inflated Log-Normal (ZILN) for methylation), which determines the decoder architecture; (2) the input dimensionality and the number of cells ; (3) for the 3D spatial modality specifically, whether the decoder contains attention mechanisms and, when applicable, the number of distance-dependent decomposition units (such as the number of strata in Hi-C’s stratified decoder, inferred from file naming conventions or feature name suffixes); and (4) the relative scale of computation across modalities, obtained by a lightweight estimation that accounts for the interaction between input size, decoder complexity, and training batch size. The guidance graph was profiled separately to record its node and edge counts.
The extracted measurements were mapped to a set of three bottleneck signals (
Table 3). A modality is flagged with the HIGH_INPUT_DIMENSIONALITY signal when its input features exceed a domain-specific threshold relative to other modalities, indicating that its encoder projection and decoder reconstruction pathways dominate memory consumption. The ATTENTION_DECODER signal is raised when the modality’s decoder employs attention mechanisms, which produce large intermediate tensors of shape [
,
D] during the query-key computation. The HEAVY_DECOMPOSITION signal (instantiated here for the Hi-C decoder) is raised when the decoder employs a multi-layer decomposed structure (such as the stratified attention layers in Hi-C’s decoder), indicating that multiple attention and feedforward blocks are applied sequentially to the same input, which causes cumulative activation retention.
The three signals were designed to be orthogonal in their detection logic rather than in the architectural features they target. ATTENTION_DECODER detects the presence of attention as a mechanism, and HEAVY_DECOMPOSITION detects the structural repetition of decoder blocks regardless of their internal operation. A model with both properties will activate both
and
, and the functional dependencies ensure that their sharded attention and chunked attention are composed correctly. The EPM estimator Equation (
2) then determines how many such layers the available GPU memory can accommodate before exhausting it.
The threshold for HIGH_INPUT_DIMENSIONALITY is domain-specific and is assessed relative to the next-largest modality. In the HiRES dataset, the Hi-C feature count exceeds the RNA features by over one order of magnitude, which serves as the practical detection criterion.
The key insight is that these signals are structurally encoded in the model’s data configuration; they can be detected by reading data files and model metadata without running the model or measuring the actual GPU memory. Because these signals derive from preprocessing parameters (bin size, decomposition depth, and decoder architecture) rather than from the biological content of the data, the diagnosis does not depend on which tissue or species the data originate from. This makes the profiling phase zero-cost and immediately informative before the training begins.
The profiler also estimates the peak GPU memory using a linear budget model. Let
denote the number of decomposition units (strata) in the Hi-C decoder. Let
denote the memory consumption of the encoder and graph embedding stages, and let
denote the marginal memory cost per decomposition unit, which is dominated by the per-unit attention block activations, decoder intermediate tensors and softmax normalization. The estimated peak memory for
decomposition units is
Both and are estimated analytically from tensor shapes without executing a forward pass: is computed from the encoder output dimensions and graph embedding parameters, whereas is derived from the intermediate tensor dimensions of the attention block per stratum.
3.3. Auto-Diagnose: Bottleneck-to-Strategy Mapping
The second phase maps the extracted bottleneck signals to a concrete parallel strategy prescription via a fixed signal-to-strategy mapping (
Figure 3). The decision procedure is intentionally straightforward; its role is to provide a reproducible, auditable mapping from structural signals to strategy activation, not to contribute algorithmic complexity. The three signal nodes in
Figure 3 (HIGH_INPUT, ATTN_DECODER, HEAVY_DECOMP) abbreviate the bottleneck signals defined in
Table 3, and each independently activates its corresponding strategy from
Table 4.
Signal-to-strategy mapping. If only one GPU is available, the framework prescribes single-GPU execution without parallelism. If the 3D spatial modality is absent (RNA-only or RNA + ATAC configurations), standard data parallelism (Strategy ) is sufficient because the bottleneck signals that require deeper parallelization are exclusively associated with the 3D spatial modality. When signals are present, the mapping evaluates each detected signal independently as follows:
If HIGH_INPUT_DIMENSIONALITY is raised, strategies and are activated: shards the reconstruction output across GPUs, and restructures the decoder-side matrix computation using the same feature partition.
If ATTENTION_DECODER is raised, strategy (chunked attention) is activated, replacing the full attention computation with sequential chunked evaluation.
If HEAVY_DECOMPOSITION is raised and the estimated peak memory exceeds , strategy (hidden-dimension sharding) is activated, partitioning the hidden dimension across GPUs.
These three signals characterize the structural source of modality asymmetry: high input dimensionality creates the encoder-side weight dominance, an attention decoder creates large quadratic or windowed temporary tensors, and deep decomposition causes cumulative activation retention. Together they describe the structural pattern that makes one modality dominate GPU memory.
While the current instantiation of and the tensor-parallel attention in target the local (sliding-window) attention pattern of HiGLUE, the ATTENTION_DECODER signal detects the architectural presence of attention in the decoder rather than its specific type. Adapting the framework to full self-attention decoders would require integrating alternative parallelization strategies (e.g., Flash Attention) into the strategy library, a direction we defer to future work. The profiling and diagnosis layers remain applicable because their detection logic depends on whether attention exists, not on the particular attention variant employed.
Strategy
(data parallelism) serves as the underlying multi-process infrastructure and is always enabled when multiple GPUs are available. This design yields a compositional strategy prescription: each signal independently activates its corresponding strategy, and the functional dependencies (
Section 3.4.6) ensure the correct composition, regardless of which subset is activated.
Unlike existing automatic parallelism frameworks that search over a continuous strategy space, the proposed method is fully determined by the detected bottleneck signals. The mapping from the signal to the strategy is fixed and documented (
Table 4), making the prescription reproducible and auditable. The strategies are composed according to functional dependencies: the feature-parallel strategy
must be activated before the restructured decode
(because
needs to know which features belong to the local rank); the chunked attention
is an independent component; and the hidden-dimension sharding
depends on the chunked attention routine from
as a building block within its sharded attention block.
3.4. Strategy Library: Five Mathematically Equivalent Parallel Strategies
The strategy library consists of five parallel strategies, each targeting a specific bottleneck signal and provably equivalent to single-GPU computation. All five strategies were implemented as non-intrusive extensions to the original training workflow, preserving the original model code base unmodified and allowing the single-GPU baseline to remain available for comparison.
3.4.1. : Lightweight Data Parallelism
Strategy implements a conventional multi-process data parallelism. Each GPU holds a complete model replica and processes distinct subsets of the training cells. After each backward pass, the gradients were manually synchronized across all ranks using an all-reduce sum, followed by division by the world size. The full guidance graph (stored as edge index, edge weight, and edge sign tensors) is replicated on every GPU and processed identically by the graph encoder forward pass on each rank. Although this is computationally redundant, the graph encoder is lightweight relative to the Hi-C decoder, and the redundancy avoids additional communication overhead for partial graph synchronization. Training metric logging, learning rate scheduling, and early stopping were disabled to ensure that all ranks followed an identical control flow.
Let
denote the gradient of parameter
computed on rank
i’s local data shard
. The synchronized gradient is as follows:
The parameter update follows standard SGD:
Because the loss decomposes additively over independent data samples:
with
,
equals the gradient that would be obtained from the full batch, ensuring equivalence to single-GPU training with
N times the batch size.
3.4.2. : Feature-Parallel Hi-C Reconstruction
Target signal. HIGH_INPUT_DIMENSIONALITY.
Unlike conventional data parallelism, where different GPUs process different cell batches, Strategy employs same-batch multi-rank execution: all ranks receive the identical cell batch B and graph batch. The Hi-C feature dimension is evenly partitioned into N contiguous shards , with rank i computing the reconstruction loss only for its local feature shard.
Algorithm 1 summarizes the per-cell local reconstruction loss computation under feature-parallel sharding.
| Algorithm 1 Local Hi-C Reconstruction Loss for a Single Cell |
Input: cell batch B, node embeddings , Hi-C raw data , , decoder parameters Output: local reconstruction loss
- 1:
rank i feature partition bounds in - 2:
for stratum to do - 3:
Compute mask: - 4:
if then - 5:
continue - 6:
end if - 7:
Extract local features: - 8:
if then - 9:
▹ no attention at - 10:
else if activated and activated then - 11:
▹ Algorithm 2 - 12:
else - 13:
- 14:
end if - 15:
Select: , ▹F selected rows - 16:
Decode: ▹ see for detail - 17:
Compute local loss: - 18:
end for - 19:
return mean of all
|
The Hi-C reconstruction loss decomposes additively over the features as follows:
Because gradients are accumulated across all ranks via all-reduce, the total gradient is
which is exactly the gradient that would be obtained from the full-feature reconstruction on a single GPU.
3.4.3. : Chunked Local Attention
Target signal. ATTENTION_DECODER.
The Hi-C decoder uses local self-attention layers, where each graph node attends to nodes within a spatial window of size
w. The standard implementation simultaneously processes all query positions, materializing an
temporary tensor. Strategy
decomposes this into sequential chunks of size
C:
where ⊕ denotes the sequential concatenation. For each chunk, the attention operation internally selects keys and values within the window region, reducing the peak memory from
to
. The chunk size
C is configurable as well.
Because attention at position j depends only on keys and values within , and softmax normalization is performed independently per query position, the chunked evaluation yields results identical to the full computation. There is no cross-chunk dependency in the computation of the above equation. This design is distinct from Flash Attention, which tiles general dense attention operations with online softmax rescaling. exploits the structural locality of the window-limited attention of the decoder: each chunk’s queries interact only with their own window neighborhood; therefore, no inter-chunk softmax rescaling is required. Chunking is a direct consequence of the local attention pattern identified by the ATTENTION_DECODER signal, rather than a general memory-saving technique applicable to arbitrary attention patterns.
3.4.4. : Restructured Decode Path
Target signal. HIGH_INPUT_DIMENSIONALITY (intermediate decode matrix).
The original Hi-C decoder computes a full intermediate matrix
before selecting the subset of features needed for the current stratum via column indexing. When
is large (e.g.,
nodes for chromosome-level processing), this intermediate matrix dominates the memory. Strategy
applies the matrix algebra identity:
where
is the feature index set for the current stratum, with
. The left-hand side materializes a
intermediate; the right-hand side selects the
F key rows first, producing a
result.
The same transformation is applied to the fully decoded path required for the per-stratum softmax normalization. Since per stratum, the memory reduction is from to .
This is a direct application of the matrix algebra identity:
where
I is a set of column indices.
3.4.5. : Hidden-Dimension Tensor Parallel Decoder
Target signal. HEAVY_DECOMPOSITION (multi-layer decomposed decoder).
Strategy
is the most important strategy in the library. It partitions the hidden dimension
D of the Hi-C decoder evenly across
N ranks and transforms every hidden-dimension-dependent operator into a sharded yet mathematically equivalent form. Let the hidden dimension be partitioned as
, with rank
i owning channels
. A vector
is decomposed as
The core algebraic principle that makes Strategy
possible is the additive decomposability of the dot product over the hidden dimension:
This identity holds because the query and key vectors are concatenated along the hidden dimension across ranks, and the full dot product is exactly the sum of the partial dot products over the concatenated sub-vectors. Every operator transformation in
Table 5 is derived from this decomposition combined with the standard properties of all-reduce and layer normalization. We now detail each of these transformations.
LayerNorm sharding. Rank
i computes the local partial sums over its hidden shard as follows:
After two all-reduce operations, each rank reconstructs the global statistics as follows:
The sharded output
which is identical to the corresponding slice of the unsharded LayerNorm output.
Linear layer sharding. For a weight matrix , rank i holds only the columns corresponding to its hidden shard. The computation is , then , and each rank retains .
Sharded decode. The decomposition of Equation (
13) is directly applied. Each rank computes its partial contribution as follows:
and a single all-reduce sum recovers the full result. This decomposes the
decode matrix multiplication into
N parallel
operation.
Sharded local attention. Rank
i computes local scores:
Note that the normalization is by the global dimension rather than the local dimension . After all-reduce, , then , which is equivalent to the unsharded computation.
Algorithm 2 integrates all five sharded operations into a complete per-stratum decoder attention block, illustrating the end-to-end tensor-parallel computation flow on each rank.
| Algorithm 2 Sharded Decoder Attention Block (per rank i, per stratum ) |
Input: , , decoder parameters for rank i
Output: , updated
- 1:
▹ Extract hidden shard - 2:
- 3:
- 4:
- 5:
▹ each - 6:
- 7:
fortodo - 8:
, - 9:
▹ per-chunk computation, window-masked as in - 10:
end for - 11:
- 12:
- 13:
- 14:
- 15:
- 16:
- 17:
return ,
|
All five operations in
Table 5 satisfy the identity
— each rank’s sharded output exactly equals the corresponding slice of the unsharded single-GPU computation. The proof is divided into three steps.
Dot product additivity. For any query
and key
, the decomposition in Equation (
13) follows from the definition of the dot product as a sum over the hidden dimension. This identity is the algebraic foundation for sharded decoding, sharded local attention scores, and sharded GEGLU FFN partial computations.
LayerNorm invariance under sharding. The global mean
and variance
are reconstructed from the per-rank partial sums via Equations (
14) and (
15). Since the normalization
is element-wise after
and
are determined, applying it to the local shard with local parameters
produces results identical to applying the full LayerNorm and then slicing.
Softmax equivalence under global normalization. The sharded attention computation is normalized by the global dimension
as shown in Equation (
18). After all-reduce,
, which is exactly the unsharded pre-softmax attention score matrix. Applying softmax to the correctly reconstructed global scores and multiplying by the local value shard yields an attention output identical to the corresponding slice of the unsharded result.
These three properties, all relying on the basic algebraic fact that summation over a partitioned index is order-independent, collectively guarantee the full mathematical equivalence of Strategy to single-GPU computation.
The ATTENTION_DECODER signal is broadly defined to detect the presence of any attention mechanism in the decoder. Strategies
and the attention-specific portion of
are instantiated here for the local (sliding-window) attention pattern found in HiGLUE’s Hi-C decoder. For models employing full self-attention, the same signal would activate alternative strategies (such as Flash Attention or sequence-parallel attention). The remaining strategies
,
,
, and the non-attention components of
(LayerNorm, Linear, and Decode sharding) are independent of the attention variant. They are applied directly to any multi-modal 3D biological model with a large 3D spatial modality. New strategies for different attention patterns can be registered through the framework’s extensibility interface (Design Principle 5,
Section 3.1) without modifying the profiling or diagnosis logic.
3.4.6. Strategy Composition
The five strategies interact through well-defined functional dependencies that determine how they are composed in the unified per-stratum computation (Algorithm 1). These dependencies reflect which strategy’s output serves as another’s input, not a sequential execution pipeline. determines the local feature partition, and applies key-first decoding using this partition—hence, must be evaluated before . provides chunked attention as a reusable operator that does not depend on the outputs of or ; it is invoked within the attention block regardless of which features are being processed. internally invokes both and in its sharded attention block: the chunked attention mechanism from is called with cross-rank score reduction enabled, and the key-first decoding approach of is adopted in the sharded decode step. operates as the underlying multi-process infrastructure throughout.
3.5. Auto-Execute: Strategy Composition and Execution
The third phase executes the prescribed strategy configuration. Strategy (data parallelism) serves as the underlying multi-process infrastructure and is always enabled when multiple GPUs are available.
When no bottleneck signals are detected, only is activated, and multi-process data-parallel training is launched. The trainer inherits the original single-GPU training procedure and augments the per-iteration step with manual gradient all-reduce synchronization.
When one or more bottleneck signals are detected, the framework activates the corresponding subset of strategies as prescribed by the diagnosis phase (
Section 3.3). All activated strategies are integrated into the training loop:
and
operate in the feature-parallel loss function;
and
operate through the tensor-parallel attention primitives. The strategies are composed according to their functional dependencies (
Section 3.4.6):
determines the local feature partition,
applies key-first decoding using this partition,
and
operate through the tensor-parallel attention primitives, and
provides the data-parallel infrastructure throughout.
4. Results
This section empirically validates the framework through four experiments: lightweight configurations verify that no unnecessary parallelism is imposed (
Section 4.1), a leave-one-out ablation quantifies each strategy’s marginal memory contribution (
Section 4.2), per-stratum memory profiling confirms the EPM budget model (
Section 4.3), and cross-configuration evaluation tests generalization and GPU scaling (
Section 4.4). HiRES [
8] is among the few publicly available datasets providing paired single-cell Hi-C and RNA-seq at genome-wide scale with published preprocessing pipelines—the multi-modal 3D spatial plus orthogonal modality setting that AMP targets.
We evaluated AMP on the HiRES mouse brain dataset [
8], under four configurations combining two bin sizes (100 kb, ∼27 K bins; 500 kb, ∼5.4 K bins) and two stratum counts (10 and 32). All experiments are conducted on servers equipped with NVIDIA RTX 3090 GPUs (24 GB each) using NCCL-based distributed communication. Hyperparameters (learning rate
, latent dimension 64, hidden dimension 128) were held constant across all runs. A batch size of 128 was used for the 100 kb configurations and 256 for the 500 kb configurations, with chunk size 32 for local attention on 100 kb and 128 on 500 kb. Each configuration ran one pretrain stage and one fine-tune stage, both consisting of five iterations in benchmark mode. Lightweight configurations (10 strata) serve as the baseline where single-GPU training is feasible; heavyweight configurations (32 strata) exhaust single-GPU memory and are addressed by the prescribed five-strategy configuration (
–
), of which
and
are the decisive enablers (
Section 4.2).
All timing values represent per-epoch averages obtained over 10 training iterations (five pretrain followed by five fine-tune) per configuration. Peak GPU memory values are reported as mean ± standard deviation over the same 10 iterations, measured via torch.cuda.max_memory_allocated() at the end of each iteration. Memory standard deviations were below 0.7 GB in all tested configurations.
4.1. Multi-GPU Speedup on Lightweight Configurations
We first evaluated whether AMP provides training acceleration on configurations that already fit on a single GPU. Both lightweight configurations (10 strata) were trained with data parallelism (
only) using four GPUs.
Table 6 shows that four-GPU data parallelism yields approximately a 1.5× speedup on the 500 kb configuration (1.59 to 1.09 s/epoch). The 100 kb configuration showed a larger per-epoch time under DDP (4.67 s vs. 2.34 s on a single GPU), reflecting load imbalance caused by the highly variable number of genomic contacts per cell at finer bin resolution.
The 500 kb configuration achieves a 1.5× speedup under data parallelism, while the 100 kb configuration becomes slower (4.67 vs. 2.34 s/epoch) due to load imbalance: at finer bin resolution, the per-cell graph complexity varies substantially, causing uneven work distribution across GPUs. This imbalance motivates the strategy-based parallelism of AMP—pure data parallelism alone is insufficient when the data distribution is inherently skewed. Nevertheless, these results confirm two important baselines: first, AMP’s data-parallel backbone () operates correctly; second, the auto-diagnosis layer correctly limits parallelism to data-parallel mode for configurations that fit comfortably on a single GPU, imposing no unnecessary overhead.
4.2. Leave-One-Out Ablation: Quantifying Per-Strategy Memory Impact
To quantify the marginal contribution of each parallel strategy to memory reduction, we conducted a leave-one-out ablation on the 500 kb–32 strata configuration using eight GPUs with a batch size of one. Starting from all five strategies (
–
), we disabled one strategy at a time and recorded the training outcome and peak per-GPU memory.
Table 7 reports these results.
The results revealed a clear hierarchy of memory impact among the strategies.
(Chunked Local Attention) is the first line of defense. Disabling while keeping and active results in an immediate out-of-memory failure in stratum 1. The unchunked local attention attempts to allocate a full attention matrix for guidance graph nodes, requesting approximately 48 GiB in a single allocation, exceeding the GPU capacity by a factor of two. Without chunking, the decoder cannot complete even the first attention-enabled stratum, regardless of the other optimizations.
(Hidden-Dimension Tensor Parallelism) is the decisive enabler. Disabling causes per-stratum memory growth to accelerate from approximately 0.56 GB per stratum to 1.38 GB per stratum, more than doubling the decoder’s activation footprint. Training fails at stratum 16 after exceeding 22.33 GB, well before reaching the final stratum. With active, all strategies complete all 32 strata within 17.87 GB—over 25% below the 24 GB limit. This confirms as the strategy that converts an untrainable configuration into a trainable configuration.
(Local Key Optimization) provides significant but non-critical savings. Without , the decoder reverts to gathering full key and value tensors across all hidden-dimension ranks before each attention operation. Training still completes, but the peak memory increases by 3.24 GB (from 17.87 to 21.11 GB), confirming that eliminates a substantial intermediate memory cost while not being individually indispensable for this configuration.
(Feature-Parallel Encoder) primarily optimizes computation rather than memory. Disabling leaves the peak memory unchanged at 17.87 GB. The output of the encoder projection layer is only , and its weight tensor (∼77 MB) is negligible compared to the decoder’s multi-GB activations. reduces the per-GPU FLOPs by distributing the large feature-dimension matrix multiplication across ranks.
Taken together, this leave-one-out analysis establishes a clear ranking of memory impact: . and are both individually necessary for training to succeed in this configuration; provides measurable relief; and optimizes throughput rather than memory footprint.
4.3. Per-Stratum Memory Profile
To understand why the 32-strata configuration exhausts GPU memory and how resolves this, we profile per-GPU memory consumption after each decoder stratum with all five strategies on the 100kb–32 strata configuration (4 GPUs). Memory snapshots are collected at key points across the training pipeline using built-in profiling hooks.
Figure 4 shows the near-linear growth of per-GPU memory with stratum index, approximately 0.45 GB per stratum once the attention-enabled decoder layers begin. With all five strategies activated,
partitions the hidden dimension
D across GPUs, keeping the per-stratum memory increment sufficiently low to remain below 24 GB across all 32 strata. The observed linear accumulation is consistent with the analytical EPM budget model formulated in
Section 3.2: each additional stratum adds a predictable
to the per-GPU memory footprint, and hidden-dimension sharding compresses this per-stratum cost by approximately a factor of
N (the number of GPUs). Extrapolating the same per-stratum increment without
confirms that the 24 GB limit would be exceeded at approximately stratum 28, which aligns with single-GPU profiling, where memory exceeded 22 GB at strata 18–19 under the original unoptimized decoder.
4.4. Generalization and Scalability Across Configurations
To assess whether the five-strategy configuration generalizes beyond a single configuration and GPU count, we evaluated AMP across both bin sizes on four and eight GPUs.
Table 8 confirms that AMP generalizes across bin sizes. Both 32-strata configurations exhaust single-GPU memory but complete training under all five strategies on four GPUs. The 500 kb configuration peaks at 17.89 ± 0.15 GB, and the 100 kb configuration peaks at 19.51 ± 0.67 GB, both safely under the 24 GB limit. The absolute memory values differ by approximately 1.6 GB despite the 100 kb configuration having five times as many input bins (∼27 K vs. ∼5.4 K), demonstrating that
effectively decouples per-device memory from the input dimensionality, which is a direct consequence of sharding along the hidden dimension rather than the feature dimension.
Table 9 reports the throughput scaling from four to eight GPUs. The 100 kb configuration achieves a 3.8× speedup by doubling the number of GPUs, while the 500 kb configuration gains a 2.0× improvement. The larger relative gain on 100 kb is expected because its substantially larger bin count (∼27K vs. ∼5.4 K bins) makes the attention and decode stages the dominant runtime components, where
and
achieve the greatest parallelization efficiency through finer hidden-dimension sharding. Peak per-GPU memory does not decrease proportionally with the GPU count beyond four GPUs: at
, the residual memory is dominated by the non-sharded components (RNA encoder, guidance graph embedding, and loss computation) rather than the decoder’s hidden-dimension activations. This memory plateau confirms that
has effectively saturated the hidden-dimension memory bottleneck on the decoder side; further memory reduction would require sharding of the encoder and graph embedding stages, which are not currently targeted by the strategy library.
The speedup gap between the two configurations is explained by the communication overhead of the distributed deployment. Each attention stratum under
incurs the all-reduce operations listed in
Table 5 (up to 10 per stratum across 31 attention-enabled strata). The 500 kb configuration (∼5.4 K bins) has a lighter compute load, resulting in a higher communication-to-computation ratio and a modest 2.0× speedup. In contrast, the 100 kb configuration (∼27 K bins) carries a substantially heavier attention and decode workload that amortizes the same communication cost, achieving a 3.8× speedup.
Finally, to verify that standard data parallelism alone is insufficient for heavyweight configurations, we evaluate both 32-strata configurations under pure
on four GPUs. In both cases, each GPU holds a full model replica, and the structural memory bottleneck, identified as predictable from the data configuration in
Section 3.2, is duplicated rather than distributed across devices. Training fails with OOM before the first epoch is completed. This negative result confirms the analysis in
Section 2.2: generic parallelism strategies that ignore modality-specific structural asymmetry cannot resolve the predictable memory bottleneck of multi-modal 3D biological models. Only the diagnosis-driven five-strategy configuration, with hidden-dimension tensor parallelism as its decisive component, converts out-of-memory failures into successful training runs.
4.5. Summary of Experimental Validation
Taken together, the experimental results validate the three core claims of this work. First, the memory bottleneck of multi-modal 3D biological models is structurally predictable from data configuration alone: the EPM budget model accurately captures the linear accumulation of per-stratum activations, and the OOM threshold can be identified before training begins. Second, bottleneck-specific diagnosis can replace cost-model search: three orthogonal signals mapped to a prescribed set of five strategies suffice to resolve the bottleneck in all tested configurations without any trial-and-error strategy exploration. Third, hidden-dimension tensor parallelism () is the decisive enabler: the leave-one-out ablation on the 500 kb–32 strata configuration shows that disabling more than doubles the per-stratum memory growth and causes OOM at stratum 16, while all strategies complete all 32 strata within 17.87 GB. The generalization of these results across two bin sizes and two GPU counts further demonstrates that AMP framework is not configuration-specific but rather addresses a structural property of this model class.
5. Conclusions
In this work we addressed the structural memory bottleneck that prevents training heavyweight multi-modal 3D biological models on single GPUs. We showed that in models of this class, the location and severity of the bottleneck are deterministically predictable from the data configuration, and that this predictability enables a diagnostic approach: three bottleneck signals (HIGH_INPUT_DIMENSIONALITY, ATTENTION_DECODER, and HEAVY_DECOMPOSITION) are extracted via zero-cost profiling and mapped to a prescribed set of parallel strategies. The resulting AMP framework deploys five mathematically equivalent strategies (–), with hidden-dimension tensor parallelism () serving as the decisive enabler.
Experimental validation on the HiRES mouse brain dataset demonstrated that all five strategies enable the training of previously untrainable 32-strata configurations at both 100 kb and 500 kb bin sizes. A leave-one-out ablation on the 500 kb configuration confirmed that (chunked attention) and (hidden-dimension tensor parallelism) are each individually necessary for training to succeed, while (local key optimization) provides significant supplementary memory savings and (feature-parallel encoding) optimizes computational throughput. Per-stratum memory profiling confirmed the linear accumulation predicted by the EPM budget model and showed that compressed the per-stratum activation footprint sufficiently to keep memory within hardware limits. Pure data parallelism on the same configuration failed with OOM on four GPUs, confirming that generic parallelism without modality-aware diagnosis cannot resolve structural-memory bottlenecks.
Several limitations define the current scope of AMP. First, bottleneck signal definitions and strategy implementations were instantiated for Hi-C-based 3D genomic models with local attention decoders. Although the ATTENTION_DECODER signal is broadly defined to cover any attention mechanism, concrete strategies for full self-attention decoders (e.g., Flash Attention substitution) remain to be integrated and tested. Second, the auto-diagnosis module currently operates on a fixed three-signal vocabulary; extending the framework to additional modality types (e.g., spatial transcriptomics and imaging-based 3D data) will require registering new bottleneck signals and corresponding strategies through the extensibility interface.
Beyond the immediate generalization to other multi-modal 3D biological models, the core insight that domain-specific architectural asymmetries can replace generic cost-model searching has implications for the broader parallel computing community. Multi-modal 3D deep learning architectures where one spatial modality dominates computation face the same class of predictable memory bottlenecks. Applying AMP diagnostic framework to 3D computer vision models, such as multi-modal LiDAR-camera detectors in autonomous driving, represents a natural next step. Future work should also explore the integration of activation checkpointing as a complementary memory-saving mechanism, incorporate a communication-aware scheduling layer that selects among parallelism strategies based on cluster topology, and provide automated GPU-count recommendation that balances memory savings against communication cost. Extending the strategy library to support heterogeneous GPU clusters represents another practical direction.