arccos函数在定义域边界上梯度为无穷,因此会迅速导致梯度爆炸。
在一些ML任务中,我们可能会在计算loss时试图使用反三角函数。例如在位姿估计时对旋转矩阵设置loss时可能会用到轴角式(i.e.Axis-angle)的转角,用旋转矩阵计算此转角时会用到arccos。
arccos定义域限制在[-1,1]中。为了排除异常值以及防止舍入误差等导致正确结果越出定义域,则通常会把输入clip到[-1,1]。然而arccos在两个端点上的梯度为无穷值,这将导致梯度爆炸(权值和输出很快变为nan)。
因此**在需要反向传播梯度时,千万不要直接把arccos的输入直接clip到[-1,1]上。**解决方法主要有三种(参考https://github.com/pytorch/pytorch/issues/8069):
- 改为clip到[-1+eps,1-eps]上,eps是一较小正数。注意到arccos在此处的梯度很大,截断带来的误差也会变大,因此eps要在防止梯度过大和防止误差过大之间进行取舍。
- 在[-1,-1+eps]和[1-eps,1]中用线性插值近似arccos
- 除非万不得已,否则尽量使用数学性质更好的损失函数。在位姿估计中使用lp范数的效果通常比用转角差的效果好得多。
(BTW,在梯度爆炸时可以借助torch自有的探查功能进行debug:https://pytorch.org/docs/stable/autograd.html#torch.autograd.detect_anomaly)