[mlgo] Fix flaky test

The source of the flakyness is internal uninitialized buffers due to
a dangling variable in the model.
This commit is contained in:
Mircea Trofin 2022-08-26 21:29:25 -07:00
parent 53f1cc85e3
commit 3546b5c520
1 changed files with 1 additions and 7 deletions

View File

@ -102,14 +102,8 @@ def get_output_spec_path(path):
def build_mock_model(path, signature):
"""Build and save the mock model with the given signature"""
module = tf.Module()
# We have to set this useless variable in order for the TF C API to correctly
# intake it
module.var = tf.Variable(0.)
def action(*inputs):
s = tf.reduce_sum([tf.cast(x, tf.float32) for x in tf.nest.flatten(inputs)])
return {signature['output']: tf.cast(tf.divide((s + module.var), tf.abs(s + module.var)), tf.int64)}
return {signature['output']: tf.constant(value=1, dtype=tf.int64)}
module.action = tf.function()(action)
action = {'action': module.action.get_concrete_function(signature['inputs'])}