Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for optax lbfgs and related optimizers with NNX #4144

Open
jlperla opened this issue Aug 26, 2024 · 1 comment
Open

Support for optax lbfgs and related optimizers with NNX #4144

jlperla opened this issue Aug 26, 2024 · 1 comment
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)

Comments

@jlperla
Copy link

jlperla commented Aug 26, 2024

I am trying to use L-BFGS and related optimizers with nnx + optax, but running into trouble. It might be that optax has a slightly different optimization interface in those cases: https://optax.readthedocs.io/en/latest/api/optimizers.html#lbfgs with a function called optax.value_and_grad_from_state and also a change to the optimizer.update interface?

In particular, note that the sample code for these optimizers in https://optax.readthedocs.io/en/latest/api/optimizers.html#lbfgs looks like

opt_state = solver.init(params)
value_and_grad = optax.value_and_grad_from_state(f)
for _ in range(5):
  value, grad = value_and_grad(params, state=opt_state)
  updates, opt_state = solver.update(
     grad, opt_state, params, value=value, grad=grad, value_fn=f
  )
  params = optax.apply_updates(params, updates)

So maybe we need the ability to pass a value_fn argument on? Any easy fixes or things I am missing?

Problem you have encountered:

Take the following implementation LLS with NNX + optax

import jax
import jax.numpy as jnp
from jax import random
import optax
import equinox as eqx
import jax_dataloader as jdl
from jax_dataloader.loaders import DataLoaderJAX
from flax import nnx

N = 64  # samples
M = 2
sigma = 0.001
rngs = nnx.Rngs(42)
theta = random.normal(rngs(), (M,))
X = random.normal(rngs(), (N, M))
Y = X @ theta + sigma * random.normal(rngs(), (N,))  # Adding noise

def residual(model, x, y):
    y_hat = model(x)
    return (y_hat - y) ** 2

def residuals_loss(model, X, Y):
    return jnp.mean(jax.vmap(residual, in_axes=(None, 0, 0))(model, X, Y))

model = nnx.Linear(M, 1, use_bias=False, rngs=rngs)

lr = 0.001
optimizer = nnx.Optimizer(model,
                          optax.sgd(lr) #optax.lbfgs()
                          )

@nnx.jit
def train_step(model, optimizer, X, Y):
    grad_fn = nnx.value_and_grad(residuals_loss, has_aux=False)
    loss, grads = grad_fn(model, X, Y)
    optimizer.update(grads)
    return loss

num_epochs = 500
batch_size = 64
dataset = jdl.ArrayDataset(X, Y)
train_loader = DataLoaderJAX(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(num_epochs):
    for X_batch, Y_batch in train_loader:
        loss = train_step(model, optimizer, X_batch, Y_batch)

    if epoch % 100 == 0:
        print(
            f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}"
        )

print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}")

But if I change the optimizer from optax.sgd(lr) to be optax.lbfgs() then I would expect NNX to work.

Logs, error messages, etc:

The error it gives is

  File "/Users/jlperla/Documents/GitHub/ECON622_instructor/lectures/examples/linear_regression_jax_nnx.py", line 52, in <module>
    loss = train_step(model, optimizer, X_batch, Y_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/flax/nnx/nnx/graph.py", line 1043, in update_context_manager_wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/flax/nnx/nnx/transforms/transforms.py", line 359, in jit_wrapper
    out, output_state, output_graphdef = jitted_fn(
                                         ^^^^^^^^^^
  File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/flax/nnx/nnx/transforms/transforms.py", line 158, in jit_fn
    out = f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^
  File "/Users/jlperla/Documents/GitHub/ECON622_instructor/lectures/examples/linear_regression_jax_nnx.py", line 43, in train_step
    optimizer.update(grads)
  File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/flax/nnx/nnx/training/optimizer.py", line 201, in update
    updates, new_opt_state = self.tx.update(grads, self.opt_state, state)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/optax/transforms/_combining.py", line 73, in update_fn
    updates, new_s = fn(updates, s, params, **extra_args)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: scale_by_zoom_linesearch.<locals>.update_fn() missing 3 required keyword-only arguments: 'value', 'grad', and 'value_fn'

Steps to reproduce:

Run the example above with the optimizer swapped out, i.e.

import jax
import jax.numpy as jnp
from jax import random
import optax
import equinox as eqx
import jax_dataloader as jdl
from jax_dataloader.loaders import DataLoaderJAX
from flax import nnx

N = 64  # samples
M = 2
sigma = 0.001
rngs = nnx.Rngs(42)
theta = random.normal(rngs(), (M,))
X = random.normal(rngs(), (N, M))
Y = X @ theta + sigma * random.normal(rngs(), (N,))  # Adding noise

def residual(model, x, y):
    y_hat = model(x)
    return (y_hat - y) ** 2

def residuals_loss(model, X, Y):
    return jnp.mean(jax.vmap(residual, in_axes=(None, 0, 0))(model, X, Y))

model = nnx.Linear(M, 1, use_bias=False, rngs=rngs)

lr = 0.001
optimizer = nnx.Optimizer(model,
                          optax.lbfgs()
                          )

@nnx.jit
def train_step(model, optimizer, X, Y):
    grad_fn = nnx.value_and_grad(residuals_loss, has_aux=False)
    loss, grads = grad_fn(model, X, Y)
    optimizer.update(grads)
    return loss

num_epochs = 500
batch_size = 64
dataset = jdl.ArrayDataset(X, Y)
train_loader = DataLoaderJAX(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(num_epochs):
    for X_batch, Y_batch in train_loader:
        loss = train_step(model, optimizer, X_batch, Y_batch)

    if epoch % 100 == 0:
        print(
            f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}"
        )

print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}")
@google google deleted a comment Aug 26, 2024
@cgarciae cgarciae added the Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) label Aug 26, 2024
@jlperla
Copy link
Author

jlperla commented Aug 27, 2024

An update: I have a partial hack to https://github.com/google/flax/blob/main/flax/nnx/nnx/training/optimizer.py

If I replace the .update with

def update(self, grads, value = None, value_fn = None):
    gdef, state = nnx.split(self.model, self.wrt)

    def value_fn_wrapped(state):
        model = nnx.merge(gdef, state)
        return value_fn(model)

    updates, new_opt_state = self.tx.update(grads, self.opt_state, state, grad = grads, value = value, value_fn = value_fn_wrapped)


    new_params = optax.apply_updates(state, updates)
    assert isinstance(new_params, nnx.State)

    self.step.value += 1
    nnx.update(self.model, new_params)
    self.opt_state = new_opt_state

Then it seems to work. The key is that the value_fn required by linesearch cannot take the state in when it evaluates, so you ned to split and merge. Don't know if this is high performance or not. It certainly isn't using the grad and value caching from optax.value_and_grad_from_state

The full version of this (where I just called my function with the optimizer is:

# Takes the baseline version and uses vmap, adds in a learning rate scheduler
import jax
import jax.numpy as jnp
from jax import random
import optax
import jax_dataloader as jdl
from jax_dataloader.loaders import DataLoaderJAX
from flax import nnx

N = 500  # samples
M = 2
sigma = 0.001
rngs = nnx.Rngs(42)
theta = random.normal(rngs(), (M,))
X = random.normal(rngs(), (N, M))
Y = X @ theta + sigma * random.normal(rngs(), (N,))  # Adding noise

def residual(model, x, y):
    y_hat = model(x)
    return (y_hat - y) ** 2

def residuals_loss(model, X, Y):
    return jnp.mean(jax.vmap(residual, in_axes=(None, 0, 0))(model, X, Y))

model = nnx.Linear(M, 1, use_bias=False, rngs=rngs)

# From https://github.com/google/flax/blob/main/flax/nnx/nnx/training/optimizer.py
def update(self, grads, value = None, value_fn = None):
    gdef, state = nnx.split(self.model, self.wrt)

    def value_fn_wrapped(state):
        model = nnx.merge(gdef, state)
        return value_fn(model)

    updates, new_opt_state = self.tx.update(grads, self.opt_state, state, grad = grads, value = value, value_fn = value_fn_wrapped)


    new_params = optax.apply_updates(state, updates)
    assert isinstance(new_params, nnx.State)

    self.step.value += 1
    nnx.update(self.model, new_params)
    self.opt_state = new_opt_state




lr = 0.001
optimizer = nnx.Optimizer(model,
                          optax.lbfgs(),
                           #optax.sgd(lr),
                          )

@nnx.jit
def train_step(model, optimizer, X, Y):
    def loss_fn(model):
        return residuals_loss(model, X, Y)
    loss, grads =  nnx.value_and_grad(loss_fn, has_aux=False)(model)
    # optimizer.update(grads)
    update(optimizer, grads, value = loss, value_fn = loss_fn)
    return loss

num_epochs = 20
batch_size = 1024
dataset = jdl.ArrayDataset(X, Y)
train_loader = DataLoaderJAX(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(num_epochs):
    for X_batch, Y_batch in train_loader:
        loss = train_step(model, optimizer, X_batch, Y_batch)

    if epoch % 2 == 0:
        print(
            f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}"
        )

print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}")

This is using full-batch, which is approprriate for lbfgs unless the learning rate is decreased.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)
Projects
None yet
Development

No branches or pull requests

2 participants