diff --git a/motion_magnification_pytorch.ipynb b/motion_magnification_pytorch.ipynb index 88ee77b..7bf60d5 100644 --- a/motion_magnification_pytorch.ipynb +++ b/motion_magnification_pytorch.ipynb @@ -569,13 +569,13 @@ }, { "cell_type": "code", - "execution_count": 158, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## Get DFT of all video frames\n", "frames_tensor = torch.tensor(np.array(frames)).type(torch.float32).to(device)\n", - "video_dft = torch.fft.fftshift(torch.fft.fft2(frames_tensor, dim=(1,2))).type(torch.complex64).to(device)" + "video_dft = torch.fft.fftshift(torch.fft.fft2(frames_tensor, dim=(1,2)), dim=(1,2)).type(torch.complex64).to(device)" ] }, { @@ -631,7 +631,7 @@ }, { "cell_type": "code", - "execution_count": 163, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -660,8 +660,7 @@ " _delta = torch.angle(curr_pyr) - ref_phase \n", "\n", " # get phase delta wrapped to [-pi, pi]\n", - " phase_deltas[:, vid_idx, :, :] = ((torch.pi + _delta) \\\n", - " % 2*torch.pi) - torch.pi\n", + " phase_deltas[:, vid_idx, :, :] = ((torch.pi + _delta) % (2*torch.pi)) - torch.pi\n", " \n", " ## Temporally Filter the phase deltas\n", " # Filter in Frequency Domain and convert back to phase space\n", diff --git a/phase_based_processing.py b/phase_based_processing.py index ed984b6..1a8a66c 100644 --- a/phase_based_processing.py +++ b/phase_based_processing.py @@ -119,8 +119,7 @@ def process_single_channel(self, _delta = torch.angle(curr_pyr) - ref_phase # get phase delta wrapped to [-pi, pi] - phase_deltas[:, vid_idx, :, :] = ((torch.pi + _delta) \ - % 2*torch.pi) - torch.pi + phase_deltas[:, vid_idx, :, :] = ((torch.pi + _delta) % (2*torch.pi)) - torch.pi ## Temporally Filter the phase deltas # Filter in Frequency Domain and convert back to phase space diff --git a/phase_utils.py b/phase_utils.py index 3134cb5..12d87cc 100644 --- a/phase_utils.py +++ b/phase_utils.py @@ -147,7 +147,7 @@ def create_gif_from_numpy(save_path, images): ## Misc utils def get_fft2_batch(tensor_in): - return torch.fft.fftshift(torch.fft.fft2(tensor_in, dim=(1,2))).type(torch.complex64) + return torch.fft.fftshift(torch.fft.fft2(tensor_in, dim=(1, 2)), dim=(1, 2)).type(torch.complex64) def bandpass_filter(freq_lo, freq_hi, fs, num_taps, device):