if tf.config.list_physical_devices('TPU'):
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))
strategy = tf.distribute.TPUStrategy(resolver)
elif tf.config.list_physical_devices('GPU'):
strategy = tf.distribute.MirroredStrategy()
else: # Use the Default Strategy
strategy = tf.distribute.get_strategy()