Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 54 additions & 72 deletions invokeai/app/invocations/composition-nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,27 @@
from torchvision.transforms.functional import to_pil_image as pil_image_from_tensor

from invokeai.app.invocations.primitives import ImageOutput
from invokeai.backend.image_util.composition import (
CIELAB_TO_UPLAB_ICC_PATH,
MAX_FLOAT,
equivalent_achromatic_lightness,
gamut_clip_tensor,
from invokeai.backend.image_util.color_conversion import (
hsl_from_srgb,
linear_srgb_from_oklab,
linear_srgb_from_oklch,
linear_srgb_from_srgb,
okhsl_from_srgb,
okhsv_from_srgb,
oklab_from_linear_srgb,
remove_nans,
oklab_from_oklch,
oklch_from_oklab,
srgb_from_hsl,
srgb_from_linear_srgb,
srgb_from_okhsl,
srgb_from_okhsv,
)
from invokeai.backend.image_util.composition import (
CIELAB_TO_UPLAB_ICC_PATH,
MAX_FLOAT,
equivalent_achromatic_lightness,
gamut_clip_tensor,
remove_nans,
srgb_from_linear_srgb,
tensor_from_pil_image,
)
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
Expand Down Expand Up @@ -136,20 +141,20 @@ def invoke(self, context: InvocationContext) -> ImageOutput:

if space == "hsv":
hsv_tensor = image_resized_to_grid_as_tensor(image_in.convert("HSV"), normalize=False, multiple_of=1)
hsv_tensor[0, :, :] = torch.remainder(torch.add(hsv_tensor[0, :, :], torch.div(self.degrees, 360.0)), 1.0)
hsv_tensor[0, :, :] = torch.remainder(torch.add(hsv_tensor[0, :, :] * 360.0, self.degrees), 360.0) / 360.0
image_out = pil_image_from_tensor(hsv_tensor, mode="HSV").convert("RGB")

elif space == "okhsl":
rgb_tensor = image_resized_to_grid_as_tensor(image_in.convert("RGB"), normalize=False, multiple_of=1)
hsl_tensor = okhsl_from_srgb(rgb_tensor, steps=(3 if self.ok_high_precision else 1))
hsl_tensor[0, :, :] = torch.remainder(torch.add(hsl_tensor[0, :, :], torch.div(self.degrees, 360.0)), 1.0)
hsl_tensor[0, :, :] = torch.remainder(torch.add(hsl_tensor[0, :, :], self.degrees), 360.0)
rgb_tensor = srgb_from_okhsl(hsl_tensor, alpha=0.0)
image_out = pil_image_from_tensor(rgb_tensor, mode="RGB")

elif space == "okhsv":
rgb_tensor = image_resized_to_grid_as_tensor(image_in.convert("RGB"), normalize=False, multiple_of=1)
hsv_tensor = okhsv_from_srgb(rgb_tensor, steps=(3 if self.ok_high_precision else 1))
hsv_tensor[0, :, :] = torch.remainder(torch.add(hsv_tensor[0, :, :], torch.div(self.degrees, 360.0)), 1.0)
hsv_tensor[0, :, :] = torch.remainder(torch.add(hsv_tensor[0, :, :], self.degrees), 360.0)
rgb_tensor = srgb_from_okhsv(hsv_tensor, alpha=0.0)
image_out = pil_image_from_tensor(rgb_tensor, mode="RGB")

Expand Down Expand Up @@ -197,24 +202,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
rgb_tensor = image_resized_to_grid_as_tensor(image_in.convert("RGB"), normalize=False, multiple_of=1)

linear_srgb_tensor = linear_srgb_from_srgb(rgb_tensor)

lab_tensor = oklab_from_linear_srgb(linear_srgb_tensor)

# L*a*b* to L*C*h
c_tensor = torch.sqrt(torch.add(torch.pow(lab_tensor[1, :, :], 2.0), torch.pow(lab_tensor[2, :, :], 2.0)))
h_tensor = torch.atan2(lab_tensor[2, :, :], lab_tensor[1, :, :])

# Rotate h
rot_rads = (self.degrees / 180.0) * PI

h_rot = torch.add(h_tensor, rot_rads)
h_rot = torch.remainder(torch.add(h_rot, 2 * PI), 2 * PI)

# L*C*h to L*a*b*
lab_tensor[1, :, :] = torch.mul(c_tensor, torch.cos(h_rot))
lab_tensor[2, :, :] = torch.mul(c_tensor, torch.sin(h_rot))

linear_srgb_tensor = linear_srgb_from_oklab(lab_tensor)
oklch_tensor = oklch_from_oklab(oklab_from_linear_srgb(linear_srgb_tensor))
oklch_tensor[2, :, :] = torch.remainder(torch.add(oklch_tensor[2, :, :], self.degrees), 360.0)
linear_srgb_tensor = linear_srgb_from_oklch(oklch_tensor)

rgb_tensor = srgb_from_linear_srgb(
linear_srgb_tensor, alpha=self.ok_adaptive_gamut, steps=(3 if self.ok_high_precision else 1)
Expand Down Expand Up @@ -602,14 +592,14 @@ def prepare_tensors_from_images(
image_hsv_upper, image_hsv_lower = image_upper.convert("HSV"), image_lower.convert("HSV")
upper_hsv_tensor = torch.stack(
[
tensor_from_pil_image(image_hsv_upper.getchannel("H"), normalize=False)[0, :, :],
tensor_from_pil_image(image_hsv_upper.getchannel("H"), normalize=False)[0, :, :] * 360.0,
tensor_from_pil_image(image_hsv_upper.getchannel("S"), normalize=False)[0, :, :],
tensor_from_pil_image(image_hsv_upper.getchannel("V"), normalize=False)[0, :, :],
]
)
lower_hsv_tensor = torch.stack(
[
tensor_from_pil_image(image_hsv_lower.getchannel("H"), normalize=False)[0, :, :],
tensor_from_pil_image(image_hsv_lower.getchannel("H"), normalize=False)[0, :, :] * 360.0,
tensor_from_pil_image(image_hsv_lower.getchannel("S"), normalize=False)[0, :, :],
tensor_from_pil_image(image_hsv_lower.getchannel("V"), normalize=False)[0, :, :],
]
Expand Down Expand Up @@ -655,29 +645,8 @@ def prepare_tensors_from_images(
if "oklch" in required:
upper_oklab_tensor = oklab_from_linear_srgb(upper_rgb_l_tensor)
lower_oklab_tensor = oklab_from_linear_srgb(lower_rgb_l_tensor)

upper_oklch_tensor = torch.stack(
[
upper_oklab_tensor[0, :, :],
torch.sqrt(
torch.add(
torch.pow(upper_oklab_tensor[1, :, :], 2.0), torch.pow(upper_oklab_tensor[2, :, :], 2.0)
)
),
torch.atan2(upper_oklab_tensor[2, :, :], upper_oklab_tensor[1, :, :]),
]
)
lower_oklch_tensor = torch.stack(
[
lower_oklab_tensor[0, :, :],
torch.sqrt(
torch.add(
torch.pow(lower_oklab_tensor[1, :, :], 2.0), torch.pow(lower_oklab_tensor[2, :, :], 2.0)
)
),
torch.atan2(lower_oklab_tensor[2, :, :], lower_oklab_tensor[1, :, :]),
]
)
upper_oklch_tensor = oklch_from_oklab(upper_oklab_tensor)
lower_oklch_tensor = oklch_from_oklab(lower_oklab_tensor)

return (
upper_rgb_l_tensor,
Expand Down Expand Up @@ -736,7 +705,17 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
"HSL": lambda t: linear_srgb_from_srgb(srgb_from_hsl(t)),
"HSV": lambda t: linear_srgb_from_srgb(
tensor_from_pil_image(
pil_image_from_tensor(t.clamp(0.0, 1.0), mode="HSV").convert("RGB"), normalize=False
pil_image_from_tensor(
torch.stack(
[
torch.remainder(t[0, :, :], 360.0) / 360.0,
t[1, :, :].clamp(0.0, 1.0),
t[2, :, :].clamp(0.0, 1.0),
]
),
mode="HSV",
).convert("RGB"),
normalize=False,
)
),
"Okhsl": lambda t: linear_srgb_from_srgb(
Expand All @@ -745,15 +724,7 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
"Okhsv": lambda t: linear_srgb_from_srgb(
srgb_from_okhsv(t, alpha=self.adaptive_gamut, steps=(3 if self.high_precision else 1))
),
"Oklch": lambda t: linear_srgb_from_oklab(
torch.stack(
[
t[0, :, :],
torch.mul(t[1, :, :], torch.cos(t[2, :, :])),
torch.mul(t[1, :, :], torch.sin(t[2, :, :])),
]
)
),
"Oklch": lambda t: linear_srgb_from_oklab(oklab_from_oklch(t)),
"LCh": lambda t: linear_srgb_from_srgb(
tensor_from_pil_image(
self.image_convert_with_xform(
Expand Down Expand Up @@ -784,9 +755,9 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
alpha_upper_tensor,
alpha_lower_tensor,
mask_tensor,
upper_hsv_tensor, # h_rgb, s_hsv, v_hsv
upper_hsv_tensor, # h_hsv_degrees, s_hsv, v_hsv
lower_hsv_tensor,
upper_hsl_tensor, # , s_hsl, l_hsl
upper_hsl_tensor, # h_hsl_degrees, s_hsl, l_hsl
lower_hsl_tensor,
upper_lab_tensor, # l_lab, a_lab, b_lab
lower_lab_tensor,
Expand All @@ -796,11 +767,11 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
lower_l_eal_tensor,
upper_oklab_tensor, # l_oklab, a_oklab, b_oklab
lower_oklab_tensor,
upper_oklch_tensor, # , c_oklab, h_oklab
upper_oklch_tensor, # l_oklab, c_oklab, h_oklab_degrees
lower_oklch_tensor,
upper_okhsv_tensor, # h_okhsv, s_okhsv, v_okhsv
upper_okhsv_tensor, # h_okhsv_degrees, s_okhsv, v_okhsv
lower_okhsv_tensor,
upper_okhsl_tensor, # h_okhsl, s_okhsl, l_r_oklab
upper_okhsl_tensor, # h_okhsl_degrees, s_okhsl, l_r_oklab
lower_okhsl_tensor,
) = image_tensors

Expand Down Expand Up @@ -850,6 +821,17 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
"LCh": 2,
}[color_space]

hue_period = {
"RGB": None,
"Linear": None,
"HSL": 360.0,
"HSV": 360.0,
"Okhsl": 360.0,
"Okhsv": 360.0,
"Oklch": 360.0,
"LCh": 2.0 * PI,
}[color_space]

if blend_mode == "Normal":
upper_rgb_l_tensor = reassembly_function(upper_space_tensor)

Expand Down Expand Up @@ -982,19 +964,19 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
elif blend_mode == "Linear Dodge (Add)":
lower_space_tensor = torch.add(lower_space_tensor, upper_space_tensor)
if hue_index is not None:
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], 1.0)
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], hue_period)
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))

elif blend_mode == "Color Dodge":
lower_space_tensor = torch.div(lower_space_tensor, torch.add(torch.mul(upper_space_tensor, -1.0), 1.0))
if hue_index is not None:
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], 1.0)
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], hue_period)
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))

elif blend_mode == "Divide":
lower_space_tensor = torch.div(lower_space_tensor, upper_space_tensor)
if hue_index is not None:
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], 1.0)
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], hue_period)
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))

elif blend_mode == "Linear Burn":
Expand Down Expand Up @@ -1088,7 +1070,7 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
elif blend_mode == "Subtract":
lower_space_tensor = torch.sub(lower_space_tensor, upper_space_tensor)
if hue_index is not None:
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], 1.0)
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], hue_period)
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))

elif blend_mode == "Difference":
Expand Down
Loading
Loading