355 lines
13 KiB
Python
355 lines
13 KiB
Python
# Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
|
# provided that the following conditions are met:
|
|
# * Redistributions of source code must retain the above copyright notice, this list of
|
|
# conditions and the following disclaimer.
|
|
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
|
# conditions and the following disclaimer in the documentation and/or other materials
|
|
# provided with the distribution.
|
|
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
|
# to endorse or promote products derived from this software without specific prior written
|
|
# permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
|
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
|
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
|
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
|
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
|
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
# @file bench_tensorflow.py
|
|
# @author Thomas Müller, NVIDIA
|
|
# @brief Generates performance data for comparison with our fully fused network.
|
|
|
|
import argparse
|
|
import commentjson as json
|
|
import glob
|
|
import math
|
|
import numpy as np
|
|
import os
|
|
import sys
|
|
import tensorflow.compat.v1 as tf
|
|
import tensorflow_probability as tfp
|
|
import time
|
|
|
|
SCRIPTS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "scripts")
|
|
sys.path.insert(0, SCRIPTS_DIR)
|
|
|
|
from common import read_image, write_image, ROOT_DIR
|
|
|
|
DATA_DIR = os.path.join(ROOT_DIR, "data")
|
|
IMAGES_DIR = os.path.join(DATA_DIR, "images")
|
|
|
|
class Function:
|
|
def __init__(self, domain, n_channels, n_dims, wraparound_dims, n_conditionals, n_raw_conditionals):
|
|
self.domain = domain
|
|
self.n_channels = n_channels
|
|
self.n_dims = n_dims
|
|
self.wraparound_dims = wraparound_dims
|
|
self.n_conditionals = n_conditionals
|
|
self.n_raw_conditionals = n_raw_conditionals
|
|
|
|
def __call__(self, xs):
|
|
raise NotImplementedError
|
|
|
|
class Image(Function):
|
|
def __init__(self, filename):
|
|
self.filename = filename
|
|
paths = glob.glob(os.path.join(IMAGES_DIR, self.filename + ".*"))
|
|
if not paths:
|
|
raise ValueError(f"Invalid image name '{filename}''")
|
|
path = paths[0] # Use first path that exists
|
|
self.data = read_image(path)
|
|
if self.data.shape[-1] > 3:
|
|
self.data = self.data[:,:,0:3]
|
|
self.data_tf = tf.constant(self.data, dtype=tf.float32)
|
|
super().__init__('unit', self.data.shape[-1], 2, {}, 0, 0)
|
|
|
|
def __call__(self, xs):
|
|
shape = self.data.shape
|
|
indices = (xs * np.array([shape[1], shape[0]])).astype(np.uint32)
|
|
indices[:, 0] = np.clip(indices[:, 0], a_min=0, a_max=shape[1]-1)
|
|
indices[:, 1] = np.clip(indices[:, 1], a_min=0, a_max=shape[0]-1)
|
|
return self.data[indices[:, 1], indices[:, 0]]
|
|
|
|
def eval_tf(self, xs):
|
|
shape = self.data_tf.shape
|
|
indices = tf.cast(xs * tf.constant([shape[1], shape[0]], dtype=tf.float32), tf.int32)
|
|
indices_clipped = tf.stack([
|
|
tf.clip_by_value(indices[:, 1], 0, shape[0]-1),
|
|
tf.clip_by_value(indices[:, 0], 0, shape[1]-1),
|
|
], axis=-1)
|
|
return tf.gather_nd(self.data_tf, indices_clipped)
|
|
|
|
class OneBlob:
|
|
def __init__(self, n_bins, n_levels):
|
|
self.n_bins = n_bins
|
|
self.n_levels = n_levels
|
|
self.radius = 0.5 / n_bins
|
|
|
|
def __call__(self, inputs, wraparound, name, dtype=None):
|
|
def gaussian_cdf_approx(x, radius):
|
|
return 0.5 * (1 + tf.tanh(1.12 * x / (math.sqrt(2.) * radius)))
|
|
|
|
def gaussian_cdf(x, radius):
|
|
return 0.5 * (1 + tf.erf(x / (math.sqrt(2.) * radius)))
|
|
|
|
dims = inputs.shape[-1]
|
|
with tf.name_scope(name):
|
|
# When there are no input dims, there is nothing to encode.
|
|
# This special case is needed because tf.reshape does strange
|
|
# things when 0-dims are involved.
|
|
if dims == 0:
|
|
return inputs
|
|
results = []
|
|
boundaries = tf.linspace(0., 1., self.n_bins + 1)
|
|
boundaries = tf.reshape(boundaries, [1 for _ in inputs.shape] + [-1])
|
|
|
|
for level in range(self.n_levels):
|
|
with tf.name_scope(f"level{level}"):
|
|
scale = self.n_bins**level
|
|
|
|
# We use the absolute value here just in case the inputs are erroneously negative.
|
|
# Even a negative epsilon would totally wreck the following code.
|
|
if level == 0:
|
|
scaled = tf.abs(inputs)
|
|
else:
|
|
scaled = tf.abs(inputs * scale) % 1
|
|
|
|
diffs = boundaries - scaled[..., tf.newaxis]
|
|
cdfs = gaussian_cdf_approx(diffs, self.radius)
|
|
result = cdfs[...,1:] - cdfs[...,:-1]
|
|
|
|
# print_op = tf.print("result: ", result)
|
|
|
|
# In the outermost level we don't want to carry over...
|
|
# otherwise we introduce ambiguities.
|
|
if level != 0 or wraparound:
|
|
cdfs_right = gaussian_cdf_approx(diffs + 1., self.radius)
|
|
cdfs_left = gaussian_cdf_approx(diffs - 1., self.radius)
|
|
result = result + cdfs_right[...,1:] - cdfs_right[...,:-1] + cdfs_left[...,1:] - cdfs_left[...,:-1]
|
|
|
|
# with tf.control_dependencies([print_op]):
|
|
result = result / scale
|
|
|
|
results.append(result)
|
|
|
|
result = tf.concat(results, axis=-1)
|
|
result = tf.reshape(result, [-1, self.n_bins * self.n_levels * dims])
|
|
return result
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser(description="Image benchmark using TensorFlow.")
|
|
|
|
parser.add_argument("-c", "--config", default="config_oneblob.json", type=str, help="JSON config filename")
|
|
parser.add_argument("-i", "--image", default="albert", type=str, help="Image to match")
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
def linear_layer(inputs, units, dtype, name, use_biases=True):
|
|
# inputs: 2d Tensor, shape=(batch, in_units).
|
|
# units: Integer, dimensionality of the output space.
|
|
|
|
assert len(inputs.shape) == 2
|
|
with tf.variable_scope(name, reuse=tf.AUTO_REUSE) as scope:
|
|
weights = tf.get_variable("weights", (inputs.shape[1], units),
|
|
initializer=tf.glorot_uniform_initializer())
|
|
|
|
if use_biases:
|
|
biases = tf.get_variable("biases", (units),
|
|
initializer=tf.constant_initializer())
|
|
|
|
result = tf.matmul(tf.cast(inputs, dtype), tf.cast(weights, dtype))
|
|
if use_biases:
|
|
result = result + tf.cast(biases, dtype)
|
|
|
|
return tf.cast(result, tf.float32)
|
|
|
|
def activation(tensor, kind):
|
|
kind = kind.lower()
|
|
if kind == "relu":
|
|
return tf.nn.relu(tensor)
|
|
elif kind == "relu6":
|
|
return tf.nn.relu6(tensor)
|
|
elif kind == "elu":
|
|
return tf.nn.elu(tensor)
|
|
elif kind == "selu":
|
|
return tf.nn.selu(tensor)
|
|
elif kind == "leaky_relu":
|
|
return tf.nn.leaky_relu(tensor)
|
|
elif kind == "none":
|
|
return tensor
|
|
else:
|
|
assert(False)
|
|
|
|
def compute_gradients(loss, variables, loss_scale):
|
|
with tf.name_scope("gradient_computation"):
|
|
gradients = tf.gradients(loss * loss_scale, variables)
|
|
# Create zero gradients for None entries
|
|
zeros = [tf.zeros_like(var) for var in variables]
|
|
gradients = [grad / loss_scale if grad is not None else None for grad in gradients]
|
|
finites = [tf.reduce_all(tf.is_finite(grad)) if grad is not None else None for grad in gradients]
|
|
gradients = [tf.where(finite, grad, zero) if grad is not None else None for finite, grad, zero in zip(finites, gradients, zeros)]
|
|
|
|
all_finite = tf.reduce_all([f for f in finites if f is not None])
|
|
|
|
return gradients, all_finite
|
|
|
|
def get_train_op(config, variables, gradients, optimizer, clip_norm=0):
|
|
if clip_norm > 0:
|
|
gradients, gradients_norm = tf.clip_by_global_norm(gradients, clip_norm=clip_norm)
|
|
else:
|
|
gradients_norm = tf.global_norm(gradients)
|
|
|
|
if gradients and not all(grad is None for grad in gradients):
|
|
train_op = optimizer.apply_gradients(zip(gradients, variables), name="apply_gradients")
|
|
else:
|
|
train_op = tf.no_op(name="apply_gradients")
|
|
|
|
return train_op, gradients_norm
|
|
|
|
def make_graph():
|
|
uniform = tfp.distributions.Uniform()
|
|
input_tensor = uniform.sample((batch_size_tensor, target_fun.n_dims))
|
|
target_tensor = target_fun.eval_tf(input_tensor)
|
|
|
|
current_tensor = encoding(input_tensor, False, "encoding")
|
|
|
|
for i in range(config["network"]["n_hidden_layers"]):
|
|
current_tensor = linear_layer(current_tensor, config["network"]["n_neurons"], tf.float16, f"fc{i}", False)
|
|
current_tensor = activation(current_tensor, config["network"]["activation"])
|
|
|
|
output_tensor = linear_layer(current_tensor, target_fun.n_channels, tf.float16, f"fc_out", False)
|
|
output_tensor = activation(output_tensor, config["network"]["output_activation"])
|
|
|
|
relative_l2_error = (target_tensor - output_tensor)**2 / (tf.stop_gradient(output_tensor)**2 + 0.01)
|
|
loss = tf.math.reduce_mean(relative_l2_error)
|
|
|
|
LOSS_SCALE = 128
|
|
variables = tf.trainable_variables()
|
|
gradients, _ = compute_gradients(loss, variables, LOSS_SCALE)
|
|
train_op, _ = get_train_op(config, variables, gradients, optimizer)
|
|
|
|
return train_op, loss, input_tensor, output_tensor
|
|
|
|
if __name__ == "__main__":
|
|
tf.disable_eager_execution()
|
|
args = get_args()
|
|
|
|
# Initialize non-TF stuff
|
|
with open(os.path.join(DATA_DIR, args.config)) as config_file:
|
|
config = json.load(config_file)
|
|
|
|
target_fun = Image(os.path.join(IMAGES_DIR, args.image))
|
|
encoding = OneBlob(config["encoding"]["n_bins"], 1)
|
|
|
|
# Initialize TF graph
|
|
batch_size_tensor = tf.placeholder(tf.int32, shape=[])
|
|
optimizer = tf.train.AdamOptimizer(config["optimizer"]["learning_rate"], config["optimizer"]["beta1"], config["optimizer"]["beta2"], config["optimizer"]["epsilon"])
|
|
train_op, loss, input_tensor, output_tensor = make_graph()
|
|
|
|
# Variables for saving/displaying image results
|
|
resolution = 1024
|
|
img_shape = (resolution, resolution, target_fun.n_channels)
|
|
|
|
half_dx = 0.5 / resolution
|
|
xs = np.linspace(half_dx, 1-half_dx, resolution)
|
|
xv, yv = np.meshgrid(xs, xs)
|
|
|
|
xy = np.stack((xv.flatten(), yv.flatten())).transpose()
|
|
gt = np.reshape(target_fun(xy), img_shape)
|
|
write_image("reference.jpg", gt)
|
|
|
|
# Enable XLA compiler (important for good TensorFlow performance)
|
|
session_config = tf.ConfigProto()
|
|
session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
|
|
|
|
timer = time.perf_counter()
|
|
|
|
# Run the network
|
|
with tf.Session(config=session_config) as sess:
|
|
PRINT_INTERVAL = 100
|
|
|
|
bench_result = { "tensorflow": [] }
|
|
|
|
for batch_size in [2**14, 2**15, 2**16, 2**17, 2**18, 2**19, 2**20, 2**21]:
|
|
N_ITERS = 1000
|
|
PRINT_INTERVAL = 100
|
|
|
|
output_dummy_variable = tf.Variable(tf.zeros(shape=[batch_size, target_fun.n_channels], dtype=tf.float32), trainable=False)
|
|
sess.run(tf.initialize_all_variables())
|
|
|
|
# Training
|
|
c = lambda it, _, __: tf.less(it, PRINT_INTERVAL)
|
|
def body(it, sequencer, _):
|
|
with tf.control_dependencies([sequencer]):
|
|
local_train_op, local_loss, _, _ = make_graph()
|
|
with tf.control_dependencies([local_train_op]):
|
|
next_sequencer = tf.ones([])
|
|
return it+1, next_sequencer, local_loss
|
|
|
|
train_op, _, loss = tf.while_loop(c, body, [0, 1., 0.], parallel_iterations=1)
|
|
|
|
throughputs = []
|
|
for i in range(0, N_ITERS, PRINT_INTERVAL):
|
|
if i % PRINT_INTERVAL == 0:
|
|
_, loss_val = sess.run([train_op, loss], feed_dict={ batch_size_tensor: batch_size })
|
|
old_time = timer
|
|
timer = time.perf_counter()
|
|
elapsed_time = timer - old_time
|
|
throughput = PRINT_INTERVAL * batch_size / elapsed_time
|
|
throughputs.append(throughput)
|
|
print(f"Iteration#{i}: loss={loss_val} time={int(elapsed_time*1000000)}[µs] thp={throughput}/s")
|
|
else:
|
|
sess.run([train_op], feed_dict={ batch_size_tensor: batch_size })
|
|
|
|
|
|
img = np.reshape(sess.run(output_tensor, feed_dict={ input_tensor: xy, batch_size_tensor: xy.shape[0] }), img_shape)
|
|
filename = f"{batch_size}-after-{N_ITERS}-iters-tensorflow.jpg"
|
|
print(f"Saving {filename}")
|
|
write_image(filename, img)
|
|
|
|
mean_training_throughput = np.mean(throughputs[1:])
|
|
|
|
print(f"Finished training benchmark. Mean throughput is {mean_training_throughput}/s. Waiting 10s for GPU to cool down.")
|
|
time.sleep(10)
|
|
|
|
# Inference
|
|
_, _, _, tmp_out = make_graph()
|
|
inference_op = output_dummy_variable.assign(tmp_out)
|
|
|
|
N_ITERS *= 2
|
|
PRINT_INTERVAL *= 2
|
|
|
|
throughputs = []
|
|
for i in range(N_ITERS):
|
|
sess.run(inference_op, feed_dict={ batch_size_tensor: batch_size })
|
|
if i % PRINT_INTERVAL == 0:
|
|
old_time = timer
|
|
timer = time.perf_counter()
|
|
elapsed_time = timer - old_time
|
|
throughput = PRINT_INTERVAL * batch_size / elapsed_time
|
|
throughputs.append(throughput)
|
|
print(f"Iteration#{i}: time={int(elapsed_time*1000000)}[µs] thp={throughput}/s")
|
|
|
|
mean_inference_throughput = np.mean(throughputs[1:])
|
|
|
|
print(f"Finished inference benchmark. Mean throughput is {mean_inference_throughput}/s. Waiting 10s for GPU to cool down.")
|
|
time.sleep(10)
|
|
|
|
# Mean throughput (discounting the first one due to XLA compilation)
|
|
bench_result["tensorflow"].append(
|
|
{
|
|
"batch_size" : batch_size,
|
|
"training_throughput" : mean_training_throughput,
|
|
"inference_throughput" : mean_inference_throughput,
|
|
}
|
|
)
|
|
|
|
with open("bench_result_tensorflow.json", "w") as f:
|
|
json.dump(bench_result, f)
|