@@ -31,6 +31,7 @@ def nearest_brenier_potential_fit(
3131 its = 100 ,
3232 log = False ,
3333 init_method = "barycentric" ,
34+ solver = None ,
3435):
3536 r"""
3637 Computes optimal values and gradients at X for a strongly convex potential :math:`\varphi` with Lipschitz gradients
@@ -87,7 +88,10 @@ def nearest_brenier_potential_fit(
8788 log : bool, optional
8889 record log if true
8990 init_method : str, optional
90- 'target' initialises G=V, 'barycentric' initialises at the image of X by the barycentric projection
91+ 'target' initialises G=V, 'barycentric' initialises at the image of X by
92+ the barycentric projection
93+ solver : str, optional
94+ The CVXPY solver to use
9195
9296 Returns
9397 -------
@@ -173,7 +177,7 @@ def nearest_brenier_potential_fit(
173177 - c3 * (G [j ] - G [i ]).T @ (X [j ] - X [i ])
174178 ]
175179 problem = cvx .Problem (objective , constraints )
176- problem .solve (solver = cvx . ECOS )
180+ problem .solve (solver = solver )
177181 phi_val , G_val = phi .value , G .value
178182 it_log_dict = {
179183 "solve_time" : problem .solver_stats .solve_time ,
@@ -231,6 +235,7 @@ def nearest_brenier_potential_predict_bounds(
231235 strongly_convex_constant = 0.6 ,
232236 gradient_lipschitz_constant = 1.4 ,
233237 log = False ,
238+ solver = None ,
234239):
235240 r"""
236241 Compute the values of the lower and upper bounding potentials at the input points Y, using the potential optimal
@@ -290,6 +295,8 @@ def nearest_brenier_potential_predict_bounds(
290295 constant for the Lipschitz property of the input gradient G, defaults to 1.4
291296 log : bool, optional
292297 record log if true
298+ solver : str, optional
299+ The CVXPY solver to use
293300
294301 Returns
295302 -------
@@ -368,7 +375,7 @@ def nearest_brenier_potential_predict_bounds(
368375 - c3 * (G [j ] - G_l_y ).T @ (X [j ] - Y [y_idx ])
369376 ]
370377 problem = cvx .Problem (objective , constraints )
371- problem .solve (solver = cvx . ECOS )
378+ problem .solve (solver = solver )
372379 phi_lu [0 , y_idx ] = phi_l_y .value
373380 G_lu [0 , y_idx ] = G_l_y .value
374381 if log :
@@ -395,7 +402,7 @@ def nearest_brenier_potential_predict_bounds(
395402 - c3 * (G_u_y - G [i ]).T @ (Y [y_idx ] - X [i ])
396403 ]
397404 problem = cvx .Problem (objective , constraints )
398- problem .solve (solver = cvx . ECOS )
405+ problem .solve (solver = solver )
399406 phi_lu [1 , y_idx ] = phi_u_y .value
400407 G_lu [1 , y_idx ] = G_u_y .value
401408 if log :
0 commit comments