-
Notifications
You must be signed in to change notification settings - Fork 7k
Fix BucketBatchSampler cache alignment in DreamBooth scripts #13353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
da85c1c
7dccd52
e31704f
04c6304
3cabf56
fff6c8c
57f10d5
8ef3fc1
fde056d
a476637
94e2c3d
3e253b4
7df7cb5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -974,7 +974,13 @@ def collate_fn(examples, with_prior_preservation=False): | |
|
|
||
|
|
||
| class BucketBatchSampler(BatchSampler): | ||
| def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): | ||
| def __init__( | ||
| self, | ||
| dataset: DreamBoothDataset, | ||
| batch_size: int, | ||
| drop_last: bool = False, | ||
| shuffle_batches_each_epoch: bool = True, | ||
| ): | ||
| if not isinstance(batch_size, int) or batch_size <= 0: | ||
| raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) | ||
| if not isinstance(drop_last, bool): | ||
|
|
@@ -983,6 +989,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool | |
| self.dataset = dataset | ||
| self.batch_size = batch_size | ||
| self.drop_last = drop_last | ||
| self.shuffle_batches_each_epoch = shuffle_batches_each_epoch | ||
|
|
||
| # Group indices by bucket | ||
| self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] | ||
|
|
@@ -1004,9 +1011,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool | |
| self.batches.append(batch) | ||
| self.sampler_len += 1 # Count the number of batches | ||
|
|
||
| if not self.shuffle_batches_each_epoch: | ||
| # Shuffle the precomputed batches once to mix buckets while keeping | ||
| # the order stable across epochs for step-indexed caches. | ||
| random.shuffle(self.batches) | ||
|
|
||
| def __iter__(self): | ||
| # Shuffle the order of the batches each epoch | ||
| random.shuffle(self.batches) | ||
| if self.shuffle_batches_each_epoch: | ||
| random.shuffle(self.batches) | ||
| for batch in self.batches: | ||
| yield batch | ||
|
|
||
|
|
@@ -1468,7 +1480,13 @@ def load_model_hook(models, input_dir): | |
| center_crop=args.center_crop, | ||
| buckets=buckets, | ||
| ) | ||
| batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) | ||
| has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts | ||
| batch_sampler = BucketBatchSampler( | ||
| train_dataset, | ||
| batch_size=args.train_batch_size, | ||
| drop_last=True, | ||
| shuffle_batches_each_epoch=not has_step_indexed_caches, | ||
| ) | ||
| train_dataloader = torch.utils.data.DataLoader( | ||
| train_dataset, | ||
| batch_sampler=batch_sampler, | ||
|
|
@@ -1585,7 +1603,6 @@ def _encode_single(prompt: str): | |
| # if cache_latents is set to True, we encode images to latents and store them. | ||
| # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided | ||
| # we encode them in advance as well. | ||
| precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts | ||
| if precompute_latents: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated, thanks. I restored the
|
||
| prompt_embeds_cache = [] | ||
| text_ids_cache = [] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.