Skip to content

Commit 04856ac

Browse files
authored
Allow negative batch_index on ImageFromBatch and LatentFromBatch (CORE-195) (Comfy-Org#13857)
1 parent 77e2ed5 commit 04856ac

2 files changed

Lines changed: 8 additions & 4 deletions

File tree

comfy_extras/nodes_images.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def define_schema(cls):
136136
category="image/batch",
137137
inputs=[
138138
IO.Image.Input("image"),
139-
IO.Int.Input("batch_index", default=0, min=0, max=4095),
139+
IO.Int.Input("batch_index", default=0, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
140140
IO.Int.Input("length", default=1, min=1, max=4096),
141141
],
142142
outputs=[IO.Image.Output()],
@@ -145,7 +145,9 @@ def define_schema(cls):
145145
@classmethod
146146
def execute(cls, image, batch_index, length) -> IO.NodeOutput:
147147
s_in = image
148-
batch_index = min(s_in.shape[0] - 1, batch_index)
148+
if batch_index < 0:
149+
batch_index += s_in.shape[0]
150+
batch_index = max(0, min(s_in.shape[0] - 1, batch_index))
149151
length = min(s_in.shape[0] - batch_index, length)
150152
s = s_in[batch_index:batch_index + length].clone()
151153
return IO.NodeOutput(s)

nodes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,7 @@ class LatentFromBatch:
12211221
@classmethod
12221222
def INPUT_TYPES(s):
12231223
return {"required": { "samples": ("LATENT",),
1224-
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
1224+
"batch_index": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}),
12251225
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
12261226
}}
12271227
RETURN_TYPES = ("LATENT",)
@@ -1232,7 +1232,9 @@ def INPUT_TYPES(s):
12321232
def frombatch(self, samples, batch_index, length):
12331233
s = samples.copy()
12341234
s_in = samples["samples"]
1235-
batch_index = min(s_in.shape[0] - 1, batch_index)
1235+
if batch_index < 0:
1236+
batch_index += s_in.shape[0]
1237+
batch_index = max(0, min(s_in.shape[0] - 1, batch_index))
12361238
length = min(s_in.shape[0] - batch_index, length)
12371239
s["samples"] = s_in[batch_index:batch_index + length].clone()
12381240
if "noise_mask" in samples:

0 commit comments

Comments
 (0)