Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DynamicScale behaves unexpected when computing per-sample gradients with vmap. #4114

Open
hlzl opened this issue Aug 2, 2024 · 0 comments
Open

Comments

@hlzl
Copy link

hlzl commented Aug 2, 2024

When running jax.vmap, e.g. to compute per-sample gradients, the fin_steps and scale attributes of DynamicScale might become arrays, leading to an error in the next step during training if not handled manually. The thrown TypeError does not directly hint at the actual problem of a non-scalar scale attribute.

System information

  • jax==0.4.28 and flax==0.8.5

Problem you have encountered:

Due to self.scale becoming an array in the output of the first vmap call, the loss_wrapper also starts to return an array instead of a scalar inside of DynamicScale.

What you expected to happen:

The scale and fin_steps attributes should be either averaged or enforced to be scalars and thus not cause the TypeError.

Logs, error messages, etc:

File ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:132, in DynamicScale.value_and_grad.<locals>.grad_fn_wrapper(*args)
    ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:131) def grad_fn_wrapper(*args):
    ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:132)   aux, grad = grad_fn(*args)
    ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:133)   aux = (aux[0] / self.scale, aux[1]) if has_aux else aux / self.scale
    ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:135)   grad = jax.tree_util.tree_map(
    ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:136)     lambda g: jnp.asarray(g, jnp.float32) / self.scale, grad
    ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:137)   )
TypeError: Gradient only defined for scalar-output functions. Output had shape: (32,).

Steps to reproduce:

from typing import Sequence

import jax
import jax.numpy as jnp
import flax.linen as nn
import optax

from flax.training import dynamic_scale


class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x


def cross_entropy_loss(params, model, image, label):
    """Loss function used for training."""
    logits = model.apply({"params": params}, image)
    loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, label))
    return loss, logits

model = MLP([12, 8, 4])
input = jnp.ones((32, 10))
labels = jnp.ones((32,), dtype=int)
variables = model.init(jax.random.key(0), input)
output = model.apply(variables, input)

ds = dynamic_scale.DynamicScale()

# 1st batch
ds, is_fin, (loss, logits), per_sample_grads = jax.vmap(
    ds.value_and_grad(cross_entropy_loss, has_aux=True),
    in_axes=(None, None, 0, 0),
)(variables["params"], model, input, labels)

# 2nd batch
ds, is_fin, (loss, logits), per_sample_grads = jax.vmap(
    ds.value_and_grad(cross_entropy_loss, has_aux=True),
    in_axes=(None, None, 0, 0),
)(variables["params"], model, input, labels)

Can be fixed manually with ds = ds.replace(fin_steps=ds.fin_steps.mean(), scale=ds.scale.mean()) after each step.
Should be handled automatically / enforced within DynamicScale IMO.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant