Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental-pytree flag causes crash #4142

Open
NeilGirdhar opened this issue Aug 24, 2024 · 2 comments
Open

Experimental-pytree flag causes crash #4142

NeilGirdhar opened this issue Aug 24, 2024 · 2 comments

Comments

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Aug 24, 2024

from dataclasses import dataclass, field
from typing import Any, dataclass_transform, override

import jax.numpy as jnp
from flax import nnx
from jax import Array, vmap


@dataclass_transform(field_specifiers=(field,))
class DataClassModule(nnx.Module):
    @override
    def __init_subclass__(cls, **kwargs: Any) -> None:
        super().__init_subclass__(**kwargs, experimental_pytree=True)
        dataclass()(cls)


class SomeModule(nnx.Module):
    def __init__(self, epsilon: Array):
        super().__init__()
        self.epsilon = epsilon


class SomeDataclassModule(DataClassModule):
    def __init__(self) -> None:
        super().__init__()
        self.sm = SomeModule(jnp.zeros(1))


def f(m: SomeDataclassModule, x: Array) -> None:
    pass


module = SomeDataclassModule()
z = jnp.zeros(10)
vmap(f, in_axes=(None, 0))(module, z)

gives

Traceback (most recent call last):
  File "/home/neil/src/cmm/a.py", line 35, in <module>
    vmap(f, in_axes=(None, 0))(module, z)
  File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/jax/_src/api.py", line 1221, in vmap_f
    in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/jax/_src/api_util.py", line 400, in flatten_axes
    dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 84, in tree_unflatten
    return treedef.unflatten(leaves)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/flax/nnx/nnx/module.py", line 444, in _module_unflatten
    return graph.merge(graphdef, State(zip(paths, variables)))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/flax/nnx/nnx/graph.py", line 1306, in merge
    node, _ = unflatten(graphdef, state)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/flax/nnx/nnx/graph.py", line 471, in unflatten
    node = _graph_unflatten(
           ^^^^^^^^^^^^^^^^^
  File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/flax/nnx/nnx/graph.py", line 628, in _graph_unflatten
    children = _get_children()
               ^^^^^^^^^^^^^^^
  File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/flax/nnx/nnx/graph.py", line 569, in _get_children
    children[key] = _graph_unflatten(
                    ^^^^^^^^^^^^^^^^^
  File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/flax/nnx/nnx/graph.py", line 628, in _graph_unflatten
    children = _get_children()
               ^^^^^^^^^^^^^^^
  File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/flax/nnx/nnx/graph.py", line 575, in _get_children
    raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}')
ValueError: Expected a leaf for 'epsilon', but got <object object at 0x7ccaa815f050>

Assigning a float to epsilon makes the problem disappear.

(Tested on main and latest.)

@cgarciae
Copy link
Collaborator

Interesting counter examples for Module's pytree definition. The issue is that we decide what is static based on the type.

@NeilGirdhar
Copy link
Contributor Author

Okay, I understand. Feel free to close if you like.

copybara-service bot pushed a commit that referenced this issue Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142.

```python
class Foo(nnx.Module):
  def __init__(self):
    self.a = jnp.array(1) # no longer allowed, instead...
    self.b = nnx.Param(jnp.array(1)) # just use Variables
```

PiperOrigin-RevId: 670949705
copybara-service bot pushed a commit that referenced this issue Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142.

```python
class Foo(nnx.Module):
  def __init__(self):
    self.a = jnp.array(1) # no longer allowed, instead...
    self.b = nnx.Param(jnp.array(1)) # just use Variables
```

PiperOrigin-RevId: 670949705
copybara-service bot pushed a commit that referenced this issue Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142.

```python
class Foo(nnx.Module):
  def __init__(self):
    self.a = jnp.array(1) # no longer allowed, instead...
    self.b = nnx.Param(jnp.array(1)) # just use Variables
```

PiperOrigin-RevId: 670949705
copybara-service bot pushed a commit that referenced this issue Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142.

```python
class Foo(nnx.Module):
  def __init__(self):
    self.a = jnp.array(1) # no longer allowed, instead...
    self.b = nnx.Param(jnp.array(1)) # just use Variables
```

Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally.

PiperOrigin-RevId: 670949705
copybara-service bot pushed a commit that referenced this issue Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142.

```python
class Foo(nnx.Module):
  def __init__(self):
    self.a = jnp.array(1) # no longer allowed, instead...
    self.b = nnx.Param(jnp.array(1)) # just use Variables
```

Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally.

PiperOrigin-RevId: 670949705
copybara-service bot pushed a commit that referenced this issue Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142.

```python
class Foo(nnx.Module):
  def __init__(self):
    self.a = jnp.array(1) # no longer allowed, instead...
    self.b = nnx.Param(jnp.array(1)) # just use Variables
```

Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally.

PiperOrigin-RevId: 670949705
copybara-service bot pushed a commit that referenced this issue Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142.

```python
class Foo(nnx.Module):
  def __init__(self):
    self.a = jnp.array(1) # no longer allowed, instead...
    self.b = nnx.Param(jnp.array(1)) # just use Variables
```

Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally.

PiperOrigin-RevId: 670949705
copybara-service bot pushed a commit that referenced this issue Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142.

```python
class Foo(nnx.Module):
  def __init__(self):
    self.a = jnp.array(1) # no longer allowed, instead...
    self.b = nnx.Param(jnp.array(1)) # just use Variables
```

Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally.

PiperOrigin-RevId: 670949705
copybara-service bot pushed a commit that referenced this issue Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142.

```python
class Foo(nnx.Module):
  def __init__(self):
    self.a = jnp.array(1) # no longer allowed, instead...
    self.b = nnx.Param(jnp.array(1)) # just use Variables
```

Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally.

PiperOrigin-RevId: 670949705
copybara-service bot pushed a commit that referenced this issue Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142.

```python
class Foo(nnx.Module):
  def __init__(self):
    self.a = jnp.array(1) # no longer allowed, instead...
    self.b = nnx.Param(jnp.array(1)) # just use Variables
```

Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally.

PiperOrigin-RevId: 671372717
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants