-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_train.py
More file actions
134 lines (106 loc) · 5.48 KB
/
create_train.py
File metadata and controls
134 lines (106 loc) · 5.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import argparse
import math
import pandas as pd
import numpy as np
import xarray as xr
import rioxarray
import rasterio
from rasterio.features import rasterize
import geopandas as gpd
import shapely
import matplotlib.pyplot as plt
from utils import get_phase_and_magnitude, flow_train_dataset
parser = argparse.ArgumentParser()
parser.add_argument('--save', type=int, default=0, help="Save train dataset or not.")
parser.add_argument('--random_state', type=int, default=42, help="Random or deterministic. Use -1 for randomness.")
parser.add_argument('--path_train_data', type=str,default="/media/maffe/sturellone/groundice/data/train", help="Path to data")
parser.add_argument('--path_train_out', type=str,default="/media/maffe/sturellone/groundice/data/train/train", help="Path to where data to be saved")
parser.add_argument('--patch_size', type=int, default=256, help="Size of images")
parser.add_argument('--all_touched', type=int, default=1, help="Burn all touching pixels or use Bresenham line algorithm")
args = parser.parse_args()
if not os.path.exists(f'{args.path_train_out}/{args.patch_size}'):
os.makedirs(f'{args.path_train_out}/{args.patch_size}')
# Import .gpkg
data_grounding_line_geometries = gpd.read_file(f"{args.path_train_data}/all_manual_segm_gl.gpkg")
# Remove non valid geometies: they should be 44 of them. Valid: 2524 -> 2480
data_grounding_line_geometries = data_grounding_line_geometries[~data_grounding_line_geometries.geometry.isna()]
#print(data_grounding_line_geometries)
num_training_scenes = len(os.listdir(f'{args.path_train_data}/DInSAR'))
# loop over the tif scenes
for n, file_name_tif in enumerate(os.listdir(f'{args.path_train_data}/DInSAR')):
# get the grounding data for each tif scene
data_gl_scene = data_grounding_line_geometries[data_grounding_line_geometries['source_file'] == file_name_tif.replace(".tif", ".shp")]
#print(n, file_name_tif, len(data_gl_scene))
# for some reason some scenes do not have grounding lines in the dataset. Let's skip those
if len(data_gl_scene) == 0:
print(f"{n} File {file_name_tif} has no grounding lines in dataset and will be skipped.")
continue
# open the tif scene
tif_dinsar_re_im = rioxarray.open_rasterio(f'{args.path_train_data}/DInSAR/{file_name_tif}')
phase, magnitude = get_phase_and_magnitude(tif_dinsar_re_im)
tif_dinsar_phi_mag = tif_dinsar_re_im.copy(deep=True)
tif_dinsar_phi_mag.values[0, :, :] = phase
tif_dinsar_phi_mag.values[1, :, :] = magnitude
# create an xarray for burning in grounding line pixels. We initialize it with zeros.
tif_gl_mask = xr.DataArray(
np.zeros((1, tif_dinsar_re_im.shape[1], tif_dinsar_re_im.shape[2]), dtype="uint8"),
dims=tif_dinsar_re_im.dims,
coords={
"band": ["mask"],
"y": tif_dinsar_re_im.coords["y"],
"x": tif_dinsar_re_im.coords["x"],
},
)
# Copy over rioxarray geospatial attrs
tif_gl_mask.rio.write_crs(tif_dinsar_re_im.rio.crs, inplace=True)
tif_gl_mask.rio.write_transform(tif_dinsar_re_im.rio.transform(), inplace=True)
# check the projection
assert data_gl_scene.crs == tif_gl_mask.rio.crs, "Projection not as expected: EPSG:3031"
# Burn in geometries
mask_array = rasterize(
((geom, 1) for geom in data_gl_scene.geometry),
out_shape=(tif_gl_mask.shape[1], tif_gl_mask.shape[2]),
transform=tif_gl_mask.rio.transform(),
all_touched = args.all_touched,
fill=0,
dtype='uint8'
)
# Plug in the burned mask
tif_gl_mask.values[0] = mask_array
# At this point we have the two xarrays
# tif_dinsar_re_im is the DInSAR scene
# tif_gl_mask is the mask
plot = False
if plot:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(13, 4))
data_gl_scene.plot(ax=ax1, ec='black', fc='none', lw=2)
data_gl_scene.plot(ax=ax2, ec='red', fc='none', lw=2)
data_gl_scene.plot(ax=ax3, ec='red', fc='none', lw=2, alpha=0.5)
im1 = tif_dinsar_phi_mag.sel(band=1).plot(ax=ax1, cmap='hsv', vmin=-np.pi, vmax=np.pi, add_colorbar=False)
im2 = tif_dinsar_phi_mag.sel(band=2).plot(ax=ax2, cmap='gray', add_colorbar=False)
im3 = tif_gl_mask.plot(ax=ax3, cmap='gray', add_colorbar=False)
for ax in (ax1, ax2, ax3):
ax.set_aspect('equal')
ax1.set_title("Phase")
ax2.set_title("Coherence")
ax2.set_title("Mask")
# Add colorbars below each plot
cbar1 = plt.colorbar(im1, ax=ax1, orientation='horizontal', pad=0.1, fraction=0.046)
cbar1.set_label("Phase [rad]")
cbar2 = plt.colorbar(im2, ax=ax2, orientation='horizontal', pad=0.1, fraction=0.046)
cbar2.set_label("Coherence")
cbar3 = plt.colorbar(im3, ax=ax3, orientation='horizontal', pad=0.1, fraction=0.046)
cbar3.set_label("Mask")
plt.tight_layout()
plt.show()
seen, num_images = flow_train_dataset(scene_dinsar=tif_dinsar_re_im,
scene_mask=tif_gl_mask,
scene_ground_lines = data_gl_scene,
scene_name = file_name_tif.replace(".tif", ""),
patch_size = args.patch_size,
random_state=args.random_state,
path_out=args.path_train_out,
save=args.save)
print(f"From scene {n}/{num_training_scenes} we have produced {num_images} images of shape {args.patch_size}x{args.patch_size}.")
exit("EXIT.")