DiscoverEarthly Machine LearningJigsaw: Training Multi-Billion-Parameter AI Weather Models With Optimized Model Parallelism
Jigsaw: Training Multi-Billion-Parameter AI Weather Models With Optimized Model Parallelism

Jigsaw: Training Multi-Billion-Parameter AI Weather Models With Optimized Model Parallelism

Update: 2025-10-24
Share

Description

Jigsaw: Training Multi-Billion-Parameter AI Weather Models With Optimized Model ParallelismAuthors: Deifilia Kieckhefen, Markus Götz, Lars H. Heyen, Achim Streit, and Charlotte Debus (Karlsruhe Institute of Technology, Helmholtz AI)

The paper introduces WeatherMixer (WM), a multi-layer perceptron (MLP)-based architecture designed for atmospheric forecasting, which serves as a competitive alternative to Transformer-based models. WM's workload scales linearly with input size, addressing the scaling challenges and quadratic computational complexity associated with the self-attention mechanism in Transformers when dealing with gigabyte-sized atmospheric data.• A novel parallelization scheme called Jigsaw parallelism is proposed, combining both domain parallelism and tensor parallelism to efficiently train multi-billion-parameter models. Jigsaw is optimized for large input data by fully sharding the data, model parameters, and optimizer states across devices, eliminating memory redundancy.

 Jigsaw effectively mitigates hardware bottlenecks, particularly I/O-bandwidth limitations frequently encountered in training large scientific AI models. Due to its partitioned data loading (domain parallelism), the scheme achieves superscalar weak scaling in I/O-bandwidth-limited systems.

 The method demonstrates excellent scaling behavior on high-performance computing systems, exceeding state-of-the-art performance in strong scaling in computation–communication-limited systems. The training was successfully scaled up to 256 GPUs, reaching peak performances of 9 and 11 PFLOPs.• Beyond hardware efficiency, Jigsaw improves predictive performance: by partitioning the model across more GPUs (model parallelism) instead of relying solely on data parallelism, it naturally enforces smaller global batch sizes, which empirically helps mitigate the problematic large-batch effects observed in AI weather models, leading to lower loss values.

Comments 
In Channel
loading
00:00
00:00
x

0.5x

0.8x

1.0x

1.25x

1.5x

2.0x

3.0x

Sleep Timer

Off

End of Episode

5 Minutes

10 Minutes

15 Minutes

30 Minutes

45 Minutes

60 Minutes

120 Minutes

Jigsaw: Training Multi-Billion-Parameter AI Weather Models With Optimized Model Parallelism

Jigsaw: Training Multi-Billion-Parameter AI Weather Models With Optimized Model Parallelism

Amirpasha