diff --git a/simpeg_drivers/electricals/direct_current/pseudo_three_dimensions/driver.py b/simpeg_drivers/electricals/direct_current/pseudo_three_dimensions/driver.py index b34a0e21..0ea8333f 100644 --- a/simpeg_drivers/electricals/direct_current/pseudo_three_dimensions/driver.py +++ b/simpeg_drivers/electricals/direct_current/pseudo_three_dimensions/driver.py @@ -27,7 +27,7 @@ class DCBatch2DForwardDriver(BaseBatch2DDriver): _params_class = DCBatch2DForwardOptions _params_2d_class = DC2DForwardOptions - _validations = {} + _validations = None class DCBatch2DInversionDriver(BaseBatch2DDriver): @@ -35,4 +35,4 @@ class DCBatch2DInversionDriver(BaseBatch2DDriver): _params_class = DCBatch2DInversionOptions _params_2d_class = DC2DInversionOptions - _validations = {} + _validations = None diff --git a/simpeg_drivers/electricals/direct_current/pseudo_three_dimensions/params.py b/simpeg_drivers/electricals/direct_current/pseudo_three_dimensions/params.py index 7f0901b6..6d8e6dd2 100644 --- a/simpeg_drivers/electricals/direct_current/pseudo_three_dimensions/params.py +++ b/simpeg_drivers/electricals/direct_current/pseudo_three_dimensions/params.py @@ -11,6 +11,7 @@ from __future__ import annotations +from pathlib import Path from typing import ClassVar from geoh5py.data import FloatData @@ -41,7 +42,7 @@ class DCBatch2DForwardOptions(BaseForwardOptions): name: ClassVar[str] = "Direct Current Pseudo 3D Forward" title: ClassVar[str] = "Direct Current (DC) 2D Batch Forward" - default_ui_json: ClassVar[str] = ( + default_ui_json: ClassVar[Path] = ( assets_path() / "uijson/direct_current_batch2d_forward.ui.json" ) @@ -76,7 +77,7 @@ class DCBatch2DInversionOptions(BaseInversionOptions): name: ClassVar[str] = "Direct Current Pseudo 3D Inversion" title: ClassVar[str] = "Direct Current (DC) 2D Batch Inversion" - default_ui_json: ClassVar[str] = ( + default_ui_json: ClassVar[Path] = ( assets_path() / "uijson/direct_current_batch2d_inversion.ui.json" ) diff --git a/simpeg_drivers/electricals/direct_current/three_dimensions/driver.py b/simpeg_drivers/electricals/direct_current/three_dimensions/driver.py index fdda2261..a714498d 100644 --- a/simpeg_drivers/electricals/direct_current/three_dimensions/driver.py +++ b/simpeg_drivers/electricals/direct_current/three_dimensions/driver.py @@ -20,11 +20,11 @@ class DC3DForwardDriver(InversionDriver): """Direct Current 3D forward driver.""" _params_class = DC3DForwardOptions - _validation = {} + _validation = None class DC3DInversionDriver(InversionDriver): """Direct Current 3D inversion driver.""" _params_class = DC3DInversionOptions - _validation = {} + _validation = None diff --git a/simpeg_drivers/electricals/direct_current/two_dimensions/driver.py b/simpeg_drivers/electricals/direct_current/two_dimensions/driver.py index 0f5a6561..effb412c 100644 --- a/simpeg_drivers/electricals/direct_current/two_dimensions/driver.py +++ b/simpeg_drivers/electricals/direct_current/two_dimensions/driver.py @@ -20,11 +20,11 @@ class DC2DForwardDriver(Base2DDriver): """Direct Current 2D forward driver.""" _params_class = DC2DForwardOptions - _validations = {} + _validations = None class DC2DInversionDriver(Base2DDriver): """Direct Current 2D inversion driver.""" _params_class = DC2DInversionOptions - _validations = {} + _validations = None diff --git a/simpeg_drivers/electricals/direct_current/two_dimensions/params.py b/simpeg_drivers/electricals/direct_current/two_dimensions/params.py index 2875ec77..b0206fe4 100644 --- a/simpeg_drivers/electricals/direct_current/two_dimensions/params.py +++ b/simpeg_drivers/electricals/direct_current/two_dimensions/params.py @@ -11,6 +11,7 @@ from __future__ import annotations +from pathlib import Path from typing import ClassVar from geoh5py.data import DataAssociationEnum, FloatData, ReferencedData @@ -38,7 +39,7 @@ class DC2DForwardOptions(BaseForwardOptions): name: ClassVar[str] = "Direct Current 2D Forward" title: ClassVar[str] = "Direct Current 2D Forward" - default_ui_json: ClassVar[str] = ( + default_ui_json: ClassVar[Path] = ( assets_path() / "uijson/direct_current_2d_forward.ui.json" ) @@ -66,7 +67,7 @@ class DC2DInversionOptions(BaseInversionOptions): name: ClassVar[str] = "Direct Current 2D Inversion" title: ClassVar[str] = "Direct Current 2D Inversion" - default_ui_json: ClassVar[str] = ( + default_ui_json: ClassVar[Path] = ( assets_path() / "uijson/direct_current_2d_inversion.ui.json" ) diff --git a/simpeg_drivers/electricals/driver.py b/simpeg_drivers/electricals/driver.py index d27fb85a..9058c604 100644 --- a/simpeg_drivers/electricals/driver.py +++ b/simpeg_drivers/electricals/driver.py @@ -77,9 +77,9 @@ def create_drape_mesh(self) -> DrapeModel: class BaseBatch2DDriver(LineSweepDriver): """Base class for batch 2D DC and IP forward and inversion drivers.""" - _params_class: type(BaseForwardOptions, BaseInversionOptions) - _params_2d_class: type(BaseForwardOptions, BaseInversionOptions) - _validations: dict + _params_class: type[BaseForwardOptions | BaseInversionOptions] + _params_2d_class: type[BaseForwardOptions | BaseInversionOptions] + _validations = None _model_list: list[str] = [] def __init__(self, params): diff --git a/simpeg_drivers/electricals/induced_polarization/pseudo_three_dimensions/driver.py b/simpeg_drivers/electricals/induced_polarization/pseudo_three_dimensions/driver.py index b09813a3..2347059c 100644 --- a/simpeg_drivers/electricals/induced_polarization/pseudo_three_dimensions/driver.py +++ b/simpeg_drivers/electricals/induced_polarization/pseudo_three_dimensions/driver.py @@ -27,7 +27,7 @@ class IPBatch2DForwardDriver(BaseBatch2DDriver): _params_class = IPBatch2DForwardOptions _params_2d_class = IP2DForwardOptions - _validations = {} + _validations = None _model_list = ["conductivity_model"] @@ -36,5 +36,5 @@ class IPBatch2DInversionDriver(BaseBatch2DDriver): _params_class = IPBatch2DInversionOptions _params_2d_class = IP2DInversionOptions - _validations = {} + _validations = None _model_list = ["conductivity_model"] diff --git a/simpeg_drivers/electricals/induced_polarization/pseudo_three_dimensions/params.py b/simpeg_drivers/electricals/induced_polarization/pseudo_three_dimensions/params.py index 95085371..c465b0e8 100644 --- a/simpeg_drivers/electricals/induced_polarization/pseudo_three_dimensions/params.py +++ b/simpeg_drivers/electricals/induced_polarization/pseudo_three_dimensions/params.py @@ -11,6 +11,7 @@ from __future__ import annotations +from pathlib import Path from typing import ClassVar from geoh5py.data import FloatData @@ -41,7 +42,7 @@ class IPBatch2DForwardOptions(BaseForwardOptions): name: ClassVar[str] = "Induced Polarization Pseudo 3D Forward" title: ClassVar[str] = "Induced Polarization (IP) 2D Batch Forward" - default_ui_json: ClassVar[str] = ( + default_ui_json: ClassVar[Path] = ( assets_path() / "uijson/induced_polarization_batch2d_forward.ui.json" ) @@ -75,7 +76,7 @@ class IPBatch2DInversionOptions(BaseInversionOptions): name: ClassVar[str] = "Induced Polarization Pseudo 3D Inversion" title: ClassVar[str] = "Induced Polarization (IP) 2D Batch Inversion" - default_ui_json: ClassVar[str] = ( + default_ui_json: ClassVar[Path] = ( assets_path() / "uijson/induced_polarization_batch2d_inversion.ui.json" ) diff --git a/simpeg_drivers/electricals/induced_polarization/three_dimensions/driver.py b/simpeg_drivers/electricals/induced_polarization/three_dimensions/driver.py index 9cb0ba58..eef7788e 100644 --- a/simpeg_drivers/electricals/induced_polarization/three_dimensions/driver.py +++ b/simpeg_drivers/electricals/induced_polarization/three_dimensions/driver.py @@ -23,11 +23,11 @@ class IP3DForwardDriver(InversionDriver): """Induced Polarization 3D forward driver.""" _params_class = IP3DForwardOptions - _validations = {} + _validations = None class IP3DInversionDriver(InversionDriver): """Induced Polarization 3D inversion driver.""" _params_class = IP3DInversionOptions - _validations = {} + _validations = None diff --git a/simpeg_drivers/electricals/induced_polarization/two_dimensions/driver.py b/simpeg_drivers/electricals/induced_polarization/two_dimensions/driver.py index 09345225..b46167fd 100644 --- a/simpeg_drivers/electricals/induced_polarization/two_dimensions/driver.py +++ b/simpeg_drivers/electricals/induced_polarization/two_dimensions/driver.py @@ -23,11 +23,11 @@ class IP2DForwardDriver(Base2DDriver): """Induced Polarization 2D forward driver.""" _params_class = IP2DForwardOptions - _validations = {} + _validations = None class IP2DInversionDriver(Base2DDriver): """Induced Polarization 2D inversion driver.""" _params_class = IP2DInversionOptions - _validations = {} + _validations = None diff --git a/simpeg_drivers/electricals/induced_polarization/two_dimensions/params.py b/simpeg_drivers/electricals/induced_polarization/two_dimensions/params.py index af179692..5559dec2 100644 --- a/simpeg_drivers/electricals/induced_polarization/two_dimensions/params.py +++ b/simpeg_drivers/electricals/induced_polarization/two_dimensions/params.py @@ -11,6 +11,7 @@ from __future__ import annotations +from pathlib import Path from typing import ClassVar from geoh5py.data import FloatData @@ -37,7 +38,7 @@ class IP2DForwardOptions(BaseForwardOptions): name: ClassVar[str] = "Induced Polarization 2D Forward" title: ClassVar[str] = "Induced Polarization 2D Forward" - default_ui_json: ClassVar[str] = ( + default_ui_json: ClassVar[Path] = ( assets_path() / "uijson/induced_polarization_2d_forward.ui.json" ) @@ -67,7 +68,7 @@ class IP2DInversionOptions(BaseInversionOptions): name: ClassVar[str] = "Induced Polarization 2D Inversion" title: ClassVar[str] = "Induced Polarization 2D Inversion" - default_ui_json: ClassVar[str] = ( + default_ui_json: ClassVar[Path] = ( assets_path() / "uijson/induced_polarization_2d_inversion.ui.json" ) diff --git a/simpeg_drivers/electromagnetics/frequency_domain/driver.py b/simpeg_drivers/electromagnetics/frequency_domain/driver.py index e517f174..19fbe0cb 100644 --- a/simpeg_drivers/electromagnetics/frequency_domain/driver.py +++ b/simpeg_drivers/electromagnetics/frequency_domain/driver.py @@ -23,7 +23,7 @@ class FDEMForwardDriver(InversionDriver): """Frequency Domain Electromagnetic forward driver.""" _params_class = FDEMForwardOptions - _validations = {} + _validations = None def __init__(self, params: FDEMForwardOptions): super().__init__(params) @@ -33,4 +33,4 @@ class FDEMInversionDriver(InversionDriver): """Frequency Domain Electromagnetic inversion driver.""" _params_class = FDEMInversionOptions - _validations = {} + _validations = None diff --git a/simpeg_drivers/electromagnetics/time_domain/driver.py b/simpeg_drivers/electromagnetics/time_domain/driver.py index 17e840de..2c2ff36b 100644 --- a/simpeg_drivers/electromagnetics/time_domain/driver.py +++ b/simpeg_drivers/electromagnetics/time_domain/driver.py @@ -29,7 +29,7 @@ class TDEMForwardDriver(InversionDriver): """Time Domain Electromagnetic forward driver.""" _params_class = TDEMForwardOptions - _validations = {} + _validations = None def get_tiles(self) -> list[np.ndarray]: """ @@ -88,7 +88,7 @@ class TDEMInversionDriver(InversionDriver): """Time Domain Electromagnetic inversion driver.""" _params_class = TDEMInversionOptions - _validations = {} + _validations = None def get_tiles(self) -> list[np.ndarray]: """ diff --git a/simpeg_drivers/joint/constants.py b/simpeg_drivers/joint/constants.py deleted file mode 100644 index d61346a6..00000000 --- a/simpeg_drivers/joint/constants.py +++ /dev/null @@ -1,89 +0,0 @@ -# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' -# Copyright (c) 2025 Mira Geoscience Ltd. ' -# ' -# This file is part of simpeg-drivers package. ' -# ' -# simpeg-drivers is distributed under the terms and conditions of the MIT License ' -# (see LICENSE file at the root of this source code package). ' -# ' -# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' - - -from __future__ import annotations - -from simpeg_drivers import default_ui_json as base_default_ui_json -from simpeg_drivers.constants import validations as base_validations - - -default_ui_json = { - "title": "SimPEG Joint Surveys Inversion", - "inversion_type": "joint surveys", - "mesh": { - "group": "Mesh and Models", - "main": True, - "label": "Mesh", - "meshType": "4EA87376-3ECE-438B-BF12-3479733DED46", - "value": None, - "enabled": False, - "optional": True, - }, - "group_a": { - "main": True, - "group": "Joint", - "label": "Group A", - "groupType": "{55ed3daf-c192-4d4b-a439-60fa987fe2b8}", - "value": "", - }, - "group_a_multiplier": { - "min": 0.0, - "main": True, - "group": "Joint", - "label": "Misfit A Scale", - "value": 1.0, - "tooltip": "Constant multiplier for the data misfit function for Group A.", - }, - "group_b": { - "main": True, - "group": "Joint", - "label": "Group B", - "groupType": "{55ed3daf-c192-4d4b-a439-60fa987fe2b8}", - "value": "", - }, - "group_b_multiplier": { - "min": 0.0, - "main": True, - "group": "Joint", - "label": "Misfit B Scale", - "value": 1.0, - "tooltip": "Constant multiplier for the data misfit function for Group B.", - }, - "group_c": { - "main": True, - "group": "Joint", - "label": "Group C", - "groupType": "{55ed3daf-c192-4d4b-a439-60fa987fe2b8}", - "optional": True, - "enabled": False, - "value": "", - }, - "group_c_multiplier": { - "min": 0.0, - "main": True, - "group": "Joint", - "label": "Misfit C Scale", - "value": 1.0, - "dependency": "group_c", - "dependencyType": "enabled", - "tooltip": "Constant multiplier for the data misfit function for Group C.", - }, -} -default_ui_json = dict(base_default_ui_json, **default_ui_json) -validations = { - "inversion_type": { - "required": True, - "values": ["joint surveys"], - }, -} - -validations = dict(base_validations, **validations) -app_initializer = {} diff --git a/simpeg_drivers/joint/driver.py b/simpeg_drivers/joint/driver.py index bb721fe7..33788924 100644 --- a/simpeg_drivers/joint/driver.py +++ b/simpeg_drivers/joint/driver.py @@ -29,12 +29,12 @@ from simpeg_drivers.components.factories import SaveDataGeoh5Factory from simpeg_drivers.driver import InversionDriver -from simpeg_drivers.joint.params import BaseJointParams +from simpeg_drivers.joint.params import BaseJointOptions from simpeg_drivers.utils.utils import simpeg_group_to_driver class BaseJointDriver(InversionDriver): - def __init__(self, params: BaseJointParams): + def __init__(self, params: BaseJointOptions): self._directives = None self._drivers = None self._wires = None diff --git a/simpeg_drivers/joint/joint_cross_gradient/__init__.py b/simpeg_drivers/joint/joint_cross_gradient/__init__.py index e46957ca..0d9bf0f2 100644 --- a/simpeg_drivers/joint/joint_cross_gradient/__init__.py +++ b/simpeg_drivers/joint/joint_cross_gradient/__init__.py @@ -9,7 +9,7 @@ # ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' -from .params import JointCrossGradientParams +from .params import JointCrossGradientOptions # pylint: disable=unused-import # flake8: noqa diff --git a/simpeg_drivers/joint/joint_cross_gradient/constants.py b/simpeg_drivers/joint/joint_cross_gradient/constants.py index 0991cbc7..280fa2d2 100644 --- a/simpeg_drivers/joint/joint_cross_gradient/constants.py +++ b/simpeg_drivers/joint/joint_cross_gradient/constants.py @@ -12,8 +12,6 @@ from __future__ import annotations import simpeg_drivers -from simpeg_drivers.constants import validations as base_validations -from simpeg_drivers.joint.constants import default_ui_json as joint_default_ui_json inversion_defaults = { @@ -79,128 +77,3 @@ "generate_sweep": False, "distributed_workers": None, } -default_ui_json = { - "title": "SimPEG Joint Cross Gradient Inversion", - "inversion_type": "joint surveys", - "cross_gradient_weight_a_b": { - "min": 0.0, - "group": "Joint", - "label": "A x B Coupling Scale", - "value": 1.0, - "main": True, - "lineEdit": False, - "tooltip": "Weight applied to the cross gradient regularizations (1: equal weight with the standard Smallness and Smoothness terms.)", - }, - "cross_gradient_weight_c_a": { - "min": 0.0, - "group": "Joint", - "label": "A x C Coupling Scale", - "value": 1.0, - "main": True, - "lineEdit": False, - "dependency": "group_c", - "dependencyType": "enabled", - "tooltip": "Weight applied to the cross gradient regularizations (1: equal weight with the standard Smallness and Smoothness terms.)", - }, - "cross_gradient_weight_c_b": { - "min": 0.0, - "group": "Joint", - "label": "B x C Coupling Scale", - "value": 1.0, - "main": True, - "lineEdit": False, - "dependency": "group_c", - "dependencyType": "enabled", - "tooltip": "Weight applied to the cross gradient regularizations (1: equal weight with the standard Smallness and Smoothness terms.)", - }, - "alpha_s": { - "min": 0.0, - "group": "Regularization", - "groupOptional": True, - "label": "Smallness weight", - "value": 1.0, - "tooltip": "Constant ratio compared to other weights. Larger values result in models that remain close to the reference model", - "enabled": False, - }, - "length_scale_x": { - "min": 0.0, - "group": "Regularization", - "label": "X-smoothness weight", - "tooltip": "Larger values relative to other smoothness weights will result in x biased smoothness", - "value": 1.0, - "enabled": False, - }, - "length_scale_y": { - "min": 0.0, - "group": "Regularization", - "label": "Y-smoothness weight", - "tooltip": "Larger values relative to other smoothness weights will result in y biased smoothness", - "value": 1.0, - "enabled": False, - }, - "length_scale_z": { - "min": 0.0, - "group": "Regularization", - "label": "Z-smoothness weight", - "tooltip": "Larger values relative to other smoothness weights will result in z biased smoothess", - "value": 1.0, - "enabled": False, - }, - "s_norm": { - "min": 0.0, - "max": 2.0, - "group": "Regularization", - "label": "Smallness norm", - "value": 0.0, - "precision": 2, - "lineEdit": False, - "enabled": False, - }, - "x_norm": { - "min": 0.0, - "max": 2.0, - "group": "Regularization", - "label": "X-smoothness norm", - "value": 2.0, - "precision": 2, - "lineEdit": False, - "enabled": False, - }, - "y_norm": { - "min": 0.0, - "max": 2.0, - "group": "Regularization", - "label": "Y-smoothness norm", - "value": 2.0, - "precision": 2, - "lineEdit": False, - "enabled": False, - }, - "z_norm": { - "min": 0.0, - "max": 2.0, - "group": "Regularization", - "label": "Z-smoothness norm", - "value": 2.0, - "precision": 2, - "lineEdit": False, - "enabled": False, - }, - "gradient_type": { - "choiceList": ["total", "components"], - "group": "Regularization", - "label": "Gradient type", - "value": "total", - "verbose": 3, - "enabled": False, - }, -} -default_ui_json = dict(joint_default_ui_json, **default_ui_json) -validations = { - "inversion_type": { - "required": True, - "values": ["joint cross gradient"], - }, -} -validations = dict(base_validations, **validations) -app_initializer = {} diff --git a/simpeg_drivers/joint/joint_cross_gradient/driver.py b/simpeg_drivers/joint/joint_cross_gradient/driver.py index 05c3dc40..9eb7bcfe 100644 --- a/simpeg_drivers/joint/joint_cross_gradient/driver.py +++ b/simpeg_drivers/joint/joint_cross_gradient/driver.py @@ -27,15 +27,14 @@ ) from simpeg_drivers.joint.driver import BaseJointDriver -from .constants import validations -from .params import JointCrossGradientParams +from .params import JointCrossGradientOptions class JointCrossGradientDriver(BaseJointDriver): - _params_class = JointCrossGradientParams - _validations = validations + _params_class = JointCrossGradientOptions + _validations = None - def __init__(self, params: JointCrossGradientParams): + def __init__(self, params: JointCrossGradientOptions): self._wires = None self._directives = None diff --git a/simpeg_drivers/joint/joint_cross_gradient/params.py b/simpeg_drivers/joint/joint_cross_gradient/params.py index 5403b2a9..d8bf0cca 100644 --- a/simpeg_drivers/joint/joint_cross_gradient/params.py +++ b/simpeg_drivers/joint/joint_cross_gradient/params.py @@ -11,60 +11,39 @@ from __future__ import annotations -from copy import deepcopy +from pathlib import Path +from typing import ClassVar -from simpeg_drivers.joint.params import BaseJointParams +from geoh5py.data import FloatData +from geoh5py.objects import Octree -from .constants import default_ui_json, inversion_defaults, validations +from simpeg_drivers import assets_path +from simpeg_drivers.joint.params import BaseJointOptions -class JointCrossGradientParams(BaseJointParams): +class JointCrossGradientOptions(BaseJointOptions): """ - Parameter class for joint cross-gradient inversion. + Joint Cross Gradient inversion options. + + :param cross_gradient_weight_a_b: Weight applied to the cross gradient + regularizations. + :param cross_gradient_weight_c_a: Weight applied to the cross gradient + regularizations. + :param cross_gradient_weight_c_b: Weight applied to the cross gradient + regularizations. """ - _physical_property = [""] + name: ClassVar[str] = "Joint Cross Gradient Inversion" + title: ClassVar[str] = "Joint Cross Gradient Inversion" + default_ui_json: ClassVar[Path] = ( + assets_path() / "uijson/joint_cross_gradient_inversion.ui.json" + ) - def __init__(self, input_file=None, forward_only=False, **kwargs): - self._default_ui_json = deepcopy(default_ui_json) - self._inversion_defaults = deepcopy(inversion_defaults) - self._inversion_type = "joint cross gradient" - self._validations = validations - self._cross_gradient_weight_a_b = 1.0 - self._cross_gradient_weight_c_a = 1.0 - self._cross_gradient_weight_c_b = 1.0 + inversion_type: str = "joint cross gradient" - super().__init__(input_file=input_file, forward_only=forward_only, **kwargs) - - @property - def cross_gradient_weight_a_b(self): - return self._cross_gradient_weight_a_b - - @cross_gradient_weight_a_b.setter - def cross_gradient_weight_a_b(self, val): - self.setter_validator("cross_gradient_weight_a_b", val) - - @property - def cross_gradient_weight_c_a(self): - return self._cross_gradient_weight_c_a - - @cross_gradient_weight_c_a.setter - def cross_gradient_weight_c_a(self, val): - self.setter_validator("cross_gradient_weight_c_a", val) - - @property - def cross_gradient_weight_c_b(self): - return self._cross_gradient_weight_c_b - - @cross_gradient_weight_c_b.setter - def cross_gradient_weight_c_b(self, val): - self.setter_validator("cross_gradient_weight_c_b", val) - - @property - def physical_property(self): - """Physical property to invert.""" - return self._physical_property - - @physical_property.setter - def physical_property(self, val: list[str]): - self._physical_property = val + data_object: None = None + mesh: Octree | None = None + starting_model: None = None + cross_gradient_weight_a_b: float = 1.0 + cross_gradient_weight_c_a: float | None = None + cross_gradient_weight_c_b: float | None = None diff --git a/simpeg_drivers/joint/joint_surveys/__init__.py b/simpeg_drivers/joint/joint_surveys/__init__.py index 101990b9..08da2366 100644 --- a/simpeg_drivers/joint/joint_surveys/__init__.py +++ b/simpeg_drivers/joint/joint_surveys/__init__.py @@ -9,7 +9,7 @@ # ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' -from .params import JointSurveysParams +from .params import JointSurveysOptions # pylint: disable=unused-import # flake8: noqa diff --git a/simpeg_drivers/joint/joint_surveys/constants.py b/simpeg_drivers/joint/joint_surveys/constants.py index 7cb55b50..3a9859f9 100644 --- a/simpeg_drivers/joint/joint_surveys/constants.py +++ b/simpeg_drivers/joint/joint_surveys/constants.py @@ -12,8 +12,6 @@ from __future__ import annotations import simpeg_drivers -from simpeg_drivers.constants import validations as base_validations -from simpeg_drivers.joint.constants import default_ui_json as joint_default_ui_json ################# defaults ################## @@ -83,78 +81,3 @@ "generate_sweep": False, "distributed_workers": None, } - -default_ui_json = { - "title": "SimPEG Joint Surveys Inversion", - "inversion_type": "joint surveys", - "model_type": { - "choiceList": ["Conductivity (S/m)", "Resistivity (Ohm-m)"], - "main": True, - "group": "Mesh and Models", - "label": "Model units", - "tooltip": "Select the units of the model.", - "value": "Conductivity (S/m)", - }, - "starting_model": { - "association": "Cell", - "dataType": "Float", - "group": "Mesh and Models", - "main": True, - "isValue": False, - "parent": "mesh", - "label": "Initial model", - "property": None, - "optional": True, - "enabled": False, - "value": 1e-4, - }, - "lower_bound": { - "association": "Cell", - "main": True, - "dataType": "Float", - "group": "Mesh and Models", - "isValue": False, - "parent": "mesh", - "label": "Lower bound)", - "property": None, - "optional": True, - "value": -10.0, - "enabled": False, - }, - "upper_bound": { - "association": "Cell", - "main": True, - "dataType": "Float", - "group": "Mesh and Models", - "isValue": False, - "parent": "mesh", - "label": "Upper bound", - "property": None, - "optional": True, - "value": 10.0, - "enabled": False, - }, - "reference_model": { - "association": "Cell", - "main": True, - "dataType": "Float", - "group": "Mesh and Models", - "isValue": False, - "parent": "mesh", - "label": "Reference", - "property": None, - "optional": True, - "value": 1e-4, - "enabled": False, - }, -} -default_ui_json = dict(joint_default_ui_json, **default_ui_json) -validations = { - "inversion_type": { - "required": True, - "values": ["joint surveys"], - }, -} - -validations = dict(base_validations, **validations) -app_initializer = {} diff --git a/simpeg_drivers/joint/joint_surveys/driver.py b/simpeg_drivers/joint/joint_surveys/driver.py index d2b6a816..288c0ca4 100644 --- a/simpeg_drivers/joint/joint_surveys/driver.py +++ b/simpeg_drivers/joint/joint_surveys/driver.py @@ -20,18 +20,19 @@ from simpeg_drivers.components.factories import DirectivesFactory, SaveModelGeoh5Factory from simpeg_drivers.joint.driver import BaseJointDriver -from .constants import validations -from .params import JointSurveysParams +from .params import JointSurveysOptions logger = getLogger(__name__) class JointSurveyDriver(BaseJointDriver): - _params_class = JointSurveysParams - _validations = validations + """Joint surveys inversion driver""" - def __init__(self, params: JointSurveysParams): + _params_class = JointSurveysOptions + _validations = None + + def __init__(self, params: JointSurveysOptions): super().__init__(params) with fetch_active_workspace(self.workspace, mode="r+"): diff --git a/simpeg_drivers/joint/joint_surveys/params.py b/simpeg_drivers/joint/joint_surveys/params.py index d24f141f..463d4173 100644 --- a/simpeg_drivers/joint/joint_surveys/params.py +++ b/simpeg_drivers/joint/joint_surveys/params.py @@ -11,49 +11,32 @@ from __future__ import annotations -from copy import deepcopy +from pathlib import Path +from typing import ClassVar -from simpeg_drivers.joint.params import BaseJointParams +from pydantic import model_validator -from .constants import default_ui_json, inversion_defaults, validations +from simpeg_drivers import assets_path +from simpeg_drivers.joint.params import BaseJointOptions -class JointSurveysParams(BaseJointParams): - """ - Parameter class for gravity->density inversion. - """ +class JointSurveysOptions(BaseJointOptions): + """Joint Surveys inversion options.""" - _physical_property = "" + name: ClassVar[str] = "Joint Surveys Inversion" + title: ClassVar[str] = "Joint Surveys Inversion" + default_ui_json: ClassVar[Path] = ( + assets_path() / "uijson/joint_surveys_inversion.ui.json" + ) - def __init__(self, input_file=None, forward_only=False, **kwargs): - self._default_ui_json = deepcopy(default_ui_json) - self._inversion_defaults = deepcopy(inversion_defaults) - self._inversion_type = "joint surveys" - self._validations = validations - self._model_type = "Conductivity (S/m)" + inversion_type: str = "joint surveys" - super().__init__(input_file=input_file, forward_only=forward_only, **kwargs) - - @property - def model_type(self): - """Model units.""" - return self._model_type - - @model_type.setter - def model_type(self, val): - self.setter_validator("model_type", val) - - @property - def physical_property(self): - """Physical property to invert.""" - return self._physical_property - - @physical_property.setter - def physical_property(self, val: list[str]): - unique_properties = list(set(val)) - if len(unique_properties) > 1: + @model_validator(mode="after") + def all_groups_same_physical_property(self): + physical_properties = [k.options["physical_property"] for k in self.groups] + if len(list(set(physical_properties))) > 1: raise ValueError( "All physical properties must be the same. " - f"Provided SimPEG groups for {unique_properties}." + f"Provided SimPEG groups for {physical_properties}." ) - self._physical_property = unique_properties[0] + return self diff --git a/simpeg_drivers/joint/params.py b/simpeg_drivers/joint/params.py index 1cb68812..a74fcfb6 100644 --- a/simpeg_drivers/joint/params.py +++ b/simpeg_drivers/joint/params.py @@ -13,92 +13,31 @@ from geoh5py.groups import SimPEGGroup -from simpeg_drivers.params import InversionBaseParams +from simpeg_drivers.params import BaseInversionOptions -class BaseJointParams(InversionBaseParams): +class BaseJointOptions(BaseInversionOptions): """ - Parameter class for gravity->density inversion. + Base Joint Options. + + :param group_a: First SimPEGGroup with options set for inversion. + :param group_a_multiplier: Multiplier for the data misfit function for Group A. + :param group_b: Second SimPEGGroup with options set for inversion. + :param group_b_multiplier: Multiplier for the data misfit function for Group B. + :param group_c: Third SimPEGGroup with options set for inversion. + :param group_c_multiplier: Multiplier for the data misfit function for Group C. """ - _physical_property = "" - - def __init__(self, input_file=None, forward_only=False, **kwargs): - self._group_a = None - self._group_b = None - self._group_c = None - self._group_a_multiplier = None - self._group_b_multiplier = None - self._group_c_multiplier = None - - super().__init__(input_file=input_file, forward_only=forward_only, **kwargs) + physical_property: str = "" + data_object: None = None + group_a: SimPEGGroup + group_a_multiplier: float = 1.0 + group_b: SimPEGGroup + group_b_multiplier: float = 1.0 + group_c: SimPEGGroup | None = None + group_c_multiplier: float | None = None @property - def groups(self): + def groups(self) -> list[SimPEGGroup]: """List all active groups.""" return [k for k in [self.group_a, self.group_b, self.group_c] if k is not None] - - @property - def group_a(self): - """First SimPEGGroup inversion.""" - return self._group_a - - @group_a.setter - def group_a(self, val: SimPEGGroup): - self.setter_validator("group_a", val, fun=self._uuid_promoter) - - @property - def group_a_multiplier(self): - """Multiplier for the first SimPEGGroup inversion.""" - return self._group_a_multiplier - - @group_a_multiplier.setter - def group_a_multiplier(self, value): - self.setter_validator("group_a_multiplier", value) - - @property - def group_b(self): - """Second SimPEGGroup inversion.""" - return self._group_b - - @group_b.setter - def group_b(self, val: SimPEGGroup): - self.setter_validator("group_b", val, fun=self._uuid_promoter) - - @property - def group_b_multiplier(self): - """Multiplier for the second SimPEGGroup inversion.""" - return self._group_b_multiplier - - @group_b_multiplier.setter - def group_b_multiplier(self, value): - self.setter_validator("group_b_multiplier", value) - - @property - def group_c(self): - """Third SimPEGGroup inversion.""" - return self._group_c - - @group_c.setter - def group_c(self, val: SimPEGGroup): - self.setter_validator("group_c", val, fun=self._uuid_promoter) - - @property - def group_c_multiplier(self): - """Multiplier for the third SimPEGGroup inversion.""" - return self._group_c_multiplier - - @group_c_multiplier.setter - def group_c_multiplier(self, value): - self.setter_validator("group_c_multiplier", value) - - @property - def forward_only(self): - return self._forward_only - - @forward_only.setter - def forward_only(self, val: bool): - if val: - raise ValueError("Joint inversion does not support forward only.") - - self.setter_validator("forward_only", val) diff --git a/simpeg_drivers/natural_sources/magnetotellurics/driver.py b/simpeg_drivers/natural_sources/magnetotellurics/driver.py index 9a8e56ac..609c3b56 100644 --- a/simpeg_drivers/natural_sources/magnetotellurics/driver.py +++ b/simpeg_drivers/natural_sources/magnetotellurics/driver.py @@ -20,11 +20,11 @@ class MTForwardDriver(InversionDriver): """Magnetotellurics forward driver.""" _params_class = MTForwardOptions - _validations = {} + _validations = None class MTInversionDriver(InversionDriver): """Magnetotellurics inversion driver.""" _params_class = MTInversionOptions - _validations = {} + _validations = None diff --git a/simpeg_drivers/natural_sources/tipper/driver.py b/simpeg_drivers/natural_sources/tipper/driver.py index 6b2a836c..a79774fd 100644 --- a/simpeg_drivers/natural_sources/tipper/driver.py +++ b/simpeg_drivers/natural_sources/tipper/driver.py @@ -20,11 +20,11 @@ class TipperForwardDriver(InversionDriver): """Tipper forward driver.""" _params_class = TipperForwardOptions - _validations = {} + _validations = None class TipperInversionDriver(InversionDriver): """Tipper inversion driver.""" _params_class = TipperInversionOptions - _validations = {} + _validations = None diff --git a/simpeg_drivers/params.py b/simpeg_drivers/params.py index f43a69d1..0bda5c67 100644 --- a/simpeg_drivers/params.py +++ b/simpeg_drivers/params.py @@ -23,7 +23,7 @@ from geoapps_utils.driver.params import BaseParams from geoh5py.data import BooleanData, FloatData, NumericData from geoh5py.groups import PropertyGroup, SimPEGGroup, UIJsonGroup -from geoh5py.objects import Octree, Points +from geoh5py.objects import DrapeModel, Octree, Points from geoh5py.shared.utils import fetch_active_workspace from geoh5py.ui_json import InputFile from pydantic import BaseModel, ConfigDict, field_validator, model_validator @@ -99,7 +99,7 @@ class CoreOptions(BaseData): physical_property: str data_object: Points z_from_topo: bool = False - mesh: Octree | None + mesh: Octree | DrapeModel | None starting_model: float | FloatData active_cells: ActiveCellsOptions tile_spatial: int = 1 @@ -322,12 +322,12 @@ class BaseInversionOptions(CoreOptions): alpha_s: float | FloatData | None = 1.0 length_scale_x: float | FloatData = 1.0 - length_scale_y: float | FloatData = 1.0 + length_scale_y: float | FloatData | None = 1.0 length_scale_z: float | FloatData = 1.0 s_norm: float | FloatData | None = 0.0 x_norm: float | FloatData = 2.0 - y_norm: float | FloatData = 2.0 + y_norm: float | FloatData | None = 2.0 z_norm: float | FloatData = 2.0 gradient_type: str = "total" max_irls_iterations: int = 25 diff --git a/simpeg_drivers/potential_fields/gravity/driver.py b/simpeg_drivers/potential_fields/gravity/driver.py index 2eb81e52..805a77f3 100644 --- a/simpeg_drivers/potential_fields/gravity/driver.py +++ b/simpeg_drivers/potential_fields/gravity/driver.py @@ -22,11 +22,11 @@ class GravityForwardDriver(InversionDriver): """Gravity forward driver.""" _params_class = GravityForwardOptions - _validation = {} + _validations = None class GravityInversionDriver(InversionDriver): """Gravity inversion driver.""" _params_class = GravityInversionOptions - _validations = {} + _validations = None diff --git a/simpeg_drivers/potential_fields/magnetic_scalar/driver.py b/simpeg_drivers/potential_fields/magnetic_scalar/driver.py index 36daa6c8..c546212c 100644 --- a/simpeg_drivers/potential_fields/magnetic_scalar/driver.py +++ b/simpeg_drivers/potential_fields/magnetic_scalar/driver.py @@ -22,11 +22,11 @@ class MagneticForwardDriver(InversionDriver): """Magnetic forward driver.""" _params_class = MagneticForwardOptions - _validations = {} + _validations = None class MagneticInversionDriver(InversionDriver): """Magnetic inversion driver.""" _params_class = MagneticInversionOptions - _validations = {} + _validations = None diff --git a/simpeg_drivers/potential_fields/magnetic_vector/driver.py b/simpeg_drivers/potential_fields/magnetic_vector/driver.py index 57d1a5ef..2db7ba5c 100644 --- a/simpeg_drivers/potential_fields/magnetic_vector/driver.py +++ b/simpeg_drivers/potential_fields/magnetic_vector/driver.py @@ -22,14 +22,14 @@ class MVIForwardDriver(InversionDriver): """Magnetic Vector forward driver.""" _params_class = MVIForwardOptions - _validations = {} + _validations = None class MVIInversionDriver(InversionDriver): """Magnetic Vector inversion driver.""" _params_class = MVIInversionOptions - _validations = {} + _validations = None @property def mapping(self) -> list[maps.Projection] | None: diff --git a/simpeg_drivers/utils/write_default_uijson.py b/simpeg_drivers/utils/write_default_uijson.py deleted file mode 100644 index b442fb83..00000000 --- a/simpeg_drivers/utils/write_default_uijson.py +++ /dev/null @@ -1,99 +0,0 @@ -# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' -# Copyright (c) 2025 Mira Geoscience Ltd. ' -# ' -# This file is part of simpeg-drivers package. ' -# ' -# simpeg-drivers is distributed under the terms and conditions of the MIT License ' -# (see LICENSE file at the root of this source code package). ' -# ' -# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' - - -from __future__ import annotations - -import argparse -from pathlib import Path - -from simpeg_drivers.joint.joint_cross_gradient import JointCrossGradientParams -from simpeg_drivers.joint.joint_surveys import JointSurveysParams - - -active_data_channels = [ - "z_real_channel", - "z_imag_channel", - "zxx_real_channel", - "zxx_imag_channel", - "zxy_real_channel", - "zxy_imag_channel", - "zyx_real_channel", - "zyx_imag_channel", - "zyy_real_channel", - "zyy_imag_channel", - "txz_real_channel", - "txz_imag_channel", - "tyz_real_channel", - "tyz_imag_channel", - "gz_channel", - "tmi_channel", - "z_channel", -] - - -def write_default_uijson(path: str | Path): - filedict = { - "joint_surveys_inversion.ui.json": JointSurveysParams( - forward_only=False, validate=False - ), - "joint_cross_gradient_inversion.ui.json": JointCrossGradientParams( - forward_only=False, validate=False - ), - } - - for filename, params in filedict.items(): - validation_options = { - "update_enabled": (True if params.geoh5 is not None else False) - } - params.input_file.validation_options = validation_options - if hasattr(params, "forward_only"): - if params.forward_only: - for form in params.input_file.ui_json.values(): - if isinstance(form, dict): - group = form.get("group", None) - if group == "Data": - form["group"] = "Survey" - for param in [ - "starting_model", - "starting_inclination", - "starting_declination", - ]: - if param in params.input_file.ui_json: - form = params.input_file.ui_json[param] - - # Exception for forward sigma models - if "model_type" in params.input_file.ui_json: - form["label"] = "Value(s)" - else: - form["label"] = ( - form["label"].replace("Initial ", "").capitalize() - ) - elif params.data_object is None: - for channel in active_data_channels: - form = params.input_file.ui_json.get(channel, None) - if form: - form["enabled"] = True - - ifile = params.input_file - ifile.validate = False - ifile.ui_json["topography_object"]["enabled"] = True - ifile.write_ui_json(name=filename, path=path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Write defaulted ui.json files.") - parser.add_argument( - "path", - type=Path, - help="Path to folder where default ui.json files will be written.", - ) - args = parser.parse_args() - write_default_uijson(args.path) diff --git a/tests/run_tests/driver_joint_cross_gradient_test.py b/tests/run_tests/driver_joint_cross_gradient_test.py index 0c2d341c..6a99e3ad 100644 --- a/tests/run_tests/driver_joint_cross_gradient_test.py +++ b/tests/run_tests/driver_joint_cross_gradient_test.py @@ -23,7 +23,7 @@ DC3DForwardDriver, DC3DInversionDriver, ) -from simpeg_drivers.joint.joint_cross_gradient import JointCrossGradientParams +from simpeg_drivers.joint.joint_cross_gradient import JointCrossGradientOptions from simpeg_drivers.joint.joint_cross_gradient.driver import JointCrossGradientDriver from simpeg_drivers.params import ActiveCellsOptions from simpeg_drivers.potential_fields import ( @@ -68,11 +68,9 @@ def test_joint_cross_gradient_fwr_run( ) active_cells = ActiveCellsOptions(topography_object=topography) params = GravityForwardOptions( - forward_only=True, geoh5=geoh5, mesh=model.parent, active_cells=active_cells, - resolution=0.0, z_from_topo=False, data_object=survey, starting_model=model, @@ -98,12 +96,10 @@ def test_joint_cross_gradient_fwr_run( inducing_field_strength=inducing_field[0], inducing_field_inclination=inducing_field[1], inducing_field_declination=inducing_field[2], - resolution=0.0, z_from_topo=False, data_object=survey, starting_model=model, ) - # params.workpath = tmp_path fwr_driver_b = MVIForwardDriver(params) _, _, model, survey, _ = setup_inversion_workspace( @@ -239,14 +235,15 @@ def test_joint_cross_gradient_inv_run( drivers.append(MVIInversionDriver(params)) # Run the inverse - joint_params = JointCrossGradientParams( + joint_params = JointCrossGradientOptions( geoh5=geoh5, - topography_object=topography.uid, + active_cells=ActiveCellsOptions(topography_object=topography), group_a=drivers[0].params.out_group, group_a_multiplier=1.0, group_b=drivers[1].params.out_group, group_b_multiplier=1.0, group_c=drivers[2].params.out_group, + group_c_multiplier=1.0, max_global_iterations=max_iterations, initial_beta_ratio=1e1, cross_gradient_weight_a_b=1e0, diff --git a/tests/run_tests/driver_joint_surveys_test.py b/tests/run_tests/driver_joint_surveys_test.py index ae95c4cb..ba26aa11 100644 --- a/tests/run_tests/driver_joint_surveys_test.py +++ b/tests/run_tests/driver_joint_surveys_test.py @@ -14,7 +14,7 @@ from geoh5py.objects import Octree from geoh5py.workspace import Workspace -from simpeg_drivers.joint.joint_surveys import JointSurveysParams +from simpeg_drivers.joint.joint_surveys import JointSurveysOptions from simpeg_drivers.joint.joint_surveys.driver import JointSurveyDriver from simpeg_drivers.params import ActiveCellsOptions from simpeg_drivers.potential_fields import ( @@ -48,11 +48,9 @@ def test_joint_surveys_fwr_run( ) active_cells = ActiveCellsOptions(topography_object=topography) params = GravityForwardOptions( - forward_only=True, geoh5=geoh5, mesh=model.parent, active_cells=active_cells, - resolution=0.0, z_from_topo=False, data_object=survey, starting_model=model, @@ -74,11 +72,9 @@ def test_joint_surveys_fwr_run( ) active_cells = ActiveCellsOptions(topography_object=topography) params = GravityForwardOptions( - forward_only=True, geoh5=geoh5, mesh=model.parent, active_cells=active_cells, - resolution=0.0, z_from_topo=False, data_object=survey, starting_model=model, @@ -144,9 +140,9 @@ def test_joint_surveys_inv_run( active_model = drivers[0].params.mesh.get_entity("active_cells")[0] # Run the inverse - joint_params = JointSurveysParams( + joint_params = JointSurveysOptions( geoh5=geoh5, - activate_model=active_model, + active_cells=ActiveCellsOptions(active_model=active_model), mesh=drivers[0].params.mesh, group_a=drivers[0].params.out_group, group_b=drivers[1].params.out_group,