Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions geoapps_utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ def update_monitoring_directory(
copy_children=copy_children,
)

@classmethod
def get_default_ui_json(cls) -> Path | None:
"""
Get the default ui.json file path for the application.

:return: Path to default ui.json file.
"""
if issubclass(cls._params_class, Options):
return cls._params_class.default_ui_json
return None


class Options(BaseModel):
"""
Expand Down
39 changes: 26 additions & 13 deletions geoapps_utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def load_ui_json_as_dict(filepath: str | Path | dict) -> dict:
return uijson


def fetch_driver_class(json_dict: str | Path | dict) -> type[Driver]:
def fetch_driver_class_from_string(module_path: str) -> type[Driver]:
"""
Fetch the driver class from the ui.json 'run_command'.
Fetch the driver class from a module path string.

:param json_dict: Path to a ui.json file with a 'run_command' key.
:param module_path: Module path string.

:return: Driver class.
"""
Expand All @@ -63,15 +63,7 @@ def fetch_driver_class(json_dict: str | Path | dict) -> type[Driver]:
BaseDriver,
)

uijson = load_ui_json_as_dict(json_dict)

if "run_command" not in uijson or not isinstance(uijson["run_command"], str):
raise KeyError(
"'run_command' in ui.json must be a string representing the module path."
f" Got {uijson.get('run_command', None)}."
)

module = import_module(uijson["run_command"])
module = import_module(module_path)
cls = None
for _, cls in inspect.getmembers(module):
try:
Expand All @@ -86,13 +78,34 @@ def fetch_driver_class(json_dict: str | Path | dict) -> type[Driver]:
else:
logger.warning(
"\n\nApplicationError: No valid driver class found in module %s\n\n",
uijson["run_command"],
module_path,
)
sys.exit(1)

return cls


def fetch_driver_class(json_dict: str | Path | dict) -> type[Driver]:
"""
Fetch the driver class from the ui.json 'run_command'.

:param json_dict: Path to a ui.json file with a 'run_command' key.

:return: Driver class.
"""
uijson = load_ui_json_as_dict(json_dict)

if "run_command" not in uijson or not isinstance(uijson["run_command"], str):
raise KeyError(
"'run_command' in ui.json must be a string representing the module path."
f" Got {uijson.get('run_command', None)}."
)

cls = fetch_driver_class_from_string(uijson["run_command"])

return cls


def run_uijson_group(
out_group: UIJsonGroup,
) -> Driver:
Expand Down
5 changes: 5 additions & 0 deletions tests/driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def test_base_driver(tmp_path):
TestParamsDriver("not a params object") # type: ignore

driver = TestParamsDriver(params)

assert TestParamsDriver.get_default_ui_json() is None

driver.start(tmp_path / "test_ifile.ui.json")

with pytest.raises(TypeError, match="Input file must be "):
Expand Down Expand Up @@ -117,6 +120,8 @@ def test_base_options(tmp_path):

driver = TestOptionsDriver(options)

assert TestOptionsDriver.get_default_ui_json().exists() # type: ignore

assert isinstance(driver.params, TestOptions)
assert driver.params_class == TestOptions
assert isinstance(driver.workspace, Workspace)
Expand Down
Loading