2929import argparse
3030
3131
32- def run (n , backend , benchmark_mode , correctness_test ):
32+ def run (n , backend , datatype , benchmark_mode ):
3333 if backend == "ddpt" :
3434 import ddptensor as np
3535 from ddptensor .numpy import fromfunction
@@ -64,8 +64,11 @@ def info(s):
6464
6565 info (f"Using backend: { backend } " )
6666
67- if correctness_test :
68- n = 10
67+ dtype = {
68+ "f64" : np .float64 ,
69+ "f32" : np .float32 ,
70+ }[datatype ]
71+ info (f"Datatype: { datatype } " )
6972
7073 # constants
7174 g = 9.81
@@ -92,20 +95,16 @@ def info(s):
9295 t_end = 1.0
9396
9497 # coordinate arrays
95- x_t_2d = fromfunction (
96- lambda i , j : xmin + i * dx + dx / 2 , (nx , ny ), dtype = np .float64
97- )
98- y_t_2d = fromfunction (
99- lambda i , j : ymin + j * dy + dy / 2 , (nx , ny ), dtype = np .float64
100- )
101- x_u_2d = fromfunction (lambda i , j : xmin + i * dx , (nx + 1 , ny ), dtype = np .float64 )
98+ x_t_2d = fromfunction (lambda i , j : xmin + i * dx + dx / 2 , (nx , ny ), dtype = dtype )
99+ y_t_2d = fromfunction (lambda i , j : ymin + j * dy + dy / 2 , (nx , ny ), dtype = dtype )
100+ x_u_2d = fromfunction (lambda i , j : xmin + i * dx , (nx + 1 , ny ), dtype = dtype )
102101 y_u_2d = fromfunction (
103- lambda i , j : ymin + j * dy + dy / 2 , (nx + 1 , ny ), dtype = np . float64
102+ lambda i , j : ymin + j * dy + dy / 2 , (nx + 1 , ny ), dtype = dtype
104103 )
105104 x_v_2d = fromfunction (
106- lambda i , j : xmin + i * dx + dx / 2 , (nx , ny + 1 ), dtype = np . float64
105+ lambda i , j : xmin + i * dx + dx / 2 , (nx , ny + 1 ), dtype = dtype
107106 )
108- y_v_2d = fromfunction (lambda i , j : ymin + j * dy , (nx , ny + 1 ), dtype = np . float64 )
107+ y_v_2d = fromfunction (lambda i , j : ymin + j * dy , (nx , ny + 1 ), dtype = dtype )
109108
110109 T_shape = (nx , ny )
111110 U_shape = (nx + 1 , ny )
@@ -122,32 +121,32 @@ def info(s):
122121 info (f"Total DOFs: { dofs_T + dofs_U + dofs_V } " )
123122
124123 # prognostic variables: elevation, (u, v) velocity
125- e = np .full (T_shape , 0.0 , np . float64 )
126- u = np .full (U_shape , 0.0 , np . float64 )
127- v = np .full (V_shape , 0.0 , np . float64 )
124+ e = np .full (T_shape , 0.0 , dtype )
125+ u = np .full (U_shape , 0.0 , dtype )
126+ v = np .full (V_shape , 0.0 , dtype )
128127
129128 # potential vorticity
130- q = np .full (F_shape , 0.0 , np . float64 )
129+ q = np .full (F_shape , 0.0 , dtype )
131130
132131 # bathymetry
133- h = np .full (T_shape , 0.0 , np . float64 )
132+ h = np .full (T_shape , 0.0 , dtype )
134133
135- hu = np .full (U_shape , 0.0 , np . float64 )
136- hv = np .full (V_shape , 0.0 , np . float64 )
134+ hu = np .full (U_shape , 0.0 , dtype )
135+ hv = np .full (V_shape , 0.0 , dtype )
137136
138- dudy = np .full (F_shape , 0.0 , np . float64 )
139- dvdx = np .full (F_shape , 0.0 , np . float64 )
137+ dudy = np .full (F_shape , 0.0 , dtype )
138+ dvdx = np .full (F_shape , 0.0 , dtype )
140139
141140 # vector invariant form
142- H_at_f = np .full (F_shape , 0.0 , np . float64 )
141+ H_at_f = np .full (F_shape , 0.0 , dtype )
143142
144143 # auxiliary variables for RK time integration
145- e1 = np .full (T_shape , 0.0 , np . float64 )
146- u1 = np .full (U_shape , 0.0 , np . float64 )
147- v1 = np .full (V_shape , 0.0 , np . float64 )
148- e2 = np .full (T_shape , 0.0 , np . float64 )
149- u2 = np .full (U_shape , 0.0 , np . float64 )
150- v2 = np .full (V_shape , 0.0 , np . float64 )
144+ e1 = np .full (T_shape , 0.0 , dtype )
145+ u1 = np .full (U_shape , 0.0 , dtype )
146+ v1 = np .full (V_shape , 0.0 , dtype )
147+ e2 = np .full (T_shape , 0.0 , dtype )
148+ u2 = np .full (U_shape , 0.0 , dtype )
149+ v2 = np .full (V_shape , 0.0 , dtype )
151150
152151 def exact_solution (t , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d ):
153152 """
@@ -176,7 +175,7 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
176175 Water depth at rest
177176 """
178177 bath = 1.0
179- return bath * np .full (T_shape , 1.0 , np . float64 )
178+ return bath * np .full (T_shape , 1.0 , dtype )
180179
181180 # inital elevation
182181 u0 , v0 , e0 = exact_solution (0 , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d )
@@ -200,10 +199,6 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
200199 dt = 1e-5
201200 nt = 100
202201 t_export = dt * 25
203- if correctness_test :
204- dt = 0.02
205- nt = 10
206- t_export = dt * 2
207202
208203 info (f"Time step: { dt } s" )
209204 info (f"Total run time: { t_end } s, { nt } time steps" )
@@ -381,20 +376,22 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
381376 err_L2 = math .sqrt (float (np .sum (err2 , all_axes )))
382377 info (f"L2 error: { err_L2 :7.15e} " )
383378
384- if correctness_test :
385- assert numpy .allclose (err_L2 , 3.687334565903038e-04 ), "L2 error does not match"
386- info ("SUCCESS" )
387- elif nx < 128 or ny < 128 :
379+ if nx < 128 or ny < 128 :
388380 info ("Skipping correctness test due to small problem size." )
389381 elif not benchmark_mode :
390- tolerance_ene = 1e-8
382+ tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
391383 assert (
392384 diff_e < tolerance_ene
393385 ), f"Energy error exceeds tolerance: { diff_e } > { tolerance_ene } "
394386 if nx == 128 and ny == 128 :
395- assert numpy .allclose (
396- err_L2 , 4.315799035627906e-05
397- ), "L2 error does not match"
387+ if datatype == "f32" :
388+ assert numpy .allclose (
389+ err_L2 , 4.3127859e-05 , rtol = 1e-5
390+ ), "L2 error does not match"
391+ else :
392+ assert numpy .allclose (
393+ err_L2 , 4.315799035627906e-05
394+ ), "L2 error does not match"
398395 else :
399396 tolerance_l2 = 1e-4
400397 assert (
@@ -423,12 +420,6 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
423420 action = "store_true" ,
424421 help = "Run a fixed number of time steps." ,
425422 )
426- parser .add_argument (
427- "-ct" ,
428- "--correctness-test" ,
429- action = "store_true" ,
430- help = "Run a minimal correctness test." ,
431- )
432423 parser .add_argument (
433424 "-b" ,
434425 "--backend" ,
@@ -437,5 +428,18 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
437428 choices = ["ddpt" , "numpy" ],
438429 help = "Backend to use." ,
439430 )
431+ parser .add_argument (
432+ "-d" ,
433+ "--datatype" ,
434+ type = str ,
435+ default = "f64" ,
436+ choices = ["f32" , "f64" ],
437+ help = "Datatype for model state variables" ,
438+ )
440439 args = parser .parse_args ()
441- run (args .resolution , args .backend , args .benchmark_mode , args .correctness_test )
440+ run (
441+ args .resolution ,
442+ args .backend ,
443+ args .datatype ,
444+ args .benchmark_mode ,
445+ )
0 commit comments