diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f418133e49f..a1e1cd3d882 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -337,9 +337,13 @@ def _apply_standard_conditioning( ) cross_attention_kwargs["percent_through"] = step_index / total_step_count + uncond_embeds = conditioning_data.uncond_text.embeds.to(x.device) + cond_embeds = conditioning_data.cond_text.embeds.to(x.device) both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( - conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds + uncond_embeds, cond_embeds ) + if added_cond_kwargs is not None: + added_cond_kwargs = {k: v.to(x.device) for k, v in added_cond_kwargs.items()} both_results = self.model_forward_callback( x_twice, sigma_twice, @@ -428,11 +432,15 @@ def _apply_standard_conditioning_sequentially( ) cross_attention_kwargs["percent_through"] = step_index / total_step_count + uncond_embeds = conditioning_data.uncond_text.embeds.to(x.device) + if added_cond_kwargs is not None: + added_cond_kwargs = {k: v.to(x.device) for k, v in added_cond_kwargs.items()} + # Run unconditioned UNet denoising (i.e. negative prompt). unconditioned_next_x = self.model_forward_callback( x, sigma, - conditioning_data.uncond_text.embeds, + uncond_embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=uncond_down_block, mid_block_additional_residual=uncond_mid_block, @@ -465,8 +473,8 @@ def _apply_standard_conditioning_sequentially( added_cond_kwargs = None if conditioning_data.is_sdxl(): added_cond_kwargs = { - "text_embeds": conditioning_data.cond_text.pooled_embeds, - "time_ids": conditioning_data.cond_text.add_time_ids, + "text_embeds": conditioning_data.cond_text.pooled_embeds.to(x.device), + "time_ids": conditioning_data.cond_text.add_time_ids.to(x.device), } # Prepare prompt regions for the conditioned pass. @@ -476,11 +484,12 @@ def _apply_standard_conditioning_sequentially( ) cross_attention_kwargs["percent_through"] = step_index / total_step_count + cond_embeds = conditioning_data.cond_text.embeds.to(x.device) # Run conditioned UNet denoising (i.e. positive prompt). conditioned_next_x = self.model_forward_callback( x, sigma, - conditioning_data.cond_text.embeds, + cond_embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=cond_down_block, mid_block_additional_residual=cond_mid_block,