9. 演算のデバイス割り振り
def parallel(devices, fn, *args, **kwargs):
...
for i, device in enumerate(devices):
with tf.device(device):
with tf.variable_scope("parallel_%d" % i):
my_args = [x[i] for x in args]
my_kwargs = {k: v[i] for k, v in six.iteritems(kwargs)}
ret.append(fn(*my_args, **my_kwargs))
内部的にwith tf.deviceで各デバイスに演算を割り当てている