Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: split shards per node #39

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 68 additions & 27 deletions src/training/data/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,32 +40,56 @@
_TEXT_EXTENSIONS = ['txt', 'text', 'caption']


def expand_urls(urls, weights=None):
if weights is None:
expanded_urls = wds.shardlists.expand_urls(urls)
return expanded_urls, None
def expand_urls(urls, weights=None, nodes=None):

if isinstance(urls, str):
urllist = urls.split('::')
weights = weights.split('::')
assert len(weights) == len(urllist), (
f'Expected the number of data components ({len(urllist)}) and weights '
f'({len(weights)}) to match.'
)
weights = [float(weight) for weight in weights]
all_urls, all_weights = [], []
for url, weight in zip(urllist, weights):
expanded_url = list(braceexpand.braceexpand(url))
expanded_weights = [weight for _ in expanded_url]
all_urls.extend(expanded_url)
all_weights.extend(expanded_weights)
return all_urls, all_weights
urllist = list(urls.split('::'))
else:
all_urls = list(urls)
return all_urls, weights
urllist = list(urls)

weightlist = None
if weights is not None:
if isinstance(weights, str):
weightlist = [float(weight) for weight in weights.split('::')]
else:
weightlist = list(weights)
assert len(urllist) == len(weightlist), (
f'Number of urls {len(urllist)} and weights {len(weightlist)} '
f'should match.'
)

nodelist = None
if nodes is not None:
if isinstance(nodes, str):
nodelist = [int(node) for node in nodes.split('::')]
else:
nodelist = list(nodes)
assert len(urllist) == len(nodelist), (
f'Number of urls {len(urllist)} and nodes {len(nodelist)} '
f'should match.'
)

_all_urls, _all_weights, _all_nodes = [], [], []
for i, url in enumerate(urllist):
_expanded_urls = list(braceexpand.braceexpand(url))
_all_urls.extend(_expanded_urls)
if weightlist is not None:
_expanded_weights = [weightlist[i] for _ in _expanded_urls]
_all_weights.extend(_expanded_weights)
if nodelist is not None:
_expanded_nodes = [nodelist[i] for _ in _expanded_urls]
_all_nodes.extend(_expanded_nodes)

if len(_all_weights) == 0:
_all_weights = None
if len(_all_nodes) == 0:
_all_nodes = None

return _all_urls, _all_weights, _all_nodes


def get_dataset_size(shards):
shards_list, _ = expand_urls(shards)
shards_list, _, _ = expand_urls(shards)
dir_path = os.path.dirname(shards_list[0])
sizes_filename = os.path.join(dir_path, 'sizes.json')
len_filename = os.path.join(dir_path, '__len__')
Expand Down Expand Up @@ -232,20 +256,30 @@ def __init__(
self,
urls: str,
weights: Optional[str] = None,
nodes: Optional[str] = None,
nshards: int = sys.maxsize,
worker_seed: Optional[Callable] = None,
deterministic: bool = False,
epoch: Union[int, _SharedEpoch] = -1,
):
super().__init__()
urls, weights = expand_urls(urls, weights)
urls, weights, nodes = expand_urls(urls, weights, nodes)
self.urls = urls
self.weights = weights
if self.weights is not None:
assert len(self.urls) == len(self.weights), (
f'Number of urls {len(self.urls)} and weights {len(self.weights)} '
f'should match.'
)
self.nodes = nodes
self.current_node = int(os.environ.get('GROUP_RANK', '0'))
logger.debug(f'Current node: {self.current_node}')

if self.nodes is not None:
self.urls = [
url for node, url in zip(self.nodes, self.urls)
if node == self.current_node
]
if self.weights is not None:
self.weights = [
weight for node, weight in zip(self.nodes, self.weights)
if node == self.current_node
]
assert isinstance(self.urls[0], str)
self.nshards = nshards
self.rng = random.Random()
Expand Down Expand Up @@ -430,6 +464,7 @@ def get_wds_dataset(
is_train: bool = False,
tokenizer: Any = None,
upsampling_factors: Optional[str] = None,
group_ids: Optional[str] = None,
images_pairs: bool = False,
use_long_captions: bool = False,
workers: int = 1,
Expand Down Expand Up @@ -463,13 +498,19 @@ def get_wds_dataset(
'Upsampling factors are only supported when sampling with '
'replacement (with --dataset-resampled).'
)
if is_train and group_ids is not None:
assert resampled, (
'Group IDs are only supported when sampling with '
'replacement (with --dataset-resampled).'
)

if is_train:
if resampled:
_shard_pipeline = [
_ResampledShards(
shards,
weights=upsampling_factors,
nodes=group_ids,
deterministic=True,
epoch=shared_epoch,
),
Expand Down
2 changes: 2 additions & 0 deletions src/training/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def _create_multimodal_dataloader(
is_train=True,
tokenizer=tokenizer,
upsampling_factors=args.train_data_upsampling_factors,
group_ids=args.train_data_assigned_groupids,
use_long_captions=False,
workers=args.workers,
batch_size=batch_size,
Expand Down Expand Up @@ -290,6 +291,7 @@ def _create_images_dataloader(
is_train=True,
tokenizer=tokenizer,
upsampling_factors=args.train_imgdata_upsampling_factors,
group_ids=args.train_imgdata_assigned_groupids,
images_pairs=True,
workers=args.workers,
batch_size=batch_size,
Expand Down
19 changes: 19 additions & 0 deletions src/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,19 @@ def parse_args(args):
'dataset sizes.'
),
)
parser.add_argument(
'--train-data-assigned-groupids',
type=str,
default=None,
help=(
'When using multiple data sources with webdataset and sampling with '
'replacement, this can be used to assign each data source to a specific '
'worker group. Similar to --train-data, this should be a string '
'with as many numbers as there are data sources, separated by `::` '
'(e.g. 0::1::0). By default, all data sources are assigned to all worker '
'groups.'
),
)
parser.add_argument(
'--train-txtdata',
type=str,
Expand Down Expand Up @@ -113,6 +126,12 @@ def parse_args(args):
default=None,
help='Similar to --train-data-upsampling-factors, but for --train-imgdata.',
)
parser.add_argument(
'--train-imgdata-assigned-groupids',
type=str,
default=None,
help='Similar to --train-data-assigned-groupids, but for --train-imgdata.',
)
parser.add_argument(
'--train-mtldata',
type=str,
Expand Down