Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions examples/dreambooth/train_dreambooth_lora_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,13 @@ def collate_fn(examples, with_prior_preservation=False):


class BucketBatchSampler(BatchSampler):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
def __init__(
self,
dataset: DreamBoothDataset,
batch_size: int,
drop_last: bool = False,
shuffle_batches_each_epoch: bool = True,
):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
Expand All @@ -983,6 +989,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch

# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
Expand All @@ -1004,9 +1011,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches

if not self.shuffle_batches_each_epoch:
# Shuffle the precomputed batches once to mix buckets while keeping
# the order stable across epochs for step-indexed caches.
random.shuffle(self.batches)

def __iter__(self):
# Shuffle the order of the batches each epoch
random.shuffle(self.batches)
if self.shuffle_batches_each_epoch:
random.shuffle(self.batches)
for batch in self.batches:
yield batch
Comment thread
azolotenkov marked this conversation as resolved.

Expand Down Expand Up @@ -1468,7 +1480,13 @@ def load_model_hook(models, input_dir):
center_crop=args.center_crop,
buckets=buckets,
)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
batch_sampler = BucketBatchSampler(
train_dataset,
batch_size=args.train_batch_size,
drop_last=True,
shuffle_batches_each_epoch=not has_step_indexed_caches,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
Expand Down Expand Up @@ -1585,7 +1603,6 @@ def _encode_single(prompt: str):
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
if precompute_latents:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think precompute_latents is a better variable name here. So, perhaps we could do:

has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thanks. I restored the precompute_latents name for the caching path while keeping has_step_indexed_caches for the sampler behavior:

has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts

prompt_embeds_cache = []
text_ids_cache = []
Expand Down
27 changes: 22 additions & 5 deletions examples/dreambooth/train_dreambooth_lora_flux2_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,13 @@ def collate_fn(examples):


class BucketBatchSampler(BatchSampler):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
def __init__(
self,
dataset: DreamBoothDataset,
batch_size: int,
drop_last: bool = False,
shuffle_batches_each_epoch: bool = True,
):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
Expand All @@ -981,6 +987,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch

# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
Expand All @@ -1002,9 +1009,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches

Comment thread
azolotenkov marked this conversation as resolved.
if not self.shuffle_batches_each_epoch:
# Shuffle the precomputed batches once to mix buckets while keeping
# the order stable across epochs for step-indexed caches.
random.shuffle(self.batches)

def __iter__(self):
# Shuffle the order of the batches each epoch
random.shuffle(self.batches)
if self.shuffle_batches_each_epoch:
random.shuffle(self.batches)
for batch in self.batches:
yield batch

Expand Down Expand Up @@ -1415,7 +1427,13 @@ def load_model_hook(models, input_dir):
center_crop=args.center_crop,
buckets=buckets,
)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
batch_sampler = BucketBatchSampler(
train_dataset,
batch_size=args.train_batch_size,
drop_last=True,
shuffle_batches_each_epoch=not has_step_indexed_caches,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
Expand Down Expand Up @@ -1518,7 +1536,6 @@ def _encode_single(prompt: str):
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
if precompute_latents:
prompt_embeds_cache = []
text_ids_cache = []
Expand Down
27 changes: 22 additions & 5 deletions examples/dreambooth/train_dreambooth_lora_flux2_klein.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,13 @@ def collate_fn(examples, with_prior_preservation=False):


class BucketBatchSampler(BatchSampler):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
def __init__(
self,
dataset: DreamBoothDataset,
batch_size: int,
drop_last: bool = False,
shuffle_batches_each_epoch: bool = True,
):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
Expand All @@ -978,6 +984,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch

# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
Expand All @@ -999,9 +1006,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches

Comment thread
azolotenkov marked this conversation as resolved.
if not self.shuffle_batches_each_epoch:
# Shuffle the precomputed batches once to mix buckets while keeping
# the order stable across epochs for step-indexed caches.
random.shuffle(self.batches)

def __iter__(self):
# Shuffle the order of the batches each epoch
random.shuffle(self.batches)
if self.shuffle_batches_each_epoch:
random.shuffle(self.batches)
for batch in self.batches:
yield batch

Expand Down Expand Up @@ -1461,7 +1473,13 @@ def load_model_hook(models, input_dir):
center_crop=args.center_crop,
buckets=buckets,
)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
batch_sampler = BucketBatchSampler(
train_dataset,
batch_size=args.train_batch_size,
drop_last=True,
shuffle_batches_each_epoch=not has_step_indexed_caches,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
Expand Down Expand Up @@ -1528,7 +1546,6 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
if precompute_latents:
prompt_embeds_cache = []
text_ids_cache = []
Expand Down
27 changes: 22 additions & 5 deletions examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,13 @@ def collate_fn(examples):


class BucketBatchSampler(BatchSampler):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
def __init__(
self,
dataset: DreamBoothDataset,
batch_size: int,
drop_last: bool = False,
shuffle_batches_each_epoch: bool = True,
):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
Expand All @@ -977,6 +983,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch

# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
Expand All @@ -998,9 +1005,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches

Comment thread
azolotenkov marked this conversation as resolved.
if not self.shuffle_batches_each_epoch:
# Shuffle the precomputed batches once to mix buckets while keeping
# the order stable across epochs for step-indexed caches.
random.shuffle(self.batches)

def __iter__(self):
# Shuffle the order of the batches each epoch
random.shuffle(self.batches)
if self.shuffle_batches_each_epoch:
random.shuffle(self.batches)
for batch in self.batches:
yield batch

Expand Down Expand Up @@ -1409,7 +1421,13 @@ def load_model_hook(models, input_dir):
center_crop=args.center_crop,
buckets=buckets,
)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
batch_sampler = BucketBatchSampler(
train_dataset,
batch_size=args.train_batch_size,
drop_last=True,
shuffle_batches_each_epoch=not has_step_indexed_caches,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
Expand Down Expand Up @@ -1469,7 +1487,6 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
if precompute_latents:
prompt_embeds_cache = []
text_ids_cache = []
Expand Down
27 changes: 22 additions & 5 deletions examples/dreambooth/train_dreambooth_lora_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,13 @@ def collate_fn(examples, with_prior_preservation=False):


class BucketBatchSampler(BatchSampler):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
def __init__(
self,
dataset: DreamBoothDataset,
batch_size: int,
drop_last: bool = False,
shuffle_batches_each_epoch: bool = True,
):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
Expand All @@ -972,6 +978,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch

# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
Expand All @@ -993,9 +1000,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches

if not self.shuffle_batches_each_epoch:
# Shuffle the precomputed batches once to mix buckets while keeping
# the order stable across epochs for step-indexed caches.
random.shuffle(self.batches)

def __iter__(self):
# Shuffle the order of the batches each epoch
random.shuffle(self.batches)
if self.shuffle_batches_each_epoch:
random.shuffle(self.batches)
for batch in self.batches:
yield batch
Comment thread
azolotenkov marked this conversation as resolved.

Expand Down Expand Up @@ -1449,7 +1461,13 @@ def load_model_hook(models, input_dir):
center_crop=args.center_crop,
buckets=buckets,
)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
batch_sampler = BucketBatchSampler(
train_dataset,
batch_size=args.train_batch_size,
drop_last=True,
shuffle_batches_each_epoch=not has_step_indexed_caches,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
Expand Down Expand Up @@ -1509,7 +1527,6 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
if precompute_latents:
prompt_embeds_cache = []
latents_cache = []
Expand Down
Loading