我在后端使用了带有tf-2.2的keras,它显示了这个错误。
Traceback (most recent call last):
File "run.py", line 97, in <module>
task_entry_function()
File "/data-crystina/src/capreolus-unpublished/capreolus/task/rerank.py", line 47, in train
return self.rerank_run(best_search_run, self.get_results_path())
File "/data-crystina/src/capreolus-unpublished/capreolus/task/rerank.py", line 85, in rerank_run
self.benchmark.relevance_level,
File "/data-crystina/src/capreolus-unpublished/capreolus/trainer/__init__.py", line 578, in train
use_multiprocessing=True,
File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 66, in _method_wrapper
return method(self, *args, **kwargs)
File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 855, in fit
callbacks.on_train_batch_end(step, logs)
File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 389, in on_train_batch_end
logs = self._process_logs(logs)
File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 265, in _process_logs
return tf_utils.to_numpy_or_python_type(logs) File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 523, in to_numpy_or_python_type
return nest.map_structure(_to_single_numpy_or_python_type, tensors)
File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/util/nest.py", line 617, in map_structure
structure[0], [func(*x) for x in entries],
File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/util/nest.py", line 617, in <listcomp>
structure[0], [func(*x) for x in entries],
File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 519, in _to_single_numpy_or_python_type
x = t.numpy()
File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 961, in numpy
maybe_arr = self._numpy() # pylint: disable=protected-access
File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 929, in _numpy
six.raise_from(core._status_to_exception(e.code, e.message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __inference_train_function_100056}} PartialTensorShape: Incompatible ranks during merge: 2 vs. 1
[[{{node map_6/TensorArrayV2Stack/TensorListStack}}]]
[[MultiDeviceIteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNextAsOptional]]
2020-07-03 07:19:03.088112: W tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc:76] Unable to destroy remote tensor handles. If you are running a tf.function, it usually indicates som
e op in the graph gets an error: {{function_node __inference_train_function_100056}} PartialTensorShape: Incompatible ranks during merge: 2 vs. 1
[[{{node map_6/TensorArrayV2Stack/TensorListStack}}]]
[[MultiDeviceIteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNextAsOptional]]
为未能找到一小段代码片段来重现此内容而道歉。但是我进入了..python3.7/site-packages/tensorflow/python/keras/callbacks.py
内部,在函数中:
def on_train_batch_end(self, batch, logs=None):
"""Calls the `on_train_batch_end` methods of its callbacks.
Arguments:
batch: integer, index of batch within the current epoch.
logs: dict. Metric results for this batch.
"""
if self._should_call_train_batch_hooks:
# print("<<<<", logs.keys())
# print(">>>", type(list(logs.values())[0]))
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
我打印出了logs
,发现它是一个只包含一个关键loss
的字典,它的值的类型是class 'tensorflow.python.framework.ops.EagerTensor'>
。但是,由于相同的错误,以及与logs["loss"].shape
相同的错误,无法打印logs["loss"]
目录。我在网上找不到类似的案例,不知道有没有人遇到过这个案例?
转载请注明出处:http://www.songfuwangmfj.com/article/20230526/1869605.html