diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 217053855445..28722ec25e7a 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -974,7 +974,13 @@ def collate_fn(examples, with_prior_preservation=False): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -983,6 +989,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -1004,9 +1011,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1468,7 +1480,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1585,7 +1603,6 @@ def _encode_single(prompt: str): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts if precompute_latents: prompt_embeds_cache = [] text_ids_cache = [] diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 7976ad1da211..477697fadb64 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -972,7 +972,13 @@ def collate_fn(examples): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -981,6 +987,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -1002,9 +1009,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1415,7 +1427,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1518,7 +1536,6 @@ def _encode_single(prompt: str): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts if precompute_latents: prompt_embeds_cache = [] text_ids_cache = [] diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index f011150784a3..21cbc8a2c47b 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -969,7 +969,13 @@ def collate_fn(examples, with_prior_preservation=False): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -978,6 +984,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -999,9 +1006,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1461,7 +1473,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1528,7 +1546,6 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts if precompute_latents: prompt_embeds_cache = [] text_ids_cache = [] diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index a21bb85da7eb..63862eed9f1e 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -968,7 +968,13 @@ def collate_fn(examples): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -977,6 +983,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -998,9 +1005,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1409,7 +1421,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1469,7 +1487,6 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts if precompute_latents: prompt_embeds_cache = [] text_ids_cache = [] diff --git a/examples/dreambooth/train_dreambooth_lora_z_image.py b/examples/dreambooth/train_dreambooth_lora_z_image.py index ee53ebe870a8..a54c84b0798f 100644 --- a/examples/dreambooth/train_dreambooth_lora_z_image.py +++ b/examples/dreambooth/train_dreambooth_lora_z_image.py @@ -963,7 +963,13 @@ def collate_fn(examples, with_prior_preservation=False): class BucketBatchSampler(BatchSampler): - def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + def __init__( + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, + ): if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): @@ -972,6 +978,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch # Group indices by bucket self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] @@ -993,9 +1000,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool self.batches.append(batch) self.sampler_len += 1 # Count the number of batches + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + def __iter__(self): - # Shuffle the order of the batches each epoch - random.shuffle(self.batches) + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) for batch in self.batches: yield batch @@ -1449,7 +1461,13 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, buckets=buckets, ) - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -1509,7 +1527,6 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided # we encode them in advance as well. - precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts if precompute_latents: prompt_embeds_cache = [] latents_cache = []