diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py index 9a33d98d..74e40569 100644 --- a/rfdiffusion/inference/model_runners.py +++ b/rfdiffusion/inference/model_runners.py @@ -18,7 +18,8 @@ from rfdiffusion.model_input_logger import pickle_function_call import sys -SCRIPT_DIR=os.path.dirname(os.path.realpath(__file__)) +from rfdiffusion.util import USER_DIR + TOR_INDICES = util.torsion_indices TOR_CAN_FLIP = util.torsion_can_flip @@ -63,7 +64,8 @@ def initialize(self, conf: DictConfig) -> None: if conf.inference.model_directory_path is not None: model_directory = conf.inference.model_directory_path else: - model_directory = f"{SCRIPT_DIR}/../../models" + # set default model weigths directory under env var control. fallback to user $HOME/rfdiffusion/models + model_directory = os.environ.get('RFD_MODELS',f"{USER_DIR}/models") print(f"Reading models from {model_directory}") @@ -122,7 +124,7 @@ def initialize(self, conf: DictConfig) -> None: if conf.inference.schedule_directory_path is not None: schedule_directory = conf.inference.schedule_directory_path else: - schedule_directory = f"{SCRIPT_DIR}/../../schedules" + schedule_directory = f"{USER_DIR}/schedules" # Check for cache schedule if not os.path.exists(schedule_directory): diff --git a/rfdiffusion/util.py b/rfdiffusion/util.py index 19c30f5f..b273bcc2 100644 --- a/rfdiffusion/util.py +++ b/rfdiffusion/util.py @@ -2,6 +2,17 @@ from rfdiffusion.chemical import * from rfdiffusion.scoring import * +import sys +from pathlib import Path +# define RFdiffusion directory located in user HOME directory. +USER_HOME=str(Path.home()) +USER_DIR=f"{USER_HOME}/rfdiffusion" + +try: + Path(USER_DIR).mkdir(parents=True, exist_ok=True) +except FileExistsError as msg: + print(f'{USER_DIR} already exist and is a file.') + sys.exit(1) def generate_Cbeta(N, Ca, C): # recreate Cb given N,Ca,C diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 2a3bf362..0ceffe7a 100755 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -21,13 +21,16 @@ from omegaconf import OmegaConf import hydra import logging -from rfdiffusion.util import writepdb_multi, writepdb +from rfdiffusion.util import writepdb_multi, writepdb, USER_DIR from rfdiffusion.inference import utils as iu from hydra.core.hydra_config import HydraConfig import numpy as np import random import glob +# set hyfra inference config dir under environment variable control +hydra_cfg_dir=os.environ.get('RFD_HYDRA_CFG', f"{USER_DIR}/config/inference") + def make_deterministic(seed=0): torch.manual_seed(seed) @@ -35,7 +38,7 @@ def make_deterministic(seed=0): random.seed(seed) -@hydra.main(version_base=None, config_path="../config/inference", config_name="base") +@hydra.main(version_base=None, config_path=hydra_cfg_dir, config_name="base") def main(conf: HydraConfig) -> None: log = logging.getLogger(__name__) if conf.inference.deterministic: