diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index 19b47365..00033909 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -22,10 +22,11 @@ from geoapps_utils.utils.logger import get_logger from geoapps_utils.utils.numerical import inverse_weighted_operator from geoapps_utils.utils.plotting import symlog -from geoapps_utils.utils.transformations import cartesian_to_polar +from geoapps_utils.utils.transformations import cartesian_to_polar, rotate_xyz from geoh5py import Workspace from geoh5py.groups import PropertyGroup, SimPEGGroup -from geoh5py.objects import AirborneTEMReceivers, Surface +from geoh5py.objects import AirborneTEMReceivers, MaxwellPlate, Surface +from geoh5py.objects.maxwell_plate import PlateGeometry from geoh5py.ui_json import InputFile from scipy import signal from scipy.sparse import csr_matrix @@ -34,7 +35,7 @@ from simpeg_drivers.driver import BaseDriver from simpeg_drivers.plate_simulation.match.options import PlateMatchOptions -from simpeg_drivers.plate_simulation.options import PlateSimulationOptions +from simpeg_drivers.plate_simulation.options import ModelOptions, PlateSimulationOptions logger = get_logger(name=__name__, level_name=False, propagate=False, add_name=False) @@ -53,6 +54,12 @@ def __init__( self._drape_heights = self._get_drape_heights() self._template = self.get_template() self._time_mask, self._time_projection = self.time_mask_and_projection() + self._spatial_tree = cKDTree(self.params.survey.vertices[:, :2]) + + @property + def spatial_tree(self): + """KDTree for spatial locations of the survey.""" + return self._spatial_tree def get_template(self) -> AirborneTEMReceivers: """ @@ -98,6 +105,27 @@ def time_mask_and_projection(self) -> tuple[np.ndarray, csr_matrix]: ) return time_mask, time_projection + def spatial_mask_and_projection( + self, location: np.ndarray, strike_angle: float + ) -> tuple[np.ndarray, csr_matrix]: + """ + Create a spatial mask and interpolation matrix from simulation to observation locations. + + :param location: Query location (x, y, z). + :param strike_angle: Strike angle with respect to the plate orientation. + + :return: Spatial mask and spatial interpolation matrix. + """ + nearest = self.spatial_tree.query(location[:2], k=1)[1] + indices = self.params.survey.get_segment_indices( + nearest, self.params.max_distance + ) + spatial_projection = self.spatial_interpolation( + indices, + np.abs(strike_angle), + ) + return indices, spatial_projection + @classmethod def start(cls, filepath: str | Path, mode="r+", **_) -> Self: """Start the parameter matching from a ui.json file.""" @@ -123,6 +151,51 @@ def start(cls, filepath: str | Path, mode="r+", **_) -> Self: return driver + def _create_plate_from_parameters( + self, index_center: int, model_options: ModelOptions, strike_angle: float + ) -> MaxwellPlate: + """ + Create a MaxwellPlate object from the parameters of the survey and model options + at the location of the query point. + + :param index_center: Index of the center point in the survey vertices. + :param model_options: Model options containing plate geometry parameters. + :param strike_angle: Strike angle to correct the plate orientation. + + :return: MaxwellPlate object created from the parameters. + """ + center = self.params.survey.vertices[index_center] + center[2] = ( + self._drape_heights[index_center] - model_options.overburden_model.thickness + ) + indices = self.params.survey.get_segment_indices( + index_center, self.params.max_distance + ) + segment = self.params.survey.vertices[indices] + delta = np.median(np.diff(segment, axis=0), axis=0) + azimuth = 90 - np.rad2deg(np.arctan2(delta[1], delta[0])) + + plate_geometry = PlateGeometry.model_validate( + { + "position": { + "x": center[0], + "y": center[1], + "z": center[2], + }, + "width": model_options.plate_model.dip_length, + "thickness": model_options.plate_model.width, + "length": model_options.plate_model.strike_length, + "dip": model_options.plate_model.dip, + "dip_direction": (azimuth + strike_angle) % 360, + } + ) + plate = MaxwellPlate.create( + self.params.geoh5, geometry=plate_geometry, parent=self.params.out_group + ) + plate.metadata = model_options.model_dump() + + return plate + def _get_drape_heights(self) -> np.ndarray: """Set drape heights based on topography object and optional topography data.""" @@ -159,6 +232,12 @@ def spatial_interpolation( origin=np.r_[self.params.survey.vertices[indices, :2].mean(axis=0), 0], ) local_polar[local_polar[:, 1] >= 180, 0] *= -1 # Wrap azimuths + + # Flip the line segment if the azimuth angle suggests the opposite direction + start_line = len(indices) // 2 + if np.median(local_polar[:start_line, 1]) < 180: + local_polar = local_polar[::-1, :] + local_polar[:, 1] = ( 0.0 if strike_angle is None else strike_angle ) # Align azimuths to zero @@ -186,75 +265,62 @@ def run(self): "Running %s . . .", self.params.title, ) - observed = normalized_data(self.params.data)[self._time_mask, :] - tree = cKDTree(self.params.survey.vertices[:, :2]) + observed = get_data_array(self.params.data)[self._time_mask, :] + strike_angle = ( + np.zeros(self.params.queries.n_vertices) + if self.params.strike_angles is None + else self.params.strike_angles.values + ) + names = [] results = [] for ii, query in enumerate(self.params.queries.vertices): # Find the nearest survey location to the query point - nearest = tree.query(query[:2], k=1)[1] - indices = self.params.survey.get_segment_indices( - nearest, self.params.max_distance + indices, spatial_projection = self.spatial_mask_and_projection( + query, strike_angle[ii] ) - spatial_projection = self.spatial_interpolation( - indices, - 0 - if self.params.strike_angles is None - else self.params.strike_angles.values[ii], - ) - file_split = np.array_split( - self.params.simulation_files, np.maximum(1, len(self.workers) * 10) - ) - - tasks = [] - for file_batch in file_split: - args = ( - file_batch, - spatial_projection, - self._time_projection, - observed[:, indices], - ) - - tasks.append( - self.client.submit(batch_files_score, *args) - if self.client - else batch_files_score(*args) - ) - - # Display progress bar - if isinstance(tasks[0], Future): - progress(tasks) - tasks = self.client.gather(tasks) - - scores = np.hstack(tasks) - ranked = np.argsort(scores)[::-1] - - # TODO: Return top N matches - # for rank in ranked[-1:][::-1]: + data, flip = prepare_data(observed[:, indices]) + print(data.min(), data.max()) + # Loop through files and compute scores and find the best match + scores, centers = self.run_scores(spatial_projection, data) + print(scores) + ranked = np.argsort(scores) + best = ranked[0] logger.info( "File: %s \nScore: %.4f", - self.params.simulation_files[ranked[0]].name, - scores[ranked[0]], + self.params.simulation_files[best].name, + scores[best], ) - with Workspace(self.params.simulation_files[ranked[0]], mode="r") as ws: + with Workspace(self.params.simulation_files[best], mode="r") as ws: survey = fetch_survey(ws) ui_json = survey.parent.parent.options ui_json["geoh5"] = ws ifile = InputFile(ui_json=ui_json) options = PlateSimulationOptions.build(ifile) - plate = survey.parent.parent.get_entity("plate")[0].copy( - parent=self.params.out_group - ) - - # Set position of plate to query location - center = self.params.survey.vertices[nearest] - center[2] = self._drape_heights[nearest] - plate.vertices = plate.vertices + center - plate.metadata = options.model.model_dump() + dir_correction = strike_angle + 180 if flip else strike_angle - results.append(self.params.simulation_files[ranked[0]].name) + plate = self._create_plate_from_parameters( + int(indices[int(centers[best])]), options.model, dir_correction + ) + plate.name = f"Query [{ii}]" + + names.append(self.params.simulation_files[best].name) + results.append(scores[best]) + + out = self.params.queries.copy(parent=self.params.out_group) + out.add_data( + { + "file": { + "values": np.array(names, dtype="U"), + "primitive_type": "TEXT", + }, + "score": { + "values": np.array(results), + }, + } + ) - return results + return out @classmethod def start_dask_run( @@ -265,7 +331,6 @@ def start_dask_run( save_report: bool = True, ): """Overload configurations of BaseDriver Dask config settings.""" - # Force distributed on 1D problems if n_workers is None: cpu_count = multiprocessing.cpu_count() @@ -280,21 +345,91 @@ def start_dask_run( json_path, n_workers=n_workers, n_threads=n_threads, save_report=save_report ) + def run_scores(self, spatial_projection, data) -> tuple[np.ndarray, np.ndarray]: + """ + Run the scoring function for all simulation files in parallel using Dask. + + :param spatial_projection: Spatial interpolation matrix for the current query. + :param data: Prepared observed data for the current query. + + :return: Tuple of scores and corresponding center indices for each simulation file. + """ + file_split = np.array_split( + self.params.simulation_files, np.maximum(1, len(self.workers) * 10) + ) + tasks = [] + for file_batch in file_split: + args = ( + file_batch, + spatial_projection, + self._time_projection, + data, + ) + + tasks.append( + self.client.submit(batch_files_score, *args) + if self.client + else batch_files_score(*args) + ) -def normalized_data(property_group: PropertyGroup, threshold=5) -> np.ndarray: + # Display progress bar + if isinstance(tasks[0], Future): + progress(tasks) + tasks = self.client.gather(tasks) + + scores, centers = np.vstack(tasks).T + + return scores, centers + + +def prepare_data(data: np.ndarray) -> tuple[np.ndarray, bool]: """ - Return data from a property group with symlog scaling and zero mean. + Prepare data for scoring by checking for multiple channels and normalizing. + + param data: Array of data channels per location. + + :return: Tuple of prepared data array, whether locations were reversed. + """ + data_array = normalized_data(data) + + # Guess what the down-dip direction is based on integral + centered = data_array - np.min(data_array, axis=1)[:, None] + mid = centered.shape[1] // 2 + left = np.sum(centered[:, :mid], axis=1) + right = np.sum(centered[:, mid:], axis=1) - :param property_group: Property group containing data channels. + # Mostly on the left suggests the peaks are migrating up-dip and should be reversed + if np.mean(left > right) > 0.5: + return data_array[:, ::-1], True + + return data_array, False + + +def get_data_array(property_group: PropertyGroup) -> np.ndarray: + """ + Extract data array from a property group. + + :param property_group: Property group containing data values. + + :return: Data array with shape (n_times, n_locations). + """ + table = property_group.table() + return np.vstack(table.tolist()).T + + +def normalized_data(data: np.ndarray, threshold=5) -> np.ndarray: + """ + Return data from a property group with symlog, zero median and unit max normalization. + + :param data: Array of data channels per location. :param threshold: Percentile threshold for symlog normalization. :return: Normalized data array. """ - table = property_group.table() - data_array = np.vstack([table[name] for name in table.dtype.names]) - thresh = np.percentile(np.abs(data_array), threshold) - log_data = symlog(data_array, thresh) - return log_data - np.mean(log_data, axis=1)[:, None] + thresh = np.percentile(np.abs(data), threshold) + log_data = symlog(data, thresh) + centered_log = log_data - np.median(log_data) + return centered_log / np.abs(centered_log).max() def fetch_survey(workspace: Workspace) -> AirborneTEMReceivers | None: @@ -310,10 +445,13 @@ def fetch_survey(workspace: Workspace) -> AirborneTEMReceivers | None: def batch_files_score( files: Path | list[Path], spatial_projection, time_projection, observed -) -> list[float]: +) -> list[tuple[float, int]]: """ Process a batch of simulation files and compute scores against observed data. + Attempt to find the best collocation of the simulated and observed data by + finding the median index of the maximum correlation across channels. + :param files: Simulation file or list of simulation files to process. :param spatial_projection: Spatial interpolation matrix. :param time_projection: Time interpolation matrix. @@ -334,24 +472,29 @@ def batch_files_score( logger.warning("No survey found in %s, skipping.", sim_file) continue - simulated = normalized_data(survey.get_entity("Iteration_0_z")[0]) + simulated = get_data_array(survey.get_entity("Iteration_0_z")[0]) pred = time_projection @ (spatial_projection @ simulated.T).T + pred = normalized_data(pred) score = 0.0 - + indices = [] # Metric: normalized cross-correlation for obs, pre in zip(observed, pred, strict=True): + # Scale pre on obs + vals = pre / np.abs(pre).max() * np.abs(obs).max() + # Full cross-correlation - corr = signal.correlate(obs, pre, mode="full") + corr = signal.correlate(obs, vals, mode="same") # Normalize by energy to get correlation coefficient in [-1, 1] - denom = np.linalg.norm(pre) * np.linalg.norm(obs) + denom = np.linalg.norm(vals) * np.linalg.norm(obs) if denom == 0: corr_norm = np.zeros_like(corr) else: corr_norm = corr / denom - score += np.max(corr_norm) + score += np.linalg.norm(obs - vals) + indices.append(np.argmax(corr_norm)) - scores.append(score) + scores.append((score, np.median(indices))) return scores diff --git a/tests/plate_simulation/runtest/match_test.py b/tests/plate_simulation/runtest/match_test.py index 731d1140..89b5ee3e 100644 --- a/tests/plate_simulation/runtest/match_test.py +++ b/tests/plate_simulation/runtest/match_test.py @@ -13,10 +13,12 @@ import numpy as np import pytest from geoapps_utils.utils.importing import GeoAppsError +from geoapps_utils.utils.transformations import rotate_xyz from geoh5py import Workspace from geoh5py.groups import PropertyGroup, SimPEGGroup from geoh5py.objects import Points from geoh5py.ui_json import InputFile +from scipy import signal from simpeg_drivers import assets_path from simpeg_drivers.electromagnetics.time_domain.driver import TDEMForwardDriver @@ -40,8 +42,14 @@ def generate_example(geoh5: Workspace, n_grid_points: int, refinement: tuple[int]): opts = SyntheticsComponentsOptions( method="airborne tdem", - survey=SurveyOptions(n_stations=n_grid_points, n_lines=1, drape=10.0), - mesh=MeshOptions(refinement=refinement, padding_distance=400.0), + survey=SurveyOptions( + n_stations=n_grid_points, + n_lines=1, + width=1000, + drape=40.0, + topography=lambda x, y: np.zeros(x.shape), + ), + mesh=MeshOptions(refinement=refinement), model=ModelOptions(background=0.001), ) components = SyntheticsComponents(geoh5, options=opts) @@ -110,7 +118,7 @@ def test_matching_driver(tmp_path: Path): # Generate simulation files with get_workspace(tmp_path / f"{__name__}.geoh5") as geoh5: - components = generate_example(geoh5, n_grid_points=5, refinement=(2,)) + components = generate_example(geoh5, n_grid_points=32, refinement=(2,)) params = TDEMForwardOptions.build( geoh5=geoh5, @@ -132,6 +140,8 @@ def test_matching_driver(tmp_path: Path): ifile.data["simulation"] = fwr_driver.out_group plate_options = PlateSimulationOptions.build(ifile.data) + plate_options.model.overburden_model.thickness = 40.0 + plate_options.model.plate_model.dip_length = 300.0 driver = PlateSimulationDriver(plate_options) driver.run() @@ -148,24 +158,58 @@ def test_matching_driver(tmp_path: Path): with Workspace(new_file) as sim_geoh5: survey = fetch_survey(sim_geoh5) prop_group = survey.get_entity("Iteration_0_z")[0] - scale = np.cos(np.linspace(-np.pi / ii, np.pi / ii, survey.n_vertices)) - for uid in prop_group.properties: - child = survey.get_entity(uid)[0] - child.values = child.values * scale + # Alter the signal to simulate different plate models + scale = signal.windows.gaussian(survey.n_vertices, 2**ii) - # Random choice of file + for ii, uid in enumerate(prop_group.properties): + child = survey.get_entity(uid)[0] + child.values = child.values * np.roll(scale, ii) + print(child.values.max(), child.values.min()) + # Downsample stations + mask = np.ones_like(child.values, dtype=bool) + mask[1::2] = False + survey.remove_vertices(mask) + indices = np.arange(survey.n_vertices) + survey.cells = np.c_[indices[:-1], indices[1:]] + + # Run the matching driver with geoh5.open(): survey = fetch_survey(geoh5) + + # Rotate the survey to test matching + survey.vertices = rotate_xyz(survey.vertices, [0, 0, 0], 215.0) + + # Flip the data to simulate up-dip measurements + prop_group = survey.get_entity("Iteration_0_z")[0] + for uid in prop_group.properties: + child = survey.get_entity(uid)[0] + child.values = child.values[::-1] + + # Change the strike angle to simulate a different orientation + strikes = components.queries.add_data( + { + "strike": { + "values": np.full(components.queries.n_vertices, -10.0), + } + } + ) + options = PlateMatchOptions( geoh5=geoh5, survey=survey, - data=survey.get_entity("Iteration_0_z")[0], + data=prop_group, queries=components.queries, + strike_angles=strikes, topography_object=components.topography, simulations=new_dir, ) match_driver = PlateMatchDriver(options) results = match_driver.run() - assert results[0] == file.stem + f"_[{4}].geoh5" + assert isinstance(results, Points) + + names = results.get_data("file")[0] + assert names.values[0] == file.stem + f"_[{4}].geoh5" + + assert geoh5.get_entity("Query [0]")[0].geometry.dip_direction == 45.0