0

I am trying to experiment with TF2.x to achieve the best possible TPU speed. In this case, I have modified a few examples and come up with my own implementation. However, as the code stands I am completely confused as to how to implement gradient descent in this manner.

For example, this example can definitely help. However, after quite some trial and error, I found it impossible to do due to this decorator @tf.function(jit_compile=True).

Can someone provide some insights? It will be greatly appreciated.

RanWang
  • 310
  • 2
  • 12

0 Answers0