I am implementing a fairly complex network structure where I need to loop over a (possibly variable) time dimension, compute values at each step, and collect all these values, and do this for multiple layers. The values at each layer could have different shapes. I want to compile to a tf.function
to speed things up. I will first provide a minimal example of the problematic code. TF 2.15, Python 3.11.
This code works:
@tf.function
def test(inputs):
one_array = tf.TensorArray(tf.float32, size=tf.shape(inputs)[0])
for step in tf.range(tf.shape(inputs)[0]):
one_array = one_array.write(step, inputs[step])
return one_array.stack()
dummy_input = tf.random.normal([5])
test(dummy_input)
However this does not:
@tf.function
def test2(inputs):
n_layers = 2 # number doesn't matter
arrays = [tf.TensorArray(tf.float32, size=tf.shape(inputs)[0]) for _ in range(n_layers)]
for step in tf.range(tf.shape(inputs)[0]):
for ind in range(n_layers):
arrays[ind] = arrays[ind].write(step, inputs[step])
for ind in range(n_layers):
arrays[ind] = arrays[ind].stack()
return arrays # returning this directly without the stack() in the line above also crashes
test2(dummy_input)
raises:
---------------------------------------------------------------------------
InaccessibleTensorError Traceback (most recent call last)
Cell In[99], line 12
9 arrays[ind] = arrays[ind].stack()
10 return arrays
---> 12 test2(dummy_input)
File /project/jens/tf311/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
151 except Exception as e:
152 filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153 raise e.with_traceback(filtered_tb) from None
154 finally:
155 del filtered_tb
File /project/jens/tf311/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py:52, in py_func_from_autograph.<locals>.autograph_handler(*args, **kwargs)
50 except Exception as e: # pylint:disable=broad-except
51 if hasattr(e, "ag_error_metadata"):
---> 52 raise e.ag_error_metadata.to_exception(e)
53 else:
54 raise
InaccessibleTensorError: in user code:
File "/tmp/ipykernel_96618/173506222.py", line 9, in test2 *
arrays[ind] = arrays[ind].stack()
File "/project/jens/tf311/lib/python3.11/site-packages/tensorflow/core/function/capture/capture_container.py", line 144, in capture_by_value
graph._validate_in_scope(tensor) # pylint: disable=protected-access
InaccessibleTensorError: <tf.Tensor 'while/TensorArrayV2Write/TensorListSetItem:0' shape=() dtype=variant> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see for more information.
<tf.Tensor 'while/TensorArrayV2Write/TensorListSetItem:0' shape=() dtype=variant> was defined here:
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/project/jens/tf311/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
File "/project/jens/tf311/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
File "/project/jens/tf311/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start
File "/project/jens/tf311/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 205, in start
File "/usr/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
File "/usr/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
File "/usr/lib/python3.11/asyncio/events.py", line 80, in _run
File "/project/jens/tf311/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 542, in dispatch_queue
File "/project/jens/tf311/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 531, in process_one
File "/project/jens/tf311/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
File "/project/jens/tf311/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 359, in execute_request
File "/project/jens/tf311/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 775, in execute_request
File "/project/jens/tf311/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 446, in do_execute
File "/project/jens/tf311/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
File "/project/jens/tf311/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3051, in run_cell
File "/project/jens/tf311/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3106, in _run_cell
File "/project/jens/tf311/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
File "/project/jens/tf311/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3311, in run_cell_async
File "/project/jens/tf311/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3493, in run_ast_nodes
File "/project/jens/tf311/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code
File "/tmp/ipykernel_96618/173506222.py", line 12, in <module>
File "/tmp/ipykernel_96618/173506222.py", line 5, in test2
File "/tmp/ipykernel_96618/173506222.py", line 6, in test2
File "/tmp/ipykernel_96618/173506222.py", line 7, in test2
The tensor <tf.Tensor 'while/TensorArrayV2Write/TensorListSetItem:0' shape=() dtype=variant> cannot be accessed from FuncGraph(name=test2, id=139903595515616), because it was defined in FuncGraph(name=while_body_262957, id=139903595521664), which is out of scope.
I am not sure what to do here. Using range
instead of tf.range
works if the input length is known in advance, but it will compile the "unrolled" loop into a graph, which takes very long for longer sequences. I need TensorArray
because that's the dynamic data structure Tensorflow offers in such loops. But I also need to collect the TensorArrays
in another data structure; the only alternative I see is to define one Python variable per array for each layer, which would require changing the code every time I change the number of layers -- not really an option.
I may be able to stack everything into one big TensorArray
and then take it apart into the per-layer components afterwards, but I wanted to see if there is something I am missing here, first.
Note: I did read through the link given in the stack trace; it wasn't really helpful.