Skip to content

Commit e5bb75d

Browse files
AnthchirpRichard Gildea
andauthored
Enable mypy checking for dlstbx (#130)
Happy 2022 Co-authored-by: Richard Gildea <richard.gildea@diamond.ac.uk>
1 parent 0493ce8 commit e5bb75d

14 files changed

Lines changed: 94 additions & 76 deletions

File tree

.pre-commit-config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,11 @@ repos:
3232
- id: check-merge-conflict
3333
- id: check-added-large-files
3434
args: ['--maxkb=200']
35+
36+
# Type checking
37+
- repo: https://github.com/pre-commit/mirrors-mypy
38+
rev: v0.910
39+
hooks:
40+
- id: mypy
41+
files: 'src/.*\.py$'
42+
additional_dependencies: ['types-pytz==2021.3.2', 'types-python-dateutil==2.8.3', 'types-setuptools==57.4.4', 'types-PyYAML==6.0.1', 'types-pycurl==7.44.1', 'types-requests==2.26.1', 'types-certifi==2021.10.8.0', 'types-pyOpenSSL==21.0.1']

src/dlstbx/cli/fix_cluster_jobs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,9 @@ def get_resubmission_id_for_job(cluster, jobid):
103103
if resubmission_id:
104104
error_db[j] = resubmission_id
105105
else:
106-
error_db[j] = str(uuid.uuid4())
107-
cluster.qalter(j, ["-ac", "resubmission_id=" + error_db[j]])
106+
unique_id = str(uuid.uuid4())
107+
cluster.qalter(j, ["-ac", f"resubmission_id={unique_id}"])
108+
error_db[j] = unique_id
108109
bar.update(1)
109110

110111
removable_jobs = [j for j in error_db if error_db[j] in resubmission_db]

src/dlstbx/cli/status_monitor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# Starts a status monitor (what do you expect?)
44
#
55

6+
from __future__ import annotations
7+
68
import argparse
79
import curses
810
import os
@@ -11,6 +13,7 @@
1113
import threading
1214
import time
1315
from pprint import pprint
16+
from typing import Any, Tuple
1417

1518
import workflows
1619
import workflows.services
@@ -37,15 +40,15 @@ class Monitor:
3740
shutdown = False
3841
"""Set to true to end the main loop and shut down the service monitor."""
3942

40-
cards = {}
43+
cards: dict[int, Any] = {}
4144
"""Register card shown for seen services"""
4245

4346
border_chars = ()
4447
"""Characters used for frame borders."""
4548
border_chars_text = ("|", "|", "=", "=", "/", "\\", "\\", "/")
4649
"""Example alternative set of frame border characters."""
4750

48-
most_recent_version = {}
51+
most_recent_version: dict[Any, Tuple[Any, Any]] = {}
4952
"""Dictionary to hold software version information, so old versions can be highlighted."""
5053

5154
def __init__(self, filters=None, transport=None, version=None, test=False):

src/dlstbx/health_checks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class CheckFunctionInterface(NamedTuple):
2626
name: str
2727

2828

29-
class Status(_Base):
29+
class Status(_Base): # type: ignore
3030
__tablename__ = "infrastructure_status"
3131
__table_args__ = {"comment": "reports of DLS infrastructure"}
3232

src/dlstbx/health_checks/quota.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _check_quota_for_user(
3636
entries = [entry for entry in _quotafile() if username in entry]
3737
limits = {"notice": 0.925, "warning": 0.95, "error": 0.975}
3838

39-
def report_level(value: float, limit: float) -> REPORT:
39+
def report_level(value: float, limit: float) -> int:
4040
if value >= limits["error"] * limit:
4141
return REPORT.ERROR
4242
if value >= limits["warning"] * limit:
@@ -52,7 +52,7 @@ def report_level(value: float, limit: float) -> REPORT:
5252
Level=REPORT.ERROR,
5353
Message=f"Only {len(entries)} entries found in quota file for {username}",
5454
)
55-
notices: list[tuple[REPORT, str]] = []
55+
notices: list[tuple[int, str]] = []
5656

5757
for entry in entries:
5858
tokens = entry.split()

src/dlstbx/services/images.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import re
55
import time
6-
from typing import Any, Callable, Dict, NamedTuple
6+
from typing import Any, Callable, Dict, NamedTuple, Protocol
77

88
import PIL.Image
99
import pkg_resources
@@ -14,9 +14,14 @@
1414
logger = logging.getLogger("dlstbx.services.images")
1515

1616

17+
class _CallableParameter(Protocol):
18+
def __call__(self, key: str, default: Any = ...) -> Any:
19+
...
20+
21+
1722
class PluginInterface(NamedTuple):
1823
rw: workflows.recipe.wrapper.RecipeWrapper
19-
parameters: Callable[[str], Any]
24+
parameters: _CallableParameter
2025
message: Dict[str, Any]
2126

2227

@@ -61,7 +66,7 @@ def image_call(self, rw, header, message):
6166
"""Pass incoming message to the relevant plugin function."""
6267
command = rw.recipe_step.get("parameters", {}).get("image_command")
6368

64-
def parameters(key, default=None):
69+
def parameters(key: str, default=None):
6570
if isinstance(message, dict) and message.get(key):
6671
return message[key]
6772
return rw.recipe_step.get("parameters", {}).get(key, default)

src/dlstbx/services/trigger.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from __future__ import annotations
2+
13
import hashlib
24
import logging
35
import os
46
import pathlib
57
from dataclasses import dataclass
68
from datetime import datetime
7-
from typing import Dict, List, Literal, Mapping, Optional
9+
from typing import Any, Dict, List, Literal, Mapping, Optional
810

911
import ispyb
1012
import pydantic
@@ -351,28 +353,28 @@ def trigger_dimple(
351353

352354
dcid = parameters.dcid
353355

354-
pdb_files = self.get_linked_pdb_files_for_dcid(
356+
pdb_files_or_codes = self.get_linked_pdb_files_for_dcid(
355357
session,
356358
dcid,
357359
parameters.pdb_tmpdir,
358360
user_pdb_dir=parameters.user_pdb_directory,
359361
)
360362

361-
if not pdb_files:
363+
if not pdb_files_or_codes:
362364
self.log.info(
363365
"Skipping dimple trigger: DCID %s has no associated PDB information"
364366
% dcid
365367
)
366368
return {"success": True}
367-
pdb_files = [str(p) for p in pdb_files]
369+
pdb_files = [str(p) for p in pdb_files_or_codes]
368370
self.log.info("PDB files: %s", ", ".join(pdb_files))
369371

370372
dc = (
371373
session.query(DataCollection)
372374
.filter(DataCollection.dataCollectionId == dcid)
373375
.one()
374376
)
375-
dimple_parameters = {
377+
dimple_parameters: dict[str, list[Any]] = {
376378
"data": [os.fspath(parameters.mtz)],
377379
"scaling_id": [parameters.scaling_id],
378380
"pdb": pdb_files,
@@ -766,17 +768,19 @@ def trigger_mrbump(
766768
self.log.info("Skipping mrbump trigger: sequence information not available")
767769
return {"success": True}
768770

769-
pdb_files = self.get_linked_pdb_files_for_dcid(
770-
session,
771-
dcid,
772-
parameters.pdb_tmpdir,
773-
user_pdb_dir=parameters.user_pdb_directory,
774-
ignore_pdb_codes=True,
771+
pdb_files = tuple(
772+
self.get_linked_pdb_files_for_dcid(
773+
session,
774+
dcid,
775+
parameters.pdb_tmpdir,
776+
user_pdb_dir=parameters.user_pdb_directory,
777+
ignore_pdb_codes=True,
778+
)
775779
)
776780

777781
jobids = []
778782

779-
for pdb_files in {(), tuple(pdb_files)}:
783+
for pdb_files in {(), pdb_files}:
780784
jp = self.ispyb.mx_processing.get_job_params()
781785
jp["automatic"] = parameters.automatic
782786
jp["comments"] = parameters.comment
@@ -806,6 +810,7 @@ def trigger_mrbump(
806810
self.log.debug(f"mrbump trigger: generated JobParameterID {jppid}")
807811

808812
for pdb_file in pdb_files:
813+
assert pdb_file.filepath is not None
809814
filepath = pdb_file.filepath
810815
if pdb_file.source == "AlphaFold":
811816
trimmed = filepath.with_name(
@@ -1130,7 +1135,7 @@ def trigger_big_ep(
11301135
class BigEPParams(pydantic.BaseModel):
11311136
data: pathlib.Path
11321137
scaled_unmerged_mtz: pathlib.Path
1133-
path_ext: Optional[str] = pydantic.Field(
1138+
path_ext: str = pydantic.Field(
11341139
default_factory=lambda: datetime.now().strftime("%Y%m%d_%H%M%S")
11351140
)
11361141

@@ -1294,7 +1299,7 @@ def trigger_multiplex(
12941299
status["ntry"] += 1
12951300
self.log.debug(f"dcid={dcid}\nmessage_delay={message_delay}\n{status}")
12961301

1297-
multiplex_job_dcids = []
1302+
multiplex_job_dcids: list[set[int]] = []
12981303
jobids = []
12991304

13001305
for group in related_dcids:

src/dlstbx/test/services/test_trigger.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,8 @@ def test_alphafold_not_triggered_for_industry_proposal(
765765
with caplog.at_level(logging.DEBUG):
766766
trigger.trigger(rw, {"some": "header"}, message)
767767
assert "Not triggering AlphaFold for protein_id" in caplog.text
768-
send.assert_not_called()
768+
send.assert_called_once_with({"result": None}, transaction=mocker.ANY)
769+
t.send.assert_not_called()
769770

770771

771772
def test_invalid_params(db_session_factory, caplog):

src/dlstbx/util/certificate.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@ def problems_with_certificate(hostname: str) -> Optional[str]:
1515
with closing(socket.socket()) as sock:
1616
sock.connect((hostname, 443))
1717
ctx = SSL.Context(SSL.SSLv23_METHOD) # most compatible
18-
ctx.check_hostname = False
19-
ctx.verify_mode = SSL.VERIFY_NONE
20-
2118
with closing(SSL.Connection(ctx, sock)) as sock_ssl:
2219
sock_ssl.set_connect_state()
2320
sock_ssl.set_tlsext_host_name(idna.encode(hostname))

src/dlstbx/util/colorstreamhandler.py

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -50,28 +50,23 @@
5050
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
5151
# THE SOFTWARE.
5252

53+
from __future__ import annotations
5354

5455
import logging
5556
import platform
57+
from typing import Union
5658

5759

58-
class _AnsiColorStreamHandler(logging.StreamHandler):
59-
DEFAULT = "\x1b[0m"
60-
RED = "\x1b[31m"
61-
GREEN = "\x1b[32m"
62-
YELLOW = "\x1b[33m"
63-
BLUE = "\x1b[34m"
64-
CYAN = "\x1b[36m"
65-
BOLD = "\x1b[1m"
66-
67-
CRITICAL = RED + BOLD
68-
ERROR = RED
69-
WARNING = YELLOW
70-
INFO = GREEN
71-
DEBUG = BLUE
60+
class _GenericColorStreamHandler(logging.StreamHandler):
61+
DEFAULT: Union[int, str] = ""
62+
CRITICAL: Union[int, str] = ""
63+
ERROR: Union[int, str] = ""
64+
WARNING: Union[int, str] = ""
65+
INFO: Union[int, str] = ""
66+
DEBUG: Union[int, str] = ""
7267

7368
@classmethod
74-
def _get_color(cls, level):
69+
def _get_color(cls, level: int) -> Union[int, str]:
7570
if level >= logging.CRITICAL:
7671
return cls.CRITICAL
7772
elif level >= logging.ERROR:
@@ -85,16 +80,29 @@ def _get_color(cls, level):
8580
else:
8681
return cls.DEFAULT
8782

88-
def __init__(self, stream=None):
89-
logging.StreamHandler.__init__(self, stream)
83+
84+
class _AnsiColorStreamHandler(_GenericColorStreamHandler):
85+
DEFAULT = "\x1b[0m"
86+
RED = "\x1b[31m"
87+
GREEN = "\x1b[32m"
88+
YELLOW = "\x1b[33m"
89+
BLUE = "\x1b[34m"
90+
CYAN = "\x1b[36m"
91+
BOLD = "\x1b[1m"
92+
93+
CRITICAL = RED + BOLD
94+
ERROR = RED
95+
WARNING = YELLOW
96+
INFO = GREEN
97+
DEBUG = BLUE
9098

9199
def format(self, record):
92100
text = logging.StreamHandler.format(self, record)
93101
color = self._get_color(record.levelno)
94102
return color + text + self.DEFAULT
95103

96104

97-
class _WinColorStreamHandler(logging.StreamHandler):
105+
class _WinColorStreamHandler(_GenericColorStreamHandler):
98106
# wincon.h
99107
FOREGROUND_BLACK = 0x0000
100108
FOREGROUND_BLUE = 0x0001
@@ -126,28 +134,14 @@ class _WinColorStreamHandler(logging.StreamHandler):
126134
INFO = FOREGROUND_GREEN
127135
DEBUG = FOREGROUND_CYAN
128136

129-
@classmethod
130-
def _get_color(cls, level):
131-
if level >= logging.CRITICAL:
132-
return cls.CRITICAL
133-
elif level >= logging.ERROR:
134-
return cls.ERROR
135-
elif level >= logging.WARNING:
136-
return cls.WARNING
137-
elif level >= logging.INFO:
138-
return cls.INFO
139-
elif level >= logging.DEBUG:
140-
return cls.DEBUG
141-
else:
142-
return cls.DEFAULT
143-
144137
def _set_color(self, code):
145138
import ctypes
146139

147140
ctypes.windll.kernel32.SetConsoleTextAttribute(self._outhdl, code)
148141

149142
def __init__(self, stream=None):
150-
logging.StreamHandler.__init__(self, stream)
143+
super().__init__(stream)
144+
151145
# get file handle for the stream
152146
import ctypes
153147
import ctypes.util
@@ -167,6 +161,7 @@ def emit(self, record):
167161

168162

169163
# select ColorStreamHandler based on platform
164+
ColorStreamHandler: type[_GenericColorStreamHandler]
170165
if platform.system() == "Windows":
171166
ColorStreamHandler = _WinColorStreamHandler
172167
else:

0 commit comments

Comments
 (0)