From b49e7cb8cfbccaed9d62580014ed8e4c1c90710a Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Thu, 2 Jul 2026 08:18:55 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 941710277 --- export/orbax/export/modules/tensorflow_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/export/orbax/export/modules/tensorflow_module.py b/export/orbax/export/modules/tensorflow_module.py index 99564565e..d1e59ad19 100644 --- a/export/orbax/export/modules/tensorflow_module.py +++ b/export/orbax/export/modules/tensorflow_module.py @@ -280,7 +280,7 @@ def jax_params_to_tf_variables( ) -> PyTree: """Converts `params` to tf.Variables in the same pytree structure.""" mesh = dtensor_utils.get_current_mesh() - default_cpu_device = tf.config.list_logical_devices('CPU')[0] + default_cpu_device = tf.config.list_logical_devices('CPU')[0].name if mesh is not None: if pspecs is None: raise ValueError(