diff --git a/export/orbax/export/modules/tensorflow_module.py b/export/orbax/export/modules/tensorflow_module.py index 99564565e..d1e59ad19 100644 --- a/export/orbax/export/modules/tensorflow_module.py +++ b/export/orbax/export/modules/tensorflow_module.py @@ -280,7 +280,7 @@ def jax_params_to_tf_variables( ) -> PyTree: """Converts `params` to tf.Variables in the same pytree structure.""" mesh = dtensor_utils.get_current_mesh() - default_cpu_device = tf.config.list_logical_devices('CPU')[0] + default_cpu_device = tf.config.list_logical_devices('CPU')[0].name if mesh is not None: if pspecs is None: raise ValueError(