You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When running jax.vmap, e.g. to compute per-sample gradients, the fin_steps and scale attributes of DynamicScale might become arrays, leading to an error in the next step during training if not handled manually. The thrown TypeError does not directly hint at the actual problem of a non-scalar scale attribute.
System information
jax==0.4.28 and flax==0.8.5
Problem you have encountered:
Due to self.scale becoming an array in the output of the first vmap call, the loss_wrapper also starts to return an array instead of a scalar inside of DynamicScale.
What you expected to happen:
The scale and fin_steps attributes should be either averaged or enforced to be scalars and thus not cause the TypeError.
Logs, error messages, etc:
File ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:132, in DynamicScale.value_and_grad.<locals>.grad_fn_wrapper(*args)
~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:131) def grad_fn_wrapper(*args):
~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:132) aux, grad = grad_fn(*args)
~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:133) aux = (aux[0] / self.scale, aux[1]) if has_aux else aux / self.scale
~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:135) grad = jax.tree_util.tree_map(
~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:136) lambda g: jnp.asarray(g, jnp.float32) / self.scale, grad
~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:137) )
Can be fixed manually with ds = ds.replace(fin_steps=ds.fin_steps.mean(), scale=ds.scale.mean()) after each step.
Should be handled automatically / enforced within DynamicScale IMO.
The text was updated successfully, but these errors were encountered:
When running
jax.vmap
, e.g. to compute per-sample gradients, thefin_steps
andscale
attributes ofDynamicScale
might become arrays, leading to an error in the next step during training if not handled manually. The thrownTypeError
does not directly hint at the actual problem of a non-scalarscale
attribute.System information
Problem you have encountered:
Due to
self.scale
becoming an array in the output of the firstvmap
call, theloss_wrapper
also starts to return an array instead of a scalar inside ofDynamicScale
.What you expected to happen:
The
scale
andfin_steps
attributes should be either averaged or enforced to be scalars and thus not cause theTypeError
.Logs, error messages, etc:
Steps to reproduce:
Can be fixed manually with
ds = ds.replace(fin_steps=ds.fin_steps.mean(), scale=ds.scale.mean())
after each step.Should be handled automatically / enforced within
DynamicScale
IMO.The text was updated successfully, but these errors were encountered: