diff --git a/examples/burgers1d.ipynb b/examples/burgers1d.ipynb index ff2cea1..e05abd4 100644 --- a/examples/burgers1d.ipynb +++ b/examples/burgers1d.ipynb @@ -35,8 +35,7 @@ "from lasdi.latent_space import Autoencoder, initial_condition_latent\n", "from lasdi.postprocess import compute_errors\n", "\n", - "date = '08_28_2024_20_46'\n", - "bglasdi_results = np.load('results/bglasdi_' + date + '.npy', allow_pickle = True).item()" + "filename = 'lasdi_09_11_2024_20_20.npy'" ] }, { @@ -55,17 +54,16 @@ "outputs": [], "source": [ "import yaml\n", - "from lasdi.workflow import initialize_physics, initialize_latent_space, ld_dict\n", + "from lasdi.workflow import initialize_trainer\n", "from lasdi.param import ParameterSpace\n", "\n", "cfg_file = 'burgers1d.yml'\n", "with open(cfg_file, 'r') as f:\n", " config = yaml.safe_load(f)\n", "\n", - "param_space = ParameterSpace(config)\n", - "physics = initialize_physics(param_space, config)\n", - "autoencoder = initialize_latent_space(physics, config)\n", - "sindy = ld_dict[config['latent_dynamics']['type']](autoencoder.n_z, physics.nt, config['latent_dynamics'])" + "restart_file = np.load(filename, allow_pickle=True).item()\n", + "\n", + "trainer, param_space, physics, autoencoder, sindy = initialize_trainer(config, restart_file)" ] }, { @@ -75,16 +73,15 @@ "metadata": {}, "outputs": [], "source": [ - "autoencoder_param = bglasdi_results['autoencoder_param']\n", + "from lasdi.gp import fit_gps\n", + "\n", + "autoencoder_param = restart_file['latent_space']['autoencoder_param']\n", "\n", - "X_train = bglasdi_results['final_X_train']\n", - "coefs = bglasdi_results['coefs']\n", - "gp_dictionnary = bglasdi_results['gp_dictionnary']\n", - "fd_type = bglasdi_results['latent_dynamics']['fd_type']\n", + "X_train = restart_file['trainer']['X_train']\n", "\n", - "paramspace_dict = bglasdi_results['parameters']\n", - "param_train = paramspace_dict['final_param_train']\n", - "param_grid = paramspace_dict['param_grid']\n", + "paramspace_dict = restart_file['parameters']\n", + "param_train = paramspace_dict['train_space']\n", + "param_grid = paramspace_dict['test_space']\n", "test_meshgrid = paramspace_dict['test_meshgrid']\n", "test_grid_sizes = paramspace_dict['test_grid_sizes']\n", "\n", @@ -93,7 +90,7 @@ "n_a_grid, n_w_grid = test_grid_sizes\n", "a_grid, w_grid = test_meshgrid\n", "\n", - "physics_dict = bglasdi_results['physics']\n", + "physics_dict = restart_file['physics']\n", "t_grid = physics_dict['t_grid']\n", "x_grid = physics_dict['x_grid']\n", "t_mesh, x_mesh = np.meshgrid(t_grid, x_grid)\n", @@ -116,6 +113,9 @@ "n_z = autoencoder_param['fc' + str(n_hidden + 1) + '_e.weight'].shape[0]\n", "\n", "autoencoder.load_state_dict(autoencoder_param)\n", + "Z = autoencoder.encoder(X_train)\n", + "coefs = sindy.calibrate(Z, physics.dt, compute_loss=False, numpy=True)\n", + "gp_dictionnary = fit_gps(param_space.train_space, coefs)\n", "\n", "n_coef = sindy.ncoefs\n", "\n", diff --git a/examples/burgers1d.yml b/examples/burgers1d.yml index d905c62..d83c0b8 100644 --- a/examples/burgers1d.yml +++ b/examples/burgers1d.yml @@ -12,6 +12,10 @@ lasdi: path_checkpoint: checkpoint path_results: results +workflow: + use_restart: true + restart_file: restarts/burgers1d.restart.npy + parameter_space: parameters: - name: a diff --git a/src/lasdi/gplasdi.py b/src/lasdi/gplasdi.py index 15c1cf4..a19176c 100644 --- a/src/lasdi/gplasdi.py +++ b/src/lasdi/gplasdi.py @@ -73,6 +73,21 @@ def get_fom_max_std(autoencoder, Zis): return m_index +# move optimizer parameters to device +def optimizer_to(optim, device): + for param in optim.state.values(): + # Not sure there are any global tensors in the state dict + if isinstance(param, torch.Tensor): + param.data = param.data.to(device) + if param._grad is not None: + param._grad.data = param._grad.data.to(device) + elif isinstance(param, dict): + for subparam in param.values(): + if isinstance(subparam, torch.Tensor): + subparam.data = subparam.data.to(device) + if subparam._grad is not None: + subparam._grad.data = subparam._grad.data.to(device) + class BayesianGLaSDI: X_train = None @@ -161,7 +176,8 @@ def train(self): self.optimizer.step() if loss.item() < self.best_loss: - torch.save(autoencoder_device.state_dict(), self.path_checkpoint + '/' + 'checkpoint.pt') + torch.save(autoencoder_device.cpu().state_dict(), self.path_checkpoint + '/' + 'checkpoint.pt') + autoencoder_device = self.autoencoder.to(device) self.best_coefs = coefs self.best_loss = loss.item() @@ -202,26 +218,10 @@ def train(self): self.timer.start("finalize") if (self.best_coefs.shape[0] == ps.n_train): + state_dict = torch.load(self.path_checkpoint + '/' + 'checkpoint.pt') + self.autoencoder.load_state_dict(state_dict) coefs = self.best_coefs - gp_dictionnary = fit_gps(ps.train_space, coefs) - - bglasdi_results = {'autoencoder_param': self.autoencoder.state_dict(), 'final_X_train': self.X_train, - 'coefs': coefs, 'gp_dictionnary': gp_dictionnary, 'lr': self.lr, 'n_iter': self.n_iter, - 'n_greedy': self.n_greedy, 'sindy_weight': self.sindy_weight, 'coef_weight': self.coef_weight, - 'n_samples' : self.n_samples, - } - bglasdi_results['physics'] = self.physics.export() - bglasdi_results['parameters'] = self.param_space.export() - # TODO(kevin): restart capability for timer. - bglasdi_results['timer'] = self.timer.export() - bglasdi_results['latent_dynamics'] = self.latent_dynamics.export() - - date = time.localtime() - date_str = "{month:02d}_{day:02d}_{year:04d}_{hour:02d}_{minute:02d}" - date_str = date_str.format(month = date.tm_mon, day = date.tm_mday, year = date.tm_year, hour = date.tm_hour + 3, minute = date.tm_min) - np.save(self.path_results + '/' + 'bglasdi_' + date_str + '.npy', bglasdi_results) - self.timer.end("finalize") self.timer.print() @@ -267,4 +267,20 @@ def sample_fom(self): self.X_train = torch.cat([self.X_train, new_X], dim = 0) else: # TODO(kevin): interface for offline FOM simulation - raise RuntimeError("Offline FOM simulation is not supported yet!") \ No newline at end of file + raise RuntimeError("Offline FOM simulation is not supported yet!") + + def export(self): + dict_ = {'X_train': self.X_train, 'lr': self.lr, 'n_iter': self.n_iter, 'n_samples' : self.n_samples, + 'n_greedy': self.n_greedy, 'sindy_weight': self.sindy_weight, 'coef_weight': self.coef_weight, + 'restart_iter': self.restart_iter, 'timer': self.timer.export(), 'optimizer': self.optimizer.state_dict() + } + return dict_ + + def load(self, dict_): + self.X_train = dict_['X_train'] + self.restart_iter = dict_['restart_iter'] + self.timer.load(dict_['timer']) + self.optimizer.load_state_dict(dict_['optimizer']) + if (self.device != 'cpu'): + optimizer_to(self.optimizer, self.device) + return \ No newline at end of file diff --git a/src/lasdi/latent_dynamics/__init__.py b/src/lasdi/latent_dynamics/__init__.py index 2c35a86..6efa6d3 100644 --- a/src/lasdi/latent_dynamics/__init__.py +++ b/src/lasdi/latent_dynamics/__init__.py @@ -56,4 +56,11 @@ def sample(self, coefs_sample, z0_sample, t_grid): def export(self): param_dict = {'dim': self.dim, 'ncoefs': self.ncoefs} return param_dict + + # SINDy does not need to load parameters. + # Other latent dynamics might need to. + def load(self, dict_): + assert(self.dim == dict_['dim']) + assert(self.ncoefs == dict_['ncoefs']) + return \ No newline at end of file diff --git a/src/lasdi/latent_space.py b/src/lasdi/latent_space.py index 2b153db..e3c7f11 100644 --- a/src/lasdi/latent_space.py +++ b/src/lasdi/latent_space.py @@ -174,4 +174,12 @@ def apply_attention(self, x, layer): x, _ = a(x, x, x) # apply attention x = x.squeeze(1) # Remove sequence dimension - return x \ No newline at end of file + return x + + def export(self): + dict_ = {'autoencoder_param': self.cpu().state_dict()} + return dict_ + + def load(self, dict_): + self.load_state_dict(dict_['autoencoder_param']) + return \ No newline at end of file diff --git a/src/lasdi/param.py b/src/lasdi/param.py index 19000b4..a656b73 100644 --- a/src/lasdi/param.py +++ b/src/lasdi/param.py @@ -22,6 +22,7 @@ def create_uniform_1dspace(config): class ParameterSpace: param_list = [] param_name = [] + n_param = 0 train_space = None test_space = None n_test = 0 @@ -35,6 +36,7 @@ def __init__(self, config): parser = InputParser(config['parameter_space'], name="param_space_input") self.param_list = parser.getInput(['parameters'], datatype=list) + self.n_param = len(self.param_list) self.param_name = [] for param in self.param_list: @@ -131,13 +133,23 @@ def appendTrainSpace(self, param): return def export(self): - dict_ = {'final_param_train': self.train_space, - 'param_grid': self.test_space, + dict_ = {'train_space': self.train_space, + 'test_space': self.test_space, 'test_grid_sizes': self.test_grid_sizes, 'test_meshgrid': self.test_meshgrid, 'n_init': self.n_init} return dict_ - def loadTrainSpace(self): - raise RuntimeError("ParameterSpace.loadTrainSpace is not implemented yet!") + def load(self, dict_): + self.train_space = dict_['train_space'] + self.test_space = dict_['test_space'] + self.test_grid_sizes = dict_['test_grid_sizes'] + self.test_meshgrid = dict_['test_meshgrid'] + + assert(self.n_init == dict_['n_init']) + assert(self.train_space.shape[1] == self.n_param) + assert(self.test_space.shape[1] == self.n_param) + + self.n_train = self.train_space.shape[0] + self.n_test = self.test_space.shape[0] return diff --git a/src/lasdi/timing.py b/src/lasdi/timing.py index ce4af72..da7d7c3 100644 --- a/src/lasdi/timing.py +++ b/src/lasdi/timing.py @@ -40,8 +40,22 @@ def print(self): return def export(self): + for start in self.starts: + if (start is not None): + raise RuntimeError('Timer.export: cannot export while Timer is still ticking!') + param_dict = {} param_dict["names"] = self.names param_dict["calls"] = self.calls param_dict["times"] = self.times - return param_dict \ No newline at end of file + return param_dict + + def load(self, dict_): + self.names = dict_['names'] + self.calls = dict_['calls'] + self.times = dict_['times'] + + assert(len(self.names) == len(self.calls)) + assert(len(self.names) == len(self.times)) + self.starts = [None] * len(self.names) + return \ No newline at end of file diff --git a/src/lasdi/workflow.py b/src/lasdi/workflow.py index 136549f..b62449e 100644 --- a/src/lasdi/workflow.py +++ b/src/lasdi/workflow.py @@ -9,6 +9,7 @@ from .latent_dynamics.sindy import SINDy from .physics.burgers1d import Burgers1D from .param import ParameterSpace +from .inputs import InputParser trainer_dict = {'gplasdi': BayesianGLaSDI} @@ -29,63 +30,102 @@ def main(): with open(args.config_file, 'r') as f: config = yaml.safe_load(f) - - trainer = initialize_trainer(config) - - if ('restart_file' in config): - restart_file = np.load(config['restart_file'], allow_pickle=True).item() - next_step = restart_file['next_step'] + cfg_parser = InputParser(config, name='main') + + use_restart = cfg_parser.getInput(['workflow', 'use_restart'], fallback=False) + if (use_restart): + restart_filename = cfg_parser.getInput(['workflow', 'restart_file'], datatype=str) + from os.path import dirname + from pathlib import Path + Path(dirname(restart_filename)).mkdir(parents=True, exist_ok=True) + + import os + if (use_restart and (os.path.isfile(restart_filename))): + # TODO(kevin): in long term, we should switch to hdf5 format. + restart_file = np.load(restart_filename, allow_pickle=True).item() + current_step = restart_file['next_step'] result = restart_file['result'] else: - next_step = NextStep.Initial + restart_file = None + current_step = NextStep.Initial result = Result.Unexecuted + + trainer, param_space, physics, latent_space, latent_dynamics = initialize_trainer(config, restart_file) + + result, next_step = step(trainer, current_step, config, use_restart) if (result is Result.Fail): raise RuntimeError('Previous step has failed. Stopping the workflow.') + elif (result is Result.Success): + print("Previous step succeeded. Preparing for the next step.") + result = Result.Unexecuted elif (result is Result.Complete): print("Workflow is finished.") - return result - result = step(trainer, next_step, config) + # save restart (or final) file. + import time + date = time.localtime() + date_str = "{month:02d}_{day:02d}_{year:04d}_{hour:02d}_{minute:02d}" + date_str = date_str.format(month = date.tm_mon, day = date.tm_mday, year = date.tm_year, hour = date.tm_hour + 3, minute = date.tm_min) + if (use_restart): + # rename old restart file if exists. + if (os.path.isfile(restart_filename)): + old_timestamp = restart_file['timestamp'] + os.rename(restart_filename, restart_filename + '.' + old_timestamp) + save_file = restart_filename + else: + save_file = 'lasdi_' + date_str + '.npy' + + save_dict = {'parameters': param_space.export(), + 'physics': physics.export(), + 'latent_space': latent_space.export(), + 'latent_dynamics': latent_dynamics.export(), + 'trainer': trainer.export(), + 'timestamp': date_str, + 'next_step': next_step, + 'result': result, # TODO(kevin): do we need to save result? + } + + np.save(save_file, save_dict) return result -def step(trainer, next_step, config): - # TODO(kevin): implement save/load workflow. - continue_workflow = True +def step(trainer, next_step, config, use_restart=False): if (next_step is NextStep.Initial): result, next_step = initial_step(trainer, config) - if (continue_workflow): - result = step(trainer, next_step, config) elif (next_step is NextStep.Train): result, next_step = trainer.train() - if (result is Result.Complete): - return result - else: - assert(next_step is NextStep.RunSample) - result = step(trainer, next_step, config) elif (next_step is NextStep.RunSample): trainer.sample_fom() # TODO(kevin): currently no offline fom simulation. skip CollectSample. result, next_step = Result.Success, NextStep.Train - result = step(trainer, next_step, config) + # result = step(trainer, next_step, config) elif (next_step is NextStep.CollectSample): import warnings warnings.warn("Collecting sample from offline FOM simulation is not implemented yet") + result, next_step = Result.Success, NextStep.RunSample else: raise RuntimeError("Unknown next step!") + + # if fail or complete, break the loop regardless of use_restart. + if ((result is Result.Fail) or (result is Result.Complete)): + return result, next_step + + # continue the workflow if not using restart. + if (not use_restart): + result, next_step = step(trainer, next_step, config) - return result + return result, next_step -def initialize_trainer(config): +def initialize_trainer(config, restart_file=None): ''' Initialize a LaSDI class with a latent space model according to config file. Currently only 'gplasdi' is available. @@ -93,23 +133,31 @@ def initialize_trainer(config): # TODO(kevin): load parameter train space from a restart file. param_space = ParameterSpace(config) + if (restart_file is not None): + param_space.load(restart_file['parameters']) physics = initialize_physics(param_space, config) latent_space = initialize_latent_space(physics, config) + if (restart_file is not None): + latent_space.load(restart_file['latent_space']) # do we need a separate routine for latent dynamics initialization? ld_type = config['latent_dynamics']['type'] assert(ld_type in config['latent_dynamics']) assert(ld_type in ld_dict) latent_dynamics = ld_dict[ld_type](latent_space.n_z, physics.nt, config['latent_dynamics']) + if (restart_file is not None): + latent_dynamics.load(restart_file['latent_dynamics']) trainer_type = config['lasdi']['type'] assert(trainer_type in config['lasdi']) assert(trainer_type in trainer_dict) trainer = trainer_dict[trainer_type](physics, latent_space, latent_dynamics, config['lasdi'][trainer_type]) + if (restart_file is not None): + trainer.load(restart_file['trainer']) - return trainer + return trainer, param_space, physics, latent_space, latent_dynamics def initialize_latent_space(physics, config): '''