Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions simpeg_drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,43 +46,43 @@ def assets_path() -> Path:
"direct current 3d": (
"simpeg_drivers.electricals.direct_current.three_dimensions.driver",
{
"forward": "DirectCurrent3DForwardDriver",
"inversion": "DirectCurrent3DInversionDriver",
"forward": "DC3DForwardDriver",
"inversion": "DC3DInversionDriver",
},
),
"direct current 2d": (
"simpeg_drivers.electricals.direct_current.two_dimensions.driver",
{
"forward": "DirectCurrent2DForwardDriver",
"inversion": "DirectCurrent2DInversionDriver",
"forward": "DC2DForwardDriver",
"inversion": "DC2DInversionDriver",
},
),
"direct current pseudo 3d": (
"simpeg_drivers.electricals.direct_current.pseudo_three_dimensions.driver",
{
"forward": "DirectCurrentPseudo3DForwardDriver",
"inversion": "DirectCurrentPseudo3DInversionDriver",
"forward": "DCBatch2DForwardDriver",
"inversion": "DCBatch2DInversionDriver",
},
),
"induced polarization 3d": (
"simpeg_drivers.electricals.induced_polarization.three_dimensions.driver",
{
"forward": "InducedPolarization3DForwardDriver",
"inversion": "InducedPolarization3DInversionDriver",
"forward": "IP3DForwardDriver",
"inversion": "IP3DInversionDriver",
},
),
"induced polarization 2d": (
"simpeg_drivers.electricals.induced_polarization.two_dimensions.driver",
{
"forward": "InducedPolarization2DForwardDriver",
"inversion": "InducedPolarization2DInversionDriver",
"forward": "IP2DForwardDriver",
"inversion": "IP2DInversionDriver",
},
),
"induced polarization pseudo 3d": (
"simpeg_drivers.electricals.induced_polarization.pseudo_three_dimensions.driver",
{
"forward": "InducedPolarizationPseudo3DForwardDriver",
"inversion": "InducedPolarizationPseudo3DInversionDriver",
"forward": "IPBatch2DForwardDriver",
"inversion": "IPBatch2DInversionDriver",
},
),
"joint surveys": (
Expand All @@ -91,40 +91,49 @@ def assets_path() -> Path:
),
"fem": (
"simpeg_drivers.electromagnetics.frequency_domain.driver",
{"inversion": "FrequencyDomainElectromagneticsDriver"},
{
"forward": "FrequenceyDomainElectromagneticsForwardDriver",
"inversion": "FDEMInversionDriver",
},
),
"joint cross gradient": (
"simpeg_drivers.joint.joint_cross_gradient.driver",
{"inversion": "JointCrossGradientDriver"},
),
"tdem": (
"simpeg_drivers.electromagnetics.time_domain.driver",
{"inversion": "TimeDomainElectromagneticsDriver"},
{
"forward": "TDEMForwardDriver",
"inversion": "TDEMInversionDriver",
},
),
"magnetotellurics": (
"simpeg_drivers.natural_sources.magnetotellurics.driver",
{"inversion": "MagnetotelluricsDriver"},
{
"forward": "MTForwardDriver",
"inversion": "MTInversionDriver",
},
),
"tipper": (
"simpeg_drivers.natural_sources.tipper.driver",
{"inversion": "TipperDriver"},
{"forward": "TipperForwardDriver", "inversion": "TipperInversionDriver"},
),
"gravity": (
"simpeg_drivers.potential_fields.gravity.driver",
{"inversion": "GravityInversionDriver", "forward": "GravityForwardDriver"},
{"forward": "GravityForwardDriver", "inversion": "GravityInversionDriver"},
),
"magnetic scalar": (
"simpeg_drivers.potential_fields.magnetic_scalar.driver",
{
"forward": "MagneticScalarForwardDriver",
"inversion": "MagneticScalarInversionDriver",
"forward": "MagneticForwardDriver",
"inversion": "MagneticInversionDriver",
},
),
"magnetic vector": (
"simpeg_drivers.potential_fields.magnetic_vector.driver",
{
"forward": "MagneticScalarForwardDriver",
"inversion": "MagneticVectorInversionDriver",
"forward": "MagneticForwardDriver",
"inversion": "MVIInversionDriver",
},
),
}
19 changes: 10 additions & 9 deletions simpeg_drivers/components/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
if TYPE_CHECKING:
from geoh5py.workspace import Workspace

from simpeg_drivers.params import InversionBaseParams
from simpeg_drivers.params import InversionBaseOptions

from copy import deepcopy
from re import findall
Expand Down Expand Up @@ -78,18 +78,17 @@ class InversionData(InversionLocations):

"""

def __init__(self, workspace: Workspace, params: InversionBaseParams):
def __init__(self, workspace: Workspace, params: InversionBaseOptions):
"""
:param: workspace: :obj`geoh5py.workspace.Workspace` workspace object containing location based data.
:param: params: Params object containing location based data parameters.
:param: params: Options object containing location based data parameters.
"""
super().__init__(workspace, params)
self.locations: np.ndarray | None = None
self.mask: np.ndarray | None = None
self.indices: np.ndarray | None = None
self.vector: bool | None = None
self.n_blocks: int | None = None
self.components: list[str] | None = None
self.observed: dict[str, np.ndarray] = {}
self.predicted: dict[str, np.ndarray] = {}
self.uncertainties: dict[str, np.ndarray] = {}
Expand All @@ -106,8 +105,10 @@ def _initialize(self) -> None:
"""Extract data from the workspace using params data."""
self.vector = True if self.params.inversion_type == "magnetic vector" else False
self.n_blocks = 3 if self.params.inversion_type == "magnetic vector" else 1
self.components, self.observed, self.uncertainties = self.get_data()
self.has_tensor = InversionData.check_tensor(self.components)
self.components = self.params.active_components
self.observed = self.params.data
self.uncertainties = self.params.uncertainties
self.has_tensor = InversionData.check_tensor(self.params.components)
self.locations = super().get_locations(self.params.data_object)

if "2d" in self.params.inversion_type:
Expand Down Expand Up @@ -280,7 +281,7 @@ def normalize(
"""
d = deepcopy(data)
for chan in getattr(self.params.data_object, "channels", [None]):
for comp in self.components:
for comp in self.params.active_components:
if isinstance(d[comp], dict):
if d[comp][chan] is not None:
d[comp][chan] *= self.normalizations[chan][comp]
Expand All @@ -298,7 +299,7 @@ def get_normalizations(self):
normalizations = {}
for chan in getattr(self.params.data_object, "channels", [None]):
normalizations[chan] = {}
for comp in self.components:
for comp in self.params.active_components:
normalizations[chan][comp] = np.ones(self.mask.sum())
if comp in ["potential", "chargeability"]:
normalizations[chan][comp] = 1
Expand Down Expand Up @@ -488,7 +489,7 @@ def survey(self):
@property
def n_data(self):
n_data = 0
for comp in self.components:
for comp in self.params.active_components:
if isinstance(self.observed[comp], dict):
for channel in self.observed[comp]:
n_data += len(self.observed[comp][channel])
Expand Down
6 changes: 3 additions & 3 deletions simpeg_drivers/components/factories/misfit_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from geoapps_utils.driver.params import BaseParams

from simpeg_drivers.components.data import InversionData
from simpeg_drivers.params import BaseOptions

import numpy as np
from geoh5py.objects import Octree
from scipy.sparse import csr_matrix
from simpeg import data, data_misfit, maps, meta, objective_function

from simpeg_drivers.components.factories.simpeg_factory import SimPEGFactory
Expand All @@ -30,9 +30,9 @@
class MisfitFactory(SimPEGFactory):
"""Build SimPEG global misfit function."""

def __init__(self, params: BaseParams, models=None):
def __init__(self, params: BaseParams | BaseOptions, models=None):
"""
:param params: Params object containing SimPEG object parameters.
:param params: Options object containing SimPEG object parameters.
"""
super().__init__(params)
self.simpeg_object = self.concrete_object()
Expand Down
6 changes: 4 additions & 2 deletions simpeg_drivers/components/factories/receiver_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
if TYPE_CHECKING:
from geoapps_utils.driver.params import BaseParams

from simpeg_drivers.params import BaseOptions

import numpy as np
from geoapps_utils.utils.transformations import rotate_xyz

Expand All @@ -29,9 +31,9 @@
class ReceiversFactory(SimPEGFactory):
"""Build SimPEG receivers objects based on factory type."""

def __init__(self, params: BaseParams):
def __init__(self, params: BaseParams | BaseOptions):
"""
:param params: Params object containing SimPEG object parameters.
:param params: Options object containing SimPEG object parameters.

"""
super().__init__(params)
Expand Down
4 changes: 3 additions & 1 deletion simpeg_drivers/components/factories/simpeg_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
if TYPE_CHECKING:
from geoapps_utils.driver.params import BaseParams

from simpeg_drivers.params import BaseOptions

# TODO Redesign simpeg factory to avoid pylint arguments-differ complaint


Expand Down Expand Up @@ -62,7 +64,7 @@ class SimPEGFactory(ABC):
"joint cross gradient",
]

def __init__(self, params: BaseParams):
def __init__(self, params: BaseParams | BaseOptions):
"""
:param params: Driver parameters object.
"""
Expand Down
6 changes: 4 additions & 2 deletions simpeg_drivers/components/factories/simulation_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
if TYPE_CHECKING:
from geoapps_utils.driver.params import BaseParams

from simpeg_drivers.params import BaseOptions

from pathlib import Path

import numpy as np
Expand All @@ -29,9 +31,9 @@


class SimulationFactory(SimPEGFactory):
def __init__(self, params: BaseParams):
def __init__(self, params: BaseParams | BaseOptions):
"""
:param params: Params object containing SimPEG object parameters.
:param params: Options object containing SimPEG object parameters.

"""
super().__init__(params)
Expand Down
6 changes: 4 additions & 2 deletions simpeg_drivers/components/factories/source_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
if TYPE_CHECKING:
from geoapps_utils.driver.params import BaseParams

from simpeg_drivers.params import BaseOptions

from copy import deepcopy

import numpy as np
Expand All @@ -29,9 +31,9 @@
class SourcesFactory(SimPEGFactory):
"""Build SimPEG sources objects based on factory type."""

def __init__(self, params: BaseParams):
def __init__(self, params: BaseParams | BaseOptions):
"""
:param params: Params object containing SimPEG object parameters.
:param params: Options object containing SimPEG object parameters.

"""
super().__init__(params)
Expand Down
6 changes: 4 additions & 2 deletions simpeg_drivers/components/factories/survey_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
if TYPE_CHECKING:
from geoapps_utils.driver.params import BaseParams

from simpeg_drivers.params import BaseOptions

import numpy as np
import simpeg.electromagnetics.time_domain as tdem
from geoh5py.objects.surveys.electromagnetics.ground_tem import (
Expand Down Expand Up @@ -69,9 +71,9 @@ class SurveyFactory(SimPEGFactory):

dummy = -999.0

def __init__(self, params: BaseParams):
def __init__(self, params: BaseParams | BaseOptions):
"""
:param params: Params object containing SimPEG object parameters.
:param params: Options object containing SimPEG object parameters.
"""
super().__init__(params)
self.simpeg_object = self.concrete_object()
Expand Down
20 changes: 16 additions & 4 deletions simpeg_drivers/components/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
if TYPE_CHECKING:
from geoh5py.workspace import Workspace

from simpeg_drivers.params import InversionBaseParams
from simpeg_drivers.params import (
BaseForwardOptions,
BaseInversionOptions,
InversionBaseParams,
)

import numpy as np
from geoh5py.objects import ObjectBase, Points
Expand Down Expand Up @@ -47,13 +51,19 @@ class InversionLocations:

"""

def __init__(self, workspace: Workspace, params: InversionBaseParams):
def __init__(
self,
workspace: Workspace,
params: InversionBaseParams | BaseForwardOptions | BaseInversionOptions,
):
"""
:param workspace: Geoh5py workspace object containing location based data.
:param params: Params object containing location based data parameters.
:param params: Options object containing location based data parameters.
"""
self.workspace = workspace
self._params: InversionBaseParams = params
self._params: (
InversionBaseParams | BaseForwardOptions | BaseInversionOptions
) = params
self.mask: np.ndarray | None = None
self.locations: np.ndarray | None = None

Expand Down Expand Up @@ -121,6 +131,8 @@ def get_locations(self, entity: ObjectBase) -> np.ndarray:
return locations

def _filter(self, a, mask):
if a is None:
return None
for k, v in a.items():
if not isinstance(v, np.ndarray):
a.update({k: self._filter(v, mask)})
Expand Down
10 changes: 7 additions & 3 deletions simpeg_drivers/components/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from octree_creation_app.params import OctreeParams
from octree_creation_app.utils import octree_2_treemesh, treemesh_2_octree

from simpeg_drivers.params import InversionBaseParams
from simpeg_drivers.params import (
BaseForwardOptions,
BaseInversionOptions,
InversionBaseParams,
)
from simpeg_drivers.utils.meshes import auto_mesh_parameters
from simpeg_drivers.utils.utils import drape_2_tensor

Expand Down Expand Up @@ -77,11 +81,11 @@ class InversionMesh:
def __init__(
self,
workspace: Workspace,
params: InversionBaseParams,
params: InversionBaseParams | BaseForwardOptions | BaseInversionOptions,
) -> None:
"""
:param workspace: Workspace object containing mesh data.
:param params: Params object containing mesh parameters.
:param params: Options object containing mesh parameters.
"""
self.workspace = workspace
self.params = params
Expand Down
Loading
Loading