diff --git a/geoapps_utils/base.py b/geoapps_utils/base.py index 2cb2325f..bc6cf70a 100644 --- a/geoapps_utils/base.py +++ b/geoapps_utils/base.py @@ -168,6 +168,17 @@ def update_monitoring_directory( copy_children=copy_children, ) + @classmethod + def get_default_ui_json(cls) -> Path | None: + """ + Get the default ui.json file path for the application. + + :return: Path to default ui.json file. + """ + if issubclass(cls._params_class, Options): + return cls._params_class.default_ui_json + return None + class Options(BaseModel): """ diff --git a/geoapps_utils/run.py b/geoapps_utils/run.py index a428a3df..8e9eea92 100644 --- a/geoapps_utils/run.py +++ b/geoapps_utils/run.py @@ -50,11 +50,11 @@ def load_ui_json_as_dict(filepath: str | Path | dict) -> dict: return uijson -def fetch_driver_class(json_dict: str | Path | dict) -> type[Driver]: +def fetch_driver_class_from_string(module_path: str) -> type[Driver]: """ - Fetch the driver class from the ui.json 'run_command'. + Fetch the driver class from a module path string. - :param json_dict: Path to a ui.json file with a 'run_command' key. + :param module_path: Module path string. :return: Driver class. """ @@ -63,15 +63,7 @@ def fetch_driver_class(json_dict: str | Path | dict) -> type[Driver]: BaseDriver, ) - uijson = load_ui_json_as_dict(json_dict) - - if "run_command" not in uijson or not isinstance(uijson["run_command"], str): - raise KeyError( - "'run_command' in ui.json must be a string representing the module path." - f" Got {uijson.get('run_command', None)}." - ) - - module = import_module(uijson["run_command"]) + module = import_module(module_path) cls = None for _, cls in inspect.getmembers(module): try: @@ -86,13 +78,34 @@ def fetch_driver_class(json_dict: str | Path | dict) -> type[Driver]: else: logger.warning( "\n\nApplicationError: No valid driver class found in module %s\n\n", - uijson["run_command"], + module_path, ) sys.exit(1) return cls +def fetch_driver_class(json_dict: str | Path | dict) -> type[Driver]: + """ + Fetch the driver class from the ui.json 'run_command'. + + :param json_dict: Path to a ui.json file with a 'run_command' key. + + :return: Driver class. + """ + uijson = load_ui_json_as_dict(json_dict) + + if "run_command" not in uijson or not isinstance(uijson["run_command"], str): + raise KeyError( + "'run_command' in ui.json must be a string representing the module path." + f" Got {uijson.get('run_command', None)}." + ) + + cls = fetch_driver_class_from_string(uijson["run_command"]) + + return cls + + def run_uijson_group( out_group: UIJsonGroup, ) -> Driver: diff --git a/tests/driver_test.py b/tests/driver_test.py index dfc43531..f4aa9459 100644 --- a/tests/driver_test.py +++ b/tests/driver_test.py @@ -65,6 +65,9 @@ def test_base_driver(tmp_path): TestParamsDriver("not a params object") # type: ignore driver = TestParamsDriver(params) + + assert TestParamsDriver.get_default_ui_json() is None + driver.start(tmp_path / "test_ifile.ui.json") with pytest.raises(TypeError, match="Input file must be "): @@ -117,6 +120,8 @@ def test_base_options(tmp_path): driver = TestOptionsDriver(options) + assert TestOptionsDriver.get_default_ui_json().exists() # type: ignore + assert isinstance(driver.params, TestOptions) assert driver.params_class == TestOptions assert isinstance(driver.workspace, Workspace)