Skip to content

Commit

Permalink
Patch just once (#2416)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Mar 19, 2024
1 parent eca20b9 commit 5b88e43
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions nvflare/app_opt/lightning/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,17 @@ def __init__(self):
self.__fl_meta__ = {"CUSTOM_VAR": "VALUE_OF_THE_VAR"}
"""
fl_callback = FLCallback(rank=trainer.global_rank, load_state_dict_strict=load_state_dict_strict)
callbacks = trainer.callbacks
if isinstance(callbacks, list):
if isinstance(callbacks, Callback):
callbacks = [callbacks]
elif not isinstance(callbacks, list):
callbacks = []

if not any(isinstance(cb, FLCallback) for cb in callbacks):
fl_callback = FLCallback(rank=trainer.global_rank, load_state_dict_strict=load_state_dict_strict)
callbacks.append(fl_callback)
elif isinstance(callbacks, Callback):
callbacks = [callbacks, fl_callback]
else:
callbacks = [fl_callback]

if restore_state:
if restore_state and not any(isinstance(cb, RestoreState) for cb in callbacks):
callbacks.append(RestoreState())

trainer.callbacks = callbacks
Expand Down

0 comments on commit 5b88e43

Please sign in to comment.