Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ This project uses **Covasim** to simulate the spread of COVID-19 within a define
Install Covasim via pip:

```bash
pip install covasim
pip install -r requirements.txt
```

---
Expand Down
430 changes: 430 additions & 0 deletions assets/NC_045512_Hu-1.fasta

Large diffs are not rendered by default.

68 changes: 68 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
appnope==0.1.4
asttokens==3.0.1
biopython==1.87
comm==0.2.3
contourpy==1.3.2
covasim==3.1.7
cycler==0.12.1
debugpy==1.8.20
decorator==5.2.1
dill==0.4.1
epyestim==0.1
et_xmlfile==2.0.0
exceptiongroup==1.3.1
executing==2.2.1
fonttools==4.62.1
gitdb==4.0.12
GitPython==3.1.49
ipykernel==7.2.0
ipython==8.39.0
jedi==0.19.2
jellyfish==1.2.1
jsonpickle==4.1.1
jupyter_client==8.8.0
jupyter_core==5.9.1
kiwisolver==1.5.0
line_profiler==5.0.2
llvmlite==0.47.0
matplotlib==3.10.9
matplotlib-inline==0.2.1
memory-profiler==0.61.0
multiprocess==0.70.19
nest-asyncio==1.6.0
networkx==3.4.2
numba==0.65.1
numpy==2.2.6
openpyxl==3.1.5
packaging==26.2
pandas==2.3.3
parso==0.8.6
patsy==1.0.2
pexpect==4.9.0
pillow==12.2.0
platformdirs==4.9.6
prompt_toolkit==3.0.52
psutil==7.2.2
ptyprocess==0.7.0
pure_eval==0.2.3
Pygments==2.20.0
pyparsing==3.3.2
python-dateutil==2.9.0.post0
pytz==2026.1.post1
PyYAML==6.0.3
pyzmq==27.1.0
scipy==1.15.3
sciris==3.2.9
six==1.17.0
smmap==5.0.3
stack-data==0.6.3
statsmodels==0.14.6
tomli==2.4.1
tornado==6.5.5
tqdm==4.67.3
traitlets==5.14.3
typing_extensions==4.15.0
tzdata==2026.2
wcwidth==0.6.0
xlsxwriter==3.2.9
zstandard==0.25.0
Empty file added src/__init__.py
Empty file.
1 change: 1 addition & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
NUCLEOTIDES = ('A', 'C', 'G', 'T')
42 changes: 42 additions & 0 deletions src/molecular_clock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import math
import random

from substitution_model import SubstitutionModel


def poisson(rng: random.Random, lam: float) -> int:
if lam <= 0.0:
return 0
L = math.exp(-lam)
k = 0
p = 1.0
while p > L:
k += 1
p *= rng.random()
return k - 1


def molecular_clock_evolve(
seq: str,
branch_time: float,
rate: float,
model: SubstitutionModel,
rng: random.Random,
) -> str:
"""
Evolve a sequence along a branch of length ``branch_time``.

Each site accumulates Poisson(rate * branch_time) substitution events; each event
replaces the nucleotide using that row of the substitution model matrix.
``rate`` must use the same time units as ``branch_time`` (e.g. per site per simulation day).
"""
if branch_time <= 0.0 or rate <= 0.0:
return seq
out = list(seq)
lam = rate * branch_time
for i, nt in enumerate(out):
n_events = poisson(rng, lam)
for _ in range(n_events):
nt = model.substitute_nucleotide(ref_nt=nt, rng=rng)
out[i] = nt
return "".join(out)
82 changes: 77 additions & 5 deletions src/run_sim.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from os import sep
from pathlib import Path
import random
from Bio import SeqIO
import covasim as cv
import covasim.data as cvdata
import numpy as np
Expand All @@ -9,6 +13,19 @@
from networkx.drawing.nx_pydot import graphviz_layout
import argparse

from molecular_clock import molecular_clock_evolve
from substitution_model import JukesCantor

_REPO_ROOT = Path(__file__).resolve().parent.parent
_DEFAULT_REFERENCE = _REPO_ROOT / "assets" / "NC_045512_Hu-1.fasta"


def load_fasta_sequence(path: Path) -> str:
"""Load the first FASTA record as a single uppercase nucleotide string."""
record = next(SeqIO.parse(path, "fasta"))
return str(record.seq).upper()


def define_sim_parameters(pop_size, pop_type,
n_days, location,
pop_infected, n_imports):
Expand Down Expand Up @@ -122,6 +139,52 @@ def viral_shedding_covasim(sim,start,end):
regional_viral_load[t, r] += viral_load_matrix[t, p]
return regional_viral_load

def assign_sequences(sim, reference):
sub_model = JukesCantor()
seq_rng = random.Random(42)

df = pd.DataFrame(sim.people.infection_log)
df.to_csv('infection_log.tsv', sep='\t')
# Track each person's currently carried lineage so reinfections can replace it.
person_current_seq = {}
person_current_start_date = {}

# Store sequence by transmission event (source -> target).
event_seq_by_row = {}

# Process in temporal order so parent lineage state is up to date.
for idx, row in df.sort_values("date", kind="stable").iterrows():
target = int(row["target"])
event_date = float(row["date"])

if pd.isna(row["source"]):
event_seq_by_row[idx] = reference
person_current_seq[target] = reference
person_current_start_date[target] = event_date
continue

source = int(row["source"])
parent_seq = person_current_seq.get(source, reference)
parent_start_date = person_current_start_date.get(source, event_date)
branch_time = max(0.0, event_date - parent_start_date)

child_seq = molecular_clock_evolve(
parent_seq,
branch_time,
rate=1e-7, # mutations per site per day
model=sub_model,
rng=seq_rng,
)

event_seq_by_row[idx] = child_seq
person_current_seq[target] = child_seq
person_current_start_date[target] = event_date

df["sequence"] = pd.Series(event_seq_by_row)

return df



def plot_shedding(data,name):
fig, ax = plt.subplots(figsize=(6,10))
Expand Down Expand Up @@ -170,6 +233,8 @@ def main():
help="Row index for where the infection initiated (default: 0)")
parser.add_argument("--c_init_inf", type=int, default=0,
help="column index for where the infection initiated (default: 0)")
parser.add_argument("--reference", type=Path, default=_DEFAULT_REFERENCE,
help="Path to reference genome FASTA (default: NC_045512_Hu-1.fasta)")

args = parser.parse_args()
### STEP 1 ####
Expand Down Expand Up @@ -199,13 +264,20 @@ def main():
new_cases = calculate_new_infections(sim, n_regions)
# calculate wastewater shedding per region per time point using
# basic shedding model
shedding_simple = viral_shedding_simple(new_cases)
# covasim viral load model (adjust zero if start date is not day 0)
shedding_covasim = viral_shedding_covasim(sim,0,args.n_days)
# shedding_simple = viral_shedding_simple(new_cases)
# # covasim viral load model (adjust zero if start date is not day 0)
# shedding_covasim = viral_shedding_covasim(sim,0,args.n_days)
# plot values
plot_shedding(shedding_simple,"simple")
plot_shedding(shedding_covasim,"covasim")
# plot_shedding(shedding_simple,"simple")
# plot_shedding(shedding_covasim,"covasim")

reference_seq = load_fasta_sequence(args.reference)

transmission_df = assign_sequences(sim, reference_seq)

print(transmission_df['sequence'].nunique())

transmission_df.to_csv("transmission.tsv", sep='\t')

if __name__ == "__main__":
main()
45 changes: 45 additions & 0 deletions src/substitution_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import random

import numpy as np

from constants import NUCLEOTIDES


class SubstitutionModel:
name: str
matrix: np.ndarray

def __init__(self, name, matrix):
self.name = name
self.matrix = np.asarray(matrix, dtype=float)
self.validate_matrix()

def validate_matrix(self) -> None:
if len(self.matrix) != 4:
raise ValueError("Transition matrix must have 4 rows")
for row in self.matrix:
if len(row) != 4:
raise ValueError("Transition matrix must have 4 columns per row")
row_sum = sum(row)
if abs(row_sum - 1.0) > 1e-8:
raise ValueError(f"Each row of the transition matrix must sum to 1.0, got {row_sum}")

def substitute_nucleotide(
self,
ref_nt: str,
rng: random.Random,
) -> str:
"""Sample the child nucleotide from this model's row for ``ref_nt``."""
idx = NUCLEOTIDES.index(ref_nt)
row = self.matrix[idx]
return rng.choices(NUCLEOTIDES, weights=row, k=1)[0]


class JukesCantor(SubstitutionModel):
"""Equal substitution rates to every other nucleotide (diagonal zero, off-diagonal 1/3)."""

def __init__(self):
p = 1.0 / 3.0
matrix = np.full((4, 4), p)
np.fill_diagonal(matrix, 0.0)
super().__init__(name="Jukes-Cantor", matrix=matrix)