Tensorflow: How to replace or modify gradient ?

Tensorflow: How to replace or modify gradient ?

Asked on December 19, 2018 in Tensorflow.
Add Comment


  • 1 Answer(s)

    A simple way to do it by using tf.RegisterGradient

    Implement the backpropagated gradient clipping with matmul. which shown below:

    import tensorflow as tf
    import numpy as np
     
    # from https://gist.github.com/harpone/3453185b41d8d985356cbe5e57d67342
    def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
     
        # Need to generate a unique name to avoid duplicates:
        rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))
     
        tf.RegisterGradient(rnd_name)(grad)
        g = tf.get_default_graph()
        with g.gradient_override_map({"PyFunc": rnd_name}):
            return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
     
    def clip_grad(x, clip_value, name=None):
        """"
        scales backpropagated gradient so that
        its L2 norm is no more than `clip_value`
        """
        with tf.name_scope(name, "ClipGrad", [x]) as name:
            return py_func(lambda x : x,
                            [x],
                            [tf.float32],
                            name=name,
                            grad=lambda op, g : tf.clip_by_norm(g, clip_value))[0]
    

    Example Implementation:

    with tf.Session() as sess:
        x = tf.constant([[1., 2.], [3., 4.]])
        y = tf.constant([[1., 2.], [3., 4.]])
     
        print('without clipping')
        z = tf.matmul(x, y)
        print(tf.gradients(tf.reduce_sum(z), x)[0].eval())
     
        print('with clipping')
        z = tf.matmul(clip_grad(x, 1.0), clip_grad(y, 0.5))
        print(tf.gradients(tf.reduce_sum(z), x)[0].eval())
     
        print('with clipping between matmuls')
        z = tf.matmul(clip_grad(tf.matmul(x, y), 1.0), y)
        print(tf.gradients(tf.reduce_sum(z), x)[0].eval())
    

    Result:

    without clipping
    [[ 3.  7.]
     [ 3.  7.]]
    with clipping
    [[ 0.278543   0.6499337]
     [ 0.278543   0.6499337]]
    with clipping between matmuls
    [[ 1.57841039  3.43536377]
     [ 1.57841039  3.43536377]]
    
    Answered on December 19, 2018.
    Add Comment


  • Your Answer

    By posting your answer, you agree to the privacy policy and terms of service.