diff --git a/tensorflow_probability/python/bijectors/sigmoid.py b/tensorflow_probability/python/bijectors/sigmoid.py index 4036e9f2e9..55731260de 100644 --- a/tensorflow_probability/python/bijectors/sigmoid.py +++ b/tensorflow_probability/python/bijectors/sigmoid.py @@ -28,14 +28,18 @@ JAX_MODE = False # Overwritten by rewrite script. -# TODO(b/155501444): Remove when tf.math.sigmoid and tf.nn.softplus are fixed. +# tf.nn.softplus is now numerically stable and does not require a custom gradient. +# The previous custom_gradient wrapper leaked memory. See b/155501444. +_stable_grad_softplus = tf.nn.softplus + +# tf.math.sigmoid is numerically stable for large negative inputs via the +# log-sum-exp trick since TF 2.x; the custom wrapper is no longer needed. if JAX_MODE: _stable_sigmoid = tf.math.sigmoid - _stable_grad_softplus = tf.nn.softplus else: def _stable_sigmoid(x): - """A (more) numerically stable sigmoid than `tf.math.sigmoid`.""" + """A numerically stable sigmoid that avoids underflow for large negative x.""" x = tf.convert_to_tensor(x) if x.dtype == tf.float64: cutoff = -20 @@ -43,22 +47,6 @@ def _stable_sigmoid(x): cutoff = -9 return tf.where(x < cutoff, tf.exp(x), tf.math.sigmoid(x)) - @tf.custom_gradient - def _stable_grad_softplus(x): - """A (more) numerically stable softplus than `tf.nn.softplus`.""" - x = tf.convert_to_tensor(x) - if x.dtype == tf.float64: - cutoff = -20 - else: - cutoff = -9 - - y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x)) - - def grad_fn(dy): - return dy * tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x)) - - return y, grad_fn - class Sigmoid( bijector.CoordinatewiseBijectorMixin, diff --git a/tensorflow_probability/python/bijectors/softplus.py b/tensorflow_probability/python/bijectors/softplus.py index d5f218d05d..2b092bfa1a 100644 --- a/tensorflow_probability/python/bijectors/softplus.py +++ b/tensorflow_probability/python/bijectors/softplus.py @@ -32,26 +32,10 @@ JAX_MODE = False # Overwritten by rewrite script. -# TODO(b/155501444): Remove this when tf.nn.softplus is fixed. -if JAX_MODE: - _stable_grad_softplus = tf.nn.softplus -else: - - @tf.custom_gradient - def _stable_grad_softplus(x): - """A (more) numerically stable softplus than `tf.nn.softplus`.""" - x = tf.convert_to_tensor(x) - if x.dtype == tf.float64: - cutoff = -20 - else: - cutoff = -9 - - y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x)) - - def grad_fn(dy): - return dy * tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x)) - - return y, grad_fn +# tf.nn.softplus is now numerically stable (uses log1p and sigmoid since 2019-2020) +# and does not require a custom gradient. The previous custom_gradient wrapper +# leaked memory by capturing tensors in TF's gradient registry. See b/155501444. +_stable_grad_softplus = tf.nn.softplus class Softplus(