Fast Differentiable Clipping-Aware Normalization and Rescaling
Rescaling a vector δ⃗∈ℝ^n to a desired length is a common operation in many areas such as data science and machine learning. When the rescaled perturbation ηδ⃗ is added to a starting point x⃗∈ D (where D is the data domain, e.g. D = [0, 1]^n), the resulting vector v⃗ = x⃗ + ηδ⃗ will in general not be in D. To enforce that the perturbed vector v is in D, the values of v⃗ can be clipped to D. This subsequent element-wise clipping to the data domain does however reduce the effective perturbation size and thus interferes with the rescaling of δ⃗. The optimal rescaling η to obtain a perturbation with the desired norm after the clipping can be iteratively approximated using a binary search. However, such an iterative approach is slow and non-differentiable. Here we show that the optimal rescaling can be found analytically using a fast and differentiable algorithm. Our algorithm works for any p-norm and can be used to train neural networks on inputs with normalized perturbations. We provide native implementations for PyTorch, TensorFlow, JAX, and NumPy based on EagerPy.
READ FULL TEXT