GPU-accelerated training using PyTorch Metal backend
Replaces DGX Spark: Fine-tune with PyTorch
pytorchfine-tuning
Basic idea
PyTorch's MPS (Metal Performance Shaders) backend lets PyTorch dispatch tensor operations to Apple Silicon's GPU instead of the CPU. When you call .to(device("mps")) on a tensor, PyTorch routes all computations for that tensor through Metal compute shaders โ the same GPU API that powers macOS graphics. This is the standard path for HuggingFace Trainer-based fine-tuning workflows that don't have direct MLX equivalents, such as PEFT/LoRA training with custom datasets and callbacks.
MPS is functional but slower than MLX for most workflows. Use it when you need PyTorch-specific libraries (PEFT, DeepSpeed, custom training loops) that haven't been ported to MLX.
What you'll accomplish
A working PyTorch MPS fine-tuning environment on your Mac's GPU. You will LoRA fine-tune a 7B parameter causal language model using HuggingFace PEFT and Trainer, with checkpoints saved to disk and a verified training run showing decreasing loss.
What to know before starting
Device placement: PyTorch tensors must be on the same device to operate on each other. If the model is on MPS and the labels are on CPU, you get a runtime error โ every tensor in a batch must be explicitly moved.
Gradient accumulation: Simulates a larger batch size by running `N` forward passes and summing gradients before each optimizer step. `per_device_train_batch_size=1` with `gradient_accumulation_steps=8` gives an effective batch size of 8 without needing 8x the memory.
PEFT/LoRA: Instead of updating all model weights, LoRA injects small trainable rank-decomposition matrices into specific layers. Only ~0.1% of parameters are trained, keeping memory and compute manageable.
MPS operation gaps: Metal lacks some CUDA primitives. When PyTorch encounters an unimplemented op on MPS, it either raises `NotImplementedError` or silently falls back to CPU depending on your environment variable setting.
float32 requirement: MPS cannot reliably execute the fp16 backward pass โ gradients overflow or go NaN. You must use `torch_dtype=torch.float32` and `fp16=False` in TrainingArguments. This roughly doubles memory usage compared to CUDA fp16 training.
Prerequisites
โข macOS 12.3 or later (MPS was introduced in 12.3)
โข Apple Silicon Mac (M1, M2, or M3 family)
โข Python 3.9 or later
โข Xcode Command Line Tools installed (`xcode-select --install`)