diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6b23c5dc..b2dd2e3d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -69,6 +69,13 @@ jobs: uv pip install --no-cache-dir -e . --no-deps rm -rf ~/.cache # /app/RFdiffusion/tests + - name: Preseed DGL backend + shell: bash + run: | + mkdir -p "$HOME/.dgl" + printf '{"backend": "pytorch"}' > "$HOME/.dgl/config.conf" + echo "DGLBACKEND=pytorch" >> "$GITHUB_ENV" + - name: Download weights run: | mkdir models @@ -87,8 +94,29 @@ jobs: - name: Setup and Run ppi_scaffolds tests run: | tar -xvf examples/ppi_scaffolds_subset.tar.gz -C examples - cd tests && uv run python test_diffusion.py + total_chunks=$(nproc) + cd tests + + #launch all chunks in background and record PIDs + labels + pids="" + for chunk_index in $(seq 1 $total_chunks); do + echo "Running chunk $chunk_index of $total_chunks" + uv run python test_diffusion.py --total_chunks $total_chunks --chunk_index $chunk_index & + pids="$pids $!" + done + + # wait for each and track failures + fail=0 + for pid in $pids; do + if ! wait "$pid"; then + echo "A chunk (PID $pid) failed" + fail=1 + else + echo "A chunk (PID $pid) passed" + fi + done + exit "$fail" # - name: Test with pytest # run: | diff --git a/examples/design_macrocyclic_binder.sh b/examples/design_macrocyclic_binder.sh index c0c69f0b..1067a01e 100755 --- a/examples/design_macrocyclic_binder.sh +++ b/examples/design_macrocyclic_binder.sh @@ -1,18 +1,15 @@ #!/bin/bash -prefix=./outputs/diffused_binder_cyclic2 +# Note that in the example below the indices in the +# input_pdbs/7zkr_GABARAP.pdb file have been shifted +# by +2 in chain A relative to pdbID 7zkr. -# Note that the indices in this pdb file have been -# shifted by +2 in chain A relative to pdbID 7zkr. -pdb='./input_pdbs/7zkr_GABARAP.pdb' - -num_designs=10 -script="../scripts/run_inference.py" -$script --config-name base \ -inference.output_prefix=$prefix \ -inference.num_designs=$num_designs \ +../scripts/run_inference.py \ +--config-name base \ +inference.output_prefix=example_outputs/diffused_binder_cyclic2 \ +inference.num_designs=10 \ 'contigmap.contigs=[12-18 A3-117/0]' \ -inference.input_pdb=$pdb \ +inference.input_pdb=./input_pdbs/7zkr_GABARAP.pdb \ inference.cyclic=True \ diffuser.T=50 \ inference.cyc_chains='a' \ diff --git a/examples/design_macrocyclic_monomer.sh b/examples/design_macrocyclic_monomer.sh index 96eda600..3aa1ac3d 100755 --- a/examples/design_macrocyclic_monomer.sh +++ b/examples/design_macrocyclic_monomer.sh @@ -1,17 +1,15 @@ #!/bin/bash -prefix=./outputs/uncond_cycpep -# Note that the indices in this pdb file have been -# shifted by +2 in chain A relative to pdbID 7zkr. -pdb='./input_pdbs/7zkr_GABARAP.pdb' +# Note that in the example below the indices in the +# input_pdbs/7zkr_GABARAP.pdb file have been shifted +# by +2 in chain A relative to pdbID 7zkr. -num_designs=10 -script="../scripts/run_inference.py" -$script --config-name base \ -inference.output_prefix=$prefix \ -inference.num_designs=$num_designs \ +../scripts/run_inference.py \ +--config-name base \ +inference.output_prefix=example_outputs/uncond_cycpep \ +inference.num_designs=10 \ 'contigmap.contigs=[12-18]' \ -inference.input_pdb=$pdb \ +inference.input_pdb=input_pdbs/7zkr_GABARAP.pdb \ inference.cyclic=True \ diffuser.T=50 \ inference.cyc_chains='a' diff --git a/examples/design_tetrahedral_oligos.sh b/examples/design_tetrahedral_oligos.sh index 231e14a9..5aab7887 100755 --- a/examples/design_tetrahedral_oligos.sh +++ b/examples/design_tetrahedral_oligos.sh @@ -5,6 +5,6 @@ # This external potential promotes contacts both within (with a relative weight of 1) and between chains (relative weight 0.1) # We specify that we want to apply these potentials to all chains, with a guide scale of 2.0 (a sensible starting point) # We decay this potential with quadratic form, so that it is applied more strongly initially -# We specify a total length of 1200aa, so each chain is 100 residues long +# We specify a total length of 1200aa, so each chain is 100 residues long - length updated to 600aa, so each chain is 50 residues long for testing to run faster -python ../scripts/run_inference.py --config-name=symmetry inference.symmetry="tetrahedral" inference.num_designs=10 inference.output_prefix="example_outputs/tetrahedral_oligo" 'potentials.guiding_potentials=["type:olig_contacts,weight_intra:1,weight_inter:0.1"]' potentials.olig_intra_all=True potentials.olig_inter_all=True potentials.guide_scale=2.0 potentials.guide_decay="quadratic" 'contigmap.contigs=[1200-1200]' +python ../scripts/run_inference.py --config-name=symmetry inference.symmetry="tetrahedral" inference.num_designs=10 inference.output_prefix="example_outputs/tetrahedral_oligo" 'potentials.guiding_potentials=["type:olig_contacts,weight_intra:1,weight_inter:0.1"]' potentials.olig_intra_all=True potentials.olig_inter_all=True potentials.guide_scale=2.0 potentials.guide_decay="quadratic" 'contigmap.contigs=[600-600]' diff --git a/tests/test_diffusion.py b/tests/test_diffusion.py index f3eaeccd..8de5c03e 100644 --- a/tests/test_diffusion.py +++ b/tests/test_diffusion.py @@ -11,6 +11,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__)) + class TestSubmissionCommands(unittest.TestCase): """ Test harness for checking that commands in the examples folder, @@ -25,88 +26,163 @@ class TestSubmissionCommands(unittest.TestCase): outputs are the same as the reference outputs. """ - def setUp(self): + failed_tests = [] + + # number of chunks to split examples into + total_chunks = 1 + # which chunk to run + chunk_index = 1 + + out_f = None + results = {} + exec_status = {} + + @classmethod + def setUpClass(cls): """ - Grabs files from the examples folder + Class-level setup: Grabs files from the examples folder, discover & rewrite example commands, then execute them once. """ submissions = glob.glob(f"{script_dir}/../examples/*.sh") # get datetime for output folder, in YYYY_MM_DD_HH_MM_SS format + chunks = cls.total_chunks + idx = cls.chunk_index + if chunks < 1: + raise ValueError("total_chunks must be at least 1") + if idx < 1 or idx > chunks: + raise ValueError( + "chunk_index must be between 1 and total_chunks (inclusive)" + ) + if chunks > 1: + submissions = [ + submissions[i] + for i in range(len(submissions)) + if i % chunks == (idx - 1) + ] + print( + f"Running chunk {idx}/{chunks}, {len(submissions)} submissions to run" + ) + if not submissions: + raise ValueError("No submissions selected for chunk {idx} of {chunks}") + now = datetime.datetime.now() now = now.strftime("%Y_%m_%d_%H_%M_%S") - self.out_f = f"{script_dir}/tests_{now}" - os.mkdir(self.out_f) + cls.out_f = f"{script_dir}/tests_{now}_{idx}" + os.mkdir(cls.out_f) # Make sure we have access to all the relevant files exclude_dirs = ["outputs", "example_outputs"] for filename in os.listdir(f"{script_dir}/../examples"): - if filename not in exclude_dirs and not os.path.islink(os.path.join(script_dir, filename)) and os.path.isdir(os.path.join(f'{script_dir}/../examples', filename)): - os.symlink(os.path.join(f'{script_dir}/../examples', filename), os.path.join(script_dir, filename)) + if ( + filename not in exclude_dirs + and not os.path.exists(os.path.join(script_dir, filename)) + and os.path.isdir(os.path.join(f"{script_dir}/../examples", filename)) + ): + try: + os.symlink( + os.path.join(f"{script_dir}/../examples", filename), + os.path.join(script_dir, filename), + ) + except FileExistsError: + pass for submission in submissions: - self._write_command(submission, self.out_f) - - print(f"Running commands in {self.out_f}, two steps of diffusion, deterministic=True") - - self.results = {} - - for bash_file in sorted( glob.glob(f"{self.out_f}/*.sh"), reverse=False): - test_name = os.path.basename(bash_file)[:-len('.sh')] - res, output = execute(f"Running {test_name}", f'bash {bash_file}', return_='tuple', add_message_and_command_line_to_output=True) - - self.results[test_name] = dict( - state = 'failed' if res else 'passed', - log = output, + cls._write_command(submission, cls.out_f) + + print( + f"Running commands in {cls.out_f}, two steps of diffusion, deterministic=True" + ) + + cls.results = {} + cls.exec_status = {} + + for bash_file in sorted(glob.glob(f"{cls.out_f}/*.sh"), reverse=False): + test_name = os.path.basename(bash_file)[: -len(".sh")] + res, output = execute( + f"Running {test_name}", + f"bash {bash_file}", + return_="tuple", + add_message_and_command_line_to_output=True, ) + cls.exec_status[test_name] = (res, output) - #subprocess.run(["bash", bash_file], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - #subprocess.run(["bash", bash_file]) + cls.results[test_name] = dict( + state="failed" if res else "passed", + log=output, + ) + # subprocess.run(["bash", bash_file], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + # subprocess.run(["bash", bash_file]) + + def test_examples_run_without_errors(self): + for name, (exit_code, output) in sorted(self.__class__.exec_status.items()): + with self.subTest(example=name): + if exit_code != 0: + self.__class__.failed_tests.append(f"{name}") + self.assertEqual( + exit_code, + 0, + msg=f"Example '{name}' exited with {exit_code}\n{output}", + ) + + sys.stderr.write("\n==== EXAMPLE FAILURE SUMMARY ====\n") + for line in self.__class__.failed_tests: + sys.stderr.write(f" - {line}\n") + sys.stderr.write("=========================\n\n") + sys.stderr.flush() def test_commands(self): """ Runs all the commands in the test_f folder """ - reference=f'{script_dir}/reference_outputs' + reference = f"{script_dir}/reference_outputs" os.makedirs(reference, exist_ok=True) - test_files=glob.glob(f"{self.out_f}/example_outputs/*pdb") - print(f'{self.out_f=} {test_files=}') + test_files = glob.glob(f"{self.__class__.out_f}/example_outputs/*pdb") + print(f"{self.__class__.out_f=} {test_files=}") # first check that we have the right number of outputs - #self.assertEqual(len(test_files), len(glob.glob(f"{self.out_f}/*.sh"))), "One or more of the example commands didn't produce an output (check the example command is formatted correctly)" + # self.assertEqual(len(test_files), len(glob.glob(f"{self.out_f}/*.sh"))), "One or more of the example commands didn't produce an output (check the example command is formatted correctly)" result = self.defaultTestResult() for test_file in test_files: with self.subTest(test_file=test_file): - test_pdb=iu.parse_pdb(test_file) + test_pdb = iu.parse_pdb(test_file) if not os.path.exists(f"{reference}/{os.path.basename(test_file)}"): copyfile(test_file, f"{reference}/{os.path.basename(test_file)}") - print(f"Created reference file {reference}/{os.path.basename(test_file)}") + print( + f"Created reference file {reference}/{os.path.basename(test_file)}" + ) else: - ref_pdb=iu.parse_pdb(f"{reference}/{os.path.basename(test_file)}") - rmsd=calc_rmsd(test_pdb['xyz'][:,:3].reshape(-1,3), ref_pdb['xyz'][:,:3].reshape(-1,3))[0] + ref_pdb = iu.parse_pdb(f"{reference}/{os.path.basename(test_file)}") + rmsd = calc_rmsd( + test_pdb["xyz"][:, :3].reshape(-1, 3), + ref_pdb["xyz"][:, :3].reshape(-1, 3), + )[0] try: self.assertAlmostEqual(rmsd, 0, 2) result.addSuccess(self) print(f"Subtest {test_file} passed") - state = 'passed' - log = f'Subtest {test_file} passed' + state = "passed" + log = f"Subtest {test_file} passed" except AssertionError as e: result.addFailure(self, e) print(f"Subtest {test_file} failed") - state = 'failed' - log = f'Subtest {test_file} failed:\n{e!r}' + state = "failed" + log = f"Subtest {test_file} failed:\n{e!r}" - self.results[ 'pdb-diff.' + test_file.rpartition('/')[-1] ] = dict(state = state, log = log) + self.results["pdb-diff." + test_file.rpartition("/")[-1]] = dict( + state=state, log=log + ) - with open('.results.json', 'w') as f: json.dump(self.results, f, sort_keys=True, indent=2) + with open(".results.json", "w") as f: + json.dump(self.results, f, sort_keys=True, indent=2) self.assertTrue(result.wasSuccessful(), "One or more subtests failed") - - def _write_command(self, bash_file, test_f) -> None: + @classmethod + def _write_command(cls, bash_file, test_f) -> None: """ Takes a bash file from the examples folder, and writes a version of it to the test_f folder. @@ -117,31 +193,42 @@ def _write_command(self, bash_file, test_f) -> None: else: inference.final_step=48 """ - out_lines=[] + out_lines = [] + command_lines = [] + in_command = False with open(bash_file, "r") as f: - lines = f.readlines() - for line in lines: - if not (line.startswith("python") or line.startswith("../")): - out_lines.append(line) + for line in f: + stripped = line.strip() + if stripped.startswith("python") or stripped.startswith("../"): + in_command = True + if in_command: + # Remove trailing line continuation slashes + if stripped.endswith("\\"): + command_lines.append(stripped[:-1].strip()) + else: + command_lines.append(stripped) + in_command = False # End of command else: - command = line.strip() - if not command.startswith("python"): - command = f'python {command}' + out_lines.append(line) + if not command_lines: + raise ValueError(f"No valid python command found in {bash_file}") + command = " ".join(command_lines) # get the partial_T if "partial_T" in command: final_step = int(command.split("partial_T=")[1].split(" ")[0]) - 2 else: final_step = 48 - output_command = f"{command} inference.deterministic=True inference.final_step={final_step}" + output_command = ( + f"{command} inference.deterministic=True inference.final_step={final_step}" + ) # replace inference.num_designs with 1 if "inference.num_designs=" in output_command: output_command = f'{output_command.split("inference.num_designs=")[0]}inference.num_designs=1 {" ".join(output_command.split("inference.num_designs=")[1].split(" ")[1:])}' else: - output_command = f'{output_command} inference.num_designs=1' + output_command = f"{output_command} inference.num_designs=1" # replace 'example_outputs' with f'{self.out_f}/example_outputs' - output_command = f'{output_command.split("example_outputs")[0]}{self.out_f}/example_outputs{output_command.split("example_outputs")[1]}' - + output_command = f'{output_command.split("example_outputs")[0]}{cls.out_f}/example_outputs{output_command.split("example_outputs")[1]}' # write the new command with open(f"{test_f}/{os.path.basename(bash_file)}", "w") as f: @@ -150,28 +237,38 @@ def _write_command(self, bash_file, test_f) -> None: f.write(output_command) - def execute_through_pty(command_line): import pty, select if sys.platform == "darwin": master, slave = pty.openpty() - p = subprocess.Popen(command_line, shell=True, stdout=slave, stdin=slave, - stderr=subprocess.STDOUT, close_fds=True) + p = subprocess.Popen( + command_line, + shell=True, + stdout=slave, + stdin=slave, + stderr=subprocess.STDOUT, + close_fds=True, + ) buffer = [] while True: try: if select.select([master], [], [], 0.2)[0]: # has something to read data = os.read(master, 1 << 22) - if data: buffer.append(data) + if data: + buffer.append(data) - elif (p.poll() is not None) and (not select.select([master], [], [], 0.2)[0] ): break # process is finished and output buffer if fully read + elif (p.poll() is not None) and ( + not select.select([master], [], [], 0.2)[0] + ): + break # process is finished and output buffer if fully read - except OSError: break # OSError will be raised when child process close PTY descriptior + except OSError: + break # OSError will be raised when child process close PTY descriptior - output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace') + output = b"".join(buffer).decode(encoding="utf-8", errors="backslashreplace") os.close(master) os.close(slave) @@ -179,7 +276,7 @@ def execute_through_pty(command_line): p.wait() exit_code = p.returncode - ''' + """ buffer = [] while True: if select.select([master], [], [], 0.2)[0]: # has something to read @@ -200,13 +297,19 @@ def execute_through_pty(command_line): output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace') exit_code = p.returncode - ''' + """ else: master, slave = pty.openpty() - p = subprocess.Popen(command_line, shell=True, stdout=slave, stdin=slave, - stderr=subprocess.STDOUT, close_fds=True) + p = subprocess.Popen( + command_line, + shell=True, + stdout=slave, + stdin=slave, + stderr=subprocess.STDOUT, + close_fds=True, + ) os.close(slave) @@ -214,10 +317,12 @@ def execute_through_pty(command_line): while True: try: data = os.read(master, 1 << 22) - if data: buffer.append(data) - except OSError: break # OSError will be raised when child process close PTY descriptior + if data: + buffer.append(data) + except OSError: + break # OSError will be raised when child process close PTY descriptior - output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace') + output = b"".join(buffer).decode(encoding="utf-8", errors="backslashreplace") os.close(master) @@ -227,39 +332,84 @@ def execute_through_pty(command_line): return exit_code, output - -def execute(message, command_line, return_='status', until_successes=False, terminate_on_failure=True, silent=False, silence_output=False, silence_output_on_errors=False, add_message_and_command_line_to_output=False): - if not silent: print(message); print(command_line); sys.stdout.flush(); +def execute( + message, + command_line, + return_="status", + until_successes=False, + terminate_on_failure=True, + silent=False, + silence_output=False, + silence_output_on_errors=False, + add_message_and_command_line_to_output=False, +): + if not silent: + print(message) + print(command_line) + sys.stdout.flush() while True: - #exit_code, output = execute_through_subprocess(command_line) - #exit_code, output = execute_through_pexpect(command_line) + # exit_code, output = execute_through_subprocess(command_line) + # exit_code, output = execute_through_pexpect(command_line) exit_code, output = execute_through_pty(command_line) - if (exit_code and not silence_output_on_errors) or not (silent or silence_output): print(output); sys.stdout.flush(); + if (exit_code and not silence_output_on_errors) or not ( + silent or silence_output + ): + print(output) + sys.stdout.flush() - if exit_code and until_successes: pass # Thats right - redability COUNT! - else: break + if exit_code and until_successes: + pass # Thats right - redability COUNT! + else: + break - print( "Error while executing {}: {}\n".format(message, output) ) + print("Error while executing {}: {}\n".format(message, output)) print("Sleeping 60s... then I will retry...") - sys.stdout.flush(); + sys.stdout.flush() time.sleep(60) - if add_message_and_command_line_to_output: output = message + '\nCommand line: ' + command_line + '\n' + output + if add_message_and_command_line_to_output: + output = message + "\nCommand line: " + command_line + "\n" + output - if return_ == 'tuple' or return_ == tuple: return(exit_code, output) + if return_ == "tuple" or return_ == tuple: + return (exit_code, output) if exit_code and terminate_on_failure: print("\nEncounter error while executing: " + command_line) - if return_==True: return True + if return_ == True: + return True else: - print('\nEncounter error while executing: ' + command_line + '\n' + output); - raise BenchmarkError('\nEncounter error while executing: ' + command_line + '\n' + output) + print("\nEncounter error while executing: " + command_line + "\n" + output) + raise BenchmarkError( + "\nEncounter error while executing: " + command_line + "\n" + output + ) - if return_ == 'output': return output - else: return exit_code + if return_ == "output": + return output + else: + return exit_code if __name__ == "__main__": - unittest.main() + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--total_chunks", + type=int, + default=1, + help="total number of chunks to split the examples into (default: 1)", + ) + parser.add_argument( + "--chunk_index", + type=int, + default=1, + help="Which chunk to run (1-based index, default:1)", + ) + args, remaining = parser.parse_known_args() + + TestSubmissionCommands.total_chunks = args.total_chunks + TestSubmissionCommands.chunk_index = args.chunk_index + + unittest.main(argv=[sys.argv[0]] + remaining)