Skip to content
Merged
10 changes: 0 additions & 10 deletions simpeg_drivers-assets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,3 @@
# (see LICENSE file at the root of this source code package). '
# '
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

#
# This file is part of simpeg-drivers.
#
#
# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
#
# This file is part of simpeg_drivers package.
#
# All rights reserved.
10 changes: 8 additions & 2 deletions simpeg_drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,16 @@ def assets_path() -> Path:
),
"magnetic scalar": (
"simpeg_drivers.potential_fields.magnetic_scalar.driver",
{"inversion": "MagneticScalarDriver"},
{
"forward": "MagneticScalarForwardDriver",
"inversion": "MagneticScalarInversionDriver",
},
),
"magnetic vector": (
"simpeg_drivers.potential_fields.magnetic_vector.driver",
{"inversion": "MagneticVectorDriver"},
{
"forward": "MagneticScalarForwardDriver",
"inversion": "MagneticVectorInversionDriver",
},
),
}
1 change: 1 addition & 0 deletions simpeg_drivers/components/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import numpy as np
from discretize import TreeMesh
from geoh5py.shared.utils import fetch_active_workspace
from scipy.spatial import cKDTree
from simpeg import maps
from simpeg.electromagnetics.static.utils.static_utils import geometric_factor
Expand Down
12 changes: 5 additions & 7 deletions simpeg_drivers/components/factories/source_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,11 @@ def assemble_keyword_arguments( # pylint: disable=arguments-differ
_ = (receivers, frequency)
kwargs = {}
if self.factory_type in ["magnetic scalar", "magnetic vector"]:
kwargs = dict(
zip(
["amplitude", "inclination", "declination"],
self.params.inducing_field_aid(),
strict=False,
)
)
kwargs = {
"amplitude": self.params.inducing_field_strength,
"inclination": self.params.inducing_field_inclination,
"declination": self.params.inducing_field_declination,
}
if self.factory_type in ["magnetotellurics", "tipper"]:
background = deepcopy(self.params.background_conductivity)

Expand Down
10 changes: 5 additions & 5 deletions simpeg_drivers/components/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ def set_z_from_topo(self, locs: np.ndarray):
if locs is None:
return None

topo = self.get_locations(self.params.topography_object)
if self.params.topography is not None:
if isinstance(self.params.topography, Entity):
z = self.params.topography.values
topo = self.get_locations(self.params.active_cells.topography_object)
if self.params.active_cells.topography is not None:
if isinstance(self.params.active_cells.topography, Entity):
z = self.params.active_cells.topography.values
else:
z = np.ones_like(topo[:, 2]) * self.params.topography
z = np.ones_like(topo[:, 2]) * self.params.active_cells.topography

topo[:, 2] = z

Expand Down
8 changes: 6 additions & 2 deletions simpeg_drivers/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def mesh_cannot_be_rotated(cls, value: Octree):
def out_group_if_none(cls, data) -> SimPEGGroup:
group = data.get("out_group", None)

if isinstance(group, SimPEGGroup):
return data

if isinstance(group, UIJsonGroup | type(None)):
name = cls.title if group is None else group.name
with fetch_active_workspace(data["geoh5"], mode="r+") as geoh5:
Expand All @@ -140,8 +143,9 @@ def out_group_if_none(cls, data) -> SimPEGGroup:
@model_validator(mode="after")
def update_out_group_options(self):
assert self.out_group is not None
self.out_group.options = self.serialize()
self.out_group.metadata = None
with fetch_active_workspace(self.geoh5):
self.out_group.options = self.serialize()
self.out_group.metadata = None
return self

@property
Expand Down
10 changes: 8 additions & 2 deletions simpeg_drivers/potential_fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@


from .gravity.params import GravityForwardParams, GravityInversionParams
from .magnetic_scalar.params import MagneticScalarParams
from .magnetic_vector.params import MagneticVectorParams
from .magnetic_scalar.params import (
MagneticScalarForwardParams,
MagneticScalarInversionParams,
)
from .magnetic_vector.params import (
MagneticVectorForwardParams,
MagneticVectorInversionParams,
)

# pylint: disable=unused-import
# flake8: noqa
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''


from .params import MagneticScalarParams
from .params import MagneticScalarForwardParams, MagneticScalarInversionParams

# pylint: disable=unused-import
# flake8: noqa
18 changes: 11 additions & 7 deletions simpeg_drivers/potential_fields/magnetic_scalar/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@
from __future__ import annotations

from simpeg_drivers.driver import InversionDriver
from simpeg_drivers.potential_fields.magnetic_scalar.constants import validations
from simpeg_drivers.potential_fields.magnetic_scalar.params import (
MagneticScalarForwardParams,
MagneticScalarInversionParams,
)

from .constants import validations
from .params import MagneticScalarParams


class MagneticScalarDriver(InversionDriver):
_params_class = MagneticScalarParams
class MagneticScalarForwardDriver(InversionDriver):
_params_class = MagneticScalarForwardParams
_validations = validations

def __init__(self, params: MagneticScalarParams):
super().__init__(params)

class MagneticScalarInversionDriver(InversionDriver):
_params_class = MagneticScalarInversionParams
_validations = validations
Loading