diff --git a/dargs/cli.py b/dargs/cli.py index b0df919..d2d6f08 100644 --- a/dargs/cli.py +++ b/dargs/cli.py @@ -3,7 +3,7 @@ import argparse import json import sys -from typing import IO +from typing import IO, Any from dargs._version import __version__ from dargs.check import check @@ -79,7 +79,7 @@ def main_parser() -> argparse.ArgumentParser: return parser -def main(): +def main() -> None: """Main entry point for the command line interface.""" parser = main_parser() args = parser.parse_args() @@ -92,7 +92,7 @@ def check_cli( func: str, jdata: list[IO], strict: bool, - **kwargs, + **kwargs: Any, ) -> None: """Normalize and check input data. @@ -131,7 +131,7 @@ def doc_cli( *, func: str, arg: str | None = None, - **kwargs, + **kwargs: Any, ) -> None: """Print documentation for an Argument. diff --git a/dargs/dargs.py b/dargs/dargs.py index 3d7734b..8a3d199 100644 --- a/dargs/dargs.py +++ b/dargs/dargs.py @@ -43,7 +43,7 @@ HookVrntType = Callable[["Variant", dict, List[str]], None] -def _DUMMYHOOK(a, x, p): +def _DUMMYHOOK(a: Argument | Variant, x: dict | Any, p: list[str]) -> None: # for doing nothing in traversing pass @@ -56,7 +56,9 @@ class _Flags(Enum): class ArgumentError(Exception): """Base error class for invalid argument values in argchecking.""" - def __init__(self, path: None | str | list[str] = None, message: str | None = None): + def __init__( + self, path: None | str | list[str] = None, message: str | None = None + ) -> None: super().__init__(message) if path is None: path = "" @@ -150,7 +152,7 @@ def __init__( doc: str = "", fold_subdoc: bool = False, extra_check_errmsg: str = "", - ): + ) -> None: self.name = name self.sub_fields: dict[str, Argument] = {} self.sub_variants: dict[str, Variant] = {} @@ -174,10 +176,10 @@ def __eq__(self, other: object) -> bool: # do not compare doc and default # since they do not enter to the type checking - def fkey(f): + def fkey(f: Argument) -> str: return f.name - def vkey(v): + def vkey(v: Variant) -> str: return v.flag_name return ( @@ -213,7 +215,7 @@ def __getitem__(self, key: str) -> Argument: return self[skey][rkey] @property - def I(self): # noqa:E743 + def I(self) -> Argument: # noqa:E743 # return a dummy argument that only has self as a sub field # can be used in indexing return Argument("_", dict, [self]) @@ -255,16 +257,16 @@ def _reorg_dtype( # and make it compatible with `isinstance` return tuple(dtype) - def set_dtype(self, dtype: None | type | Iterable[type]): + def set_dtype(self, dtype: None | type | Iterable[type]) -> None: """Change the dtype of the current Argument.""" self.dtype = self._reorg_dtype(dtype) - def set_repeat(self, repeat: bool = True): + def set_repeat(self, repeat: bool = True) -> None: """Change the repeat attribute of the current Argument.""" self.repeat = repeat self.dtype = self._reorg_dtype(self.dtype) - def extend_subfields(self, sub_fields: Iterable[Argument] | None): + def extend_subfields(self, sub_fields: Iterable[Argument] | None) -> None: """Add a list of sub fields to the current Argument.""" if sub_fields is None: return @@ -276,7 +278,7 @@ def extend_subfields(self, sub_fields: Iterable[Argument] | None): ) self.dtype = self._reorg_dtype(self.dtype) - def add_subfield(self, name: str | Argument, *args, **kwargs) -> Argument: + def add_subfield(self, name: str | Argument, *args: Any, **kwargs: Any) -> Argument: """Add a sub field to the current Argument.""" if isinstance(name, Argument): newarg = name @@ -285,7 +287,7 @@ def add_subfield(self, name: str | Argument, *args, **kwargs) -> Argument: self.extend_subfields([newarg]) return newarg - def extend_subvariants(self, sub_variants: Iterable[Variant] | None): + def extend_subvariants(self, sub_variants: Iterable[Variant] | None) -> None: """Add a list of sub variants to the current Argument.""" if sub_variants is None: return @@ -298,7 +300,9 @@ def extend_subvariants(self, sub_variants: Iterable[Variant] | None): ) self.dtype = self._reorg_dtype(self.dtype) - def add_subvariant(self, flag_name: str | Variant, *args, **kwargs) -> Variant: + def add_subvariant( + self, flag_name: str | Variant, *args: Any, **kwargs: Any + ) -> Variant: """Add a sub variant to the current Argument.""" if isinstance(flag_name, Variant): newvrnt = flag_name @@ -310,7 +314,9 @@ def add_subvariant(self, flag_name: str | Variant, *args, **kwargs) -> Variant: # above are creation part # below are general traverse part - def flatten_sub(self, value: dict, path=None) -> dict[str, Argument]: + def flatten_sub( + self, value: dict, path: list[str] | None = None + ) -> dict[str, Argument]: sub_dicts = [self.sub_fields] sub_dicts.extend( vrnt.flatten_sub(value, path) for vrnt in self.sub_variants.values() @@ -329,7 +335,7 @@ def traverse( sub_hook: HookArgKType = _DUMMYHOOK, variant_hook: HookVrntType = _DUMMYHOOK, path: list[str] | None = None, - ): + ) -> None: # first, do something with the key # then, take out the vaule and do something with it if path is None: @@ -352,7 +358,7 @@ def traverse_value( sub_hook: HookArgKType = _DUMMYHOOK, variant_hook: HookVrntType = _DUMMYHOOK, path: list[str] | None = None, - ): + ) -> None: # this is not private, and can be called directly # in the condition where there is no leading key if path is None: @@ -390,7 +396,7 @@ def _traverse_sub( sub_hook: HookArgKType = _DUMMYHOOK, variant_hook: HookVrntType = _DUMMYHOOK, path: list[str] | None = None, - ): + ) -> None: if path is None: path = [self.name] if not isinstance(value, dict): @@ -408,7 +414,7 @@ def _traverse_sub( # above are general traverse part # below are type checking part - def check(self, argdict: dict, strict: bool = False): + def check(self, argdict: dict, strict: bool = False) -> None: """Check whether `argdict` meets the structure defined in self. Will recursively check nested dicts according to @@ -435,7 +441,7 @@ def check(self, argdict: dict, strict: bool = False): sub_hook=Argument._check_strict if strict else _DUMMYHOOK, ) - def check_value(self, value: Any, strict: bool = False): + def check_value(self, value: Any, strict: bool = False) -> None: """Check the value without the leading key. Same as `check({self.name: value})`. @@ -455,7 +461,7 @@ def check_value(self, value: Any, strict: bool = False): sub_hook=Argument._check_strict if strict else _DUMMYHOOK, ) - def _check_exist(self, argdict: dict, path=None): + def _check_exist(self, argdict: dict, path: list[str] | None = None) -> None: if self.optional is True: return if self.name not in argdict: @@ -463,7 +469,7 @@ def _check_exist(self, argdict: dict, path=None): path, f"key `{self.name}` is required in arguments but not found" ) - def _check_data(self, value: Any, path=None): + def _check_data(self, value: Any, path: list[str] | None = None) -> None: try: typeguard.check_type( value, @@ -484,7 +490,7 @@ def _check_data(self, value: Any, path=None): "that fails to pass its extra checking. " + self.extra_check_errmsg, ) - def _check_strict(self, value: dict, path=None): + def _check_strict(self, value: dict, path: list[str] | None = None) -> None: allowed_keys = set(self.flatten_sub(value, path).keys()) # curpath = [*path, self.name] if not len(allowed_keys): @@ -512,7 +518,7 @@ def normalize( do_default: bool = True, do_alias: bool = True, trim_pattern: str | None = None, - ): + ) -> dict: """Modify `argdict` so that it meets the Argument structure. Normalization can add default values to optional args, @@ -565,7 +571,7 @@ def normalize_value( do_default: bool = True, do_alias: bool = True, trim_pattern: str | None = None, - ): + ) -> Any: """Modify the value so that it meets the Argument structure. Same as `normalize({self.name: value})[self.name]`. @@ -608,7 +614,7 @@ def normalize_value( ) return value - def _assign_default(self, argdict: dict, path=None): + def _assign_default(self, argdict: dict, path: list[str] | None = None) -> None: if ( self.name not in argdict and self.optional @@ -617,11 +623,11 @@ def _assign_default(self, argdict: dict, path=None): default = self.default if self.default != {} else _Flags.EMPTY_DICT argdict[self.name] = default - def _handle_empty_dict(self, argdict: dict, path=None): + def _handle_empty_dict(self, argdict: dict, path: list[str] | None = None) -> None: if argdict.get(self.name, None) is _Flags.EMPTY_DICT: argdict[self.name] = {} - def _convert_alias(self, argdict: dict, path=None): + def _convert_alias(self, argdict: dict, path: list[str] | None = None) -> None: if self.name not in argdict: for alias in self.alias: if alias in argdict: @@ -631,7 +637,7 @@ def _convert_alias(self, argdict: dict, path=None): # above are normalizing part # below are doc generation part - def gen_doc(self, path: list[str] | None = None, **kwargs) -> str: + def gen_doc(self, path: list[str] | None = None, **kwargs: Any) -> str: """Generate doc string for the current Argument.""" # the actual indentation is done here, and ONLY here if path is None: @@ -644,7 +650,7 @@ def gen_doc(self, path: list[str] | None = None, **kwargs) -> str: ] return "\n".join(filter(None, doc_list)) - def gen_doc_head(self, path: list[str] | None = None, **kwargs) -> str: + def gen_doc_head(self, path: list[str] | None = None, **kwargs: Any) -> str: typesig = "| type: " + " | ".join( [f"``{self._get_type_name(dt)}``" for dt in self.dtype] ) @@ -663,16 +669,17 @@ def gen_doc_head(self, path: list[str] | None = None, **kwargs) -> str: head = f".. dargs:argument:: {self.name}:\n :path: {'/'.join(path)}\n" head += f"\n{indent(typesig, INDENT)}" if kwargs.get("make_anchor"): - head = f"{make_rst_refid(path)}\n" + head + anchor_path = path if path is not None else [self.name] + head = f"{make_rst_refid(anchor_path)}\n" + head return head - def gen_doc_path(self, path: list[str] | None = None, **kwargs) -> str: + def gen_doc_path(self, path: list[str] | None = None, **kwargs: Any) -> str: if path is None: path = [self.name] pathdoc = f"| argument path: ``{'/'.join(path)}``\n" return pathdoc - def gen_doc_body(self, path: list[str] | None = None, **kwargs) -> str: + def gen_doc_body(self, path: list[str] | None = None, **kwargs: Any) -> str: body_list = [] if self.doc: body_list.append(self.doc + "\n") @@ -707,8 +714,10 @@ def gen_doc_body(self, path: list[str] | None = None, **kwargs) -> str: body = "\n".join(body_list) return body - def _get_type_name(self, dd) -> str: + def _get_type_name(self, dd: type | Any | None) -> str: """Get type name for doc/message generation.""" + if dd is None: + return "None" return str(dd) if isinstance(get_origin(dd), type) else dd.__name__ @@ -747,7 +756,7 @@ def __init__( optional: bool = False, default_tag: str = "", # this is indeed necessary in case of optional doc: str = "", - ): + ) -> None: self.flag_name = flag_name self.choice_dict: dict[str, Argument] = {} self.choice_alias: dict[str, str] = {} @@ -777,7 +786,7 @@ def __repr__(self) -> str: def __getitem__(self, key: str) -> Argument: return self.choice_dict[key] - def set_default(self, default_tag: bool | str): + def set_default(self, default_tag: bool | str) -> None: """Change the default tag of the current Variant.""" if not default_tag: self.optional = False @@ -788,7 +797,7 @@ def set_default(self, default_tag: bool | str): self.optional = True self.default_tag = default_tag - def extend_choices(self, choices: Iterable[Argument] | None): + def extend_choices(self, choices: Iterable[Argument] | None) -> None: """Add a list of choice Arguments to the current Variant.""" # choices is a list of arguments # whose name is treated as the switch tag @@ -813,8 +822,8 @@ def add_choice( self, tag: str | Argument, _dtype: None | type | Iterable[type] = dict, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> Argument: """Add a choice Argument to the current Variant.""" if isinstance(tag, Argument): @@ -824,7 +833,7 @@ def add_choice( self.extend_choices([newarg]) return newarg - def dummy_argument(self): + def dummy_argument(self) -> Argument: return Argument( name=self.flag_name, dtype=str, @@ -841,7 +850,7 @@ def dummy_argument(self): # above are creation part # below are helpers for traversing - def get_choice(self, argdict: dict, path=None) -> Argument: + def get_choice(self, argdict: dict, path: list[str] | None = None) -> Argument: if self.flag_name in argdict: tag = argdict[self.flag_name] if tag in self.choice_dict: @@ -865,7 +874,9 @@ def get_choice(self, argdict: dict, path=None) -> Argument: f"key `{self.flag_name}` is required to choose variant but not found.", ) - def flatten_sub(self, argdict: dict, path=None) -> dict[str, Argument]: + def flatten_sub( + self, argdict: dict, path: list[str] | None = None + ) -> dict[str, Argument]: choice = self.get_choice(argdict, path) fields = { self.flag_name: self.dummy_argument(), # as a placeholder @@ -873,7 +884,9 @@ def flatten_sub(self, argdict: dict, path=None) -> dict[str, Argument]: } return fields - def _convert_choice_alias(self, argdict: dict, path=None): + def _convert_choice_alias( + self, argdict: dict, path: list[str] | None = None + ) -> None: if self.flag_name in argdict: tag = argdict[self.flag_name] if tag not in self.choice_dict and tag in self.choice_alias: @@ -883,7 +896,7 @@ def _convert_choice_alias(self, argdict: dict, path=None): # below are doc generation part def gen_doc( - self, path: list[str] | None = None, showflag: bool = False, **kwargs + self, path: list[str] | None = None, showflag: bool = False, **kwargs: Any ) -> str: body_list = [""] body_list.append( @@ -919,7 +932,7 @@ def gen_doc( body = "\n".join(body_list) return body - def gen_doc_flag(self, path: list[str] | None = None, **kwargs) -> str: + def gen_doc_flag(self, path: list[str] | None = None, **kwargs: Any) -> str: headdoc = f"{self.flag_name}:" typedoc = "| type: ``str`` (flag key)" if self.optional: @@ -978,14 +991,14 @@ def gen_doc_flag(self, path: list[str] | None = None, **kwargs) -> str: def _make_cpath( self, cname: str, path: list[str] | None = None, showflag: bool = False - ): + ) -> list[str]: f_str = f"{self.flag_name}=" if showflag else "" c_str = f"[{f_str}{cname}]" cpath = [*path[:-1], path[-1] + c_str] if path else [c_str] return cpath -def make_rst_refid(name): +def make_rst_refid(name: str | list[str]) -> str: if not isinstance(name, str): name = "/".join(name) return ( @@ -995,7 +1008,11 @@ def make_rst_refid(name): ) -def make_ref_pair(path, text=None, prefix=None): +def make_ref_pair( + path: str | list[str], + text: str | None = None, + prefix: str | None = None, +) -> tuple[str, str]: if not isinstance(path, str): path = "/".join(path) tgt = f"`{path}`_" if not RAW_ANCHOR else f"#{path}" @@ -1012,7 +1029,7 @@ def update_nodup( *others: dict | Iterable[tuple], exclude: Iterable | None = None, err_msg: str | None = None, -): +) -> dict: for pair in others: if isinstance(pair, dict): pair = pair.items() @@ -1031,7 +1048,7 @@ def trim_by_pattern( pattern: str, reserved: Iterable[str] | None = None, use_regex: bool = False, -): +) -> None: rep = fnmatch.translate(pattern) if not use_regex else pattern rem = re.compile(rep) if reserved: @@ -1046,7 +1063,7 @@ def trim_by_pattern( argdict.pop(key) -def isinstance_annotation(value, dtype) -> bool: +def isinstance_annotation(value: Any, dtype: type | Any) -> bool: """Same as isinstance(), but supports arbitrary type annotations.""" try: typeguard.check_type( @@ -1067,7 +1084,7 @@ class ArgumentEncoder(json.JSONEncoder): >>> json.dumps(some_arg, cls=ArgumentEncoder) """ - def default(self, o) -> Any: + def default(self, o: Any) -> Any: """Generate a dict containing argument information, making it ready to be encoded to JSON string. diff --git a/dargs/notebook.py b/dargs/notebook.py index 9d1849f..96404b6 100644 --- a/dargs/notebook.py +++ b/dargs/notebook.py @@ -90,7 +90,7 @@ """ -def JSON(data: dict | str, arg: Argument | list[Argument]): +def JSON(data: dict | str, arg: Argument | list[Argument]) -> None: """Display JSON data with Argument in the Jupyter Notebook. Parameters @@ -151,14 +151,16 @@ class ArgumentData: The argument is repeat """ - def __init__(self, data: dict, arg: Argument | Variant, repeat: bool = False): + def __init__( + self, data: dict, arg: Argument | Variant, repeat: bool = False + ) -> None: self.data = data self.arg = arg self.repeat = repeat self.subdata = [] self._init_subdata() - def _init_subdata(self): + def _init_subdata(self) -> None: """Initialize sub ArgumentData.""" if ( isinstance(self.data, dict) @@ -198,7 +200,7 @@ def _init_subdata(self): for dd in self.data.values(): self.subdata.append(ArgumentData(dd, self.arg, repeat=True)) - def print_html(self, _level=0, _last_one=True): + def print_html(self, _level: int = 0, _last_one: bool = True) -> str: """Print the data with Argument in HTML format. Parameters diff --git a/dargs/sphinx.py b/dargs/sphinx.py index 8eb0b9e..9832fff 100644 --- a/dargs/sphinx.py +++ b/dargs/sphinx.py @@ -48,7 +48,7 @@ class DargsDirective(Directive): "func": unchanged, } - def run(self): + def run(self) -> list: if "module" in self.options and "func" in self.options: module_name = self.options["module"] attr_name = self.options["func"] @@ -95,11 +95,11 @@ class DargsObject(ObjectDescription): "path": unchanged, } - def handle_signature(self, sig, signode): + def handle_signature(self, sig: str, signode: Any) -> str: signode += addnodes.desc_name(sig, sig) return sig - def add_target_and_index(self, name, sig, signode): + def add_target_and_index(self, name: str, sig: str, signode: Any) -> None: path = self.options["path"] targetid = f"{self.objtype}:{path}" if targetid not in self.state.document.ids: @@ -151,7 +151,16 @@ class DargsDomain(Domain): "arguments": {}, # fullname -> docname, objtype } - def resolve_xref(self, env, fromdocname, builder, typ, target, node, contnode): + def resolve_xref( + self, + env: Any, + fromdocname: str, + builder: Any, + typ: str, + target: str, + node: Any, + contnode: Any, + ) -> Any: """Resolve cross-references.""" targetid = f"{typ}:{target}" obj = self.data["arguments"].get(targetid) @@ -160,7 +169,7 @@ def resolve_xref(self, env, fromdocname, builder, typ, target, node, contnode): return make_refnode(builder, fromdocname, obj[0], targetid, contnode, target) -def setup(app): +def setup(app: Any) -> dict[str, bool]: """Setup sphinx app.""" app.add_directive("dargs", DargsDirective) app.add_domain(DargsDomain) diff --git a/docs/conf.py b/docs/conf.py index a274020..181b36b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,6 +16,7 @@ import os import sys from datetime import date +from typing import Any sys.path.insert(0, os.path.abspath("..")) @@ -166,7 +167,7 @@ # -- Extension configuration ------------------------------------------------- -def run_apidoc(_): +def run_apidoc(_: object) -> None: from sphinx.ext.apidoc import main sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -187,7 +188,7 @@ def run_apidoc(_): ) -def setup(app): +def setup(app: Any) -> None: app.connect("builder-inited", run_apidoc) diff --git a/pyproject.toml b/pyproject.toml index 1370c44..f0ae4ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ select = [ "RUF", # ruff "I", # isort "TCH", # flake8-type-checking + "ANN", # flake8-annotations "B904", # raise-without-from-inside-except ] @@ -74,9 +75,13 @@ ignore = [ "D205", # 1 blank line required between summary line and description "D401", # TODO: first line should be in imperative mood "D404", # TODO: first word of the docstring should not be This + "ANN401", # Allow typing.Any - necessary for a library handling arbitrary types ] ignore-init-module-imports = true +[tool.ruff.lint.flake8-annotations] +allow-star-arg-any = true + [tool.ruff.lint.pydocstyle] convention = "numpy" diff --git a/tests/dpmdargs.py b/tests/dpmdargs.py index 72b2519..f40d2fc 100644 --- a/tests/dpmdargs.py +++ b/tests/dpmdargs.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + from dargs import Argument, Variant, dargs ACTIVATION_FN_DICT = { @@ -19,7 +21,7 @@ } -def list_to_doc(xx): +def list_to_doc(xx: list) -> str: items = [] for ii in xx: if len(items) == 0: @@ -30,7 +32,7 @@ def list_to_doc(xx): return "".join(items) -def make_link(content, ref_key): +def make_link(content: str, ref_key: str) -> str: return ( f"`{content} <{ref_key}_>`_" if not dargs.RAW_ANCHOR @@ -38,7 +40,7 @@ def make_link(content, ref_key): ) -def descrpt_local_frame_args(): +def descrpt_local_frame_args() -> list[Argument]: doc_sel_a = "A list of integers. The length of the list should be the same as the number of atom types in the system. `sel_a[i]` gives the selected number of type-i neighbors. The full relative coordinates of the neighbors are used by the descriptor." doc_sel_r = "A list of integers. The length of the list should be the same as the number of atom types in the system. `sel_r[i]` gives the selected number of type-i neighbors. Only relative distance of the neighbors are used by the descriptor. sel_a[i] + sel_r[i] is recommended to be larger than the maximally possible number of type-i neighbors in the cut-off radius." doc_rcut = "The cut-off radius. The default value is 6.0" @@ -58,7 +60,7 @@ def descrpt_local_frame_args(): ] -def descrpt_se_a_args(): +def descrpt_se_a_args() -> list[Argument]: doc_sel = "A list of integers. The length of the list should be the same as the number of atom types in the system. `sel[i]` gives the selected number of type-i neighbors. `sel[i]` is recommended to be larger than the maximally possible number of type-i neighbors in the cut-off radius." doc_rcut = "The cut-off radius." doc_rcut_smth = "Where to start smoothing. For example the 1/r term is smoothed from `rcut` to `rcut_smth`" @@ -102,7 +104,7 @@ def descrpt_se_a_args(): ] -def descrpt_se_a_3be_args(): +def descrpt_se_a_3be_args() -> list[Argument]: doc_sel = "A list of integers. The length of the list should be the same as the number of atom types in the system. `sel[i]` gives the selected number of type-i neighbors. `sel[i]` is recommended to be larger than the maximally possible number of type-i neighbors in the cut-off radius." doc_rcut = "The cut-off radius." doc_rcut_smth = "Where to start smoothing. For example the 1/r term is smoothed from `rcut` to `rcut_smth`" @@ -136,7 +138,7 @@ def descrpt_se_a_3be_args(): ] -def descrpt_se_a_tpe_args(): +def descrpt_se_a_tpe_args() -> list[Argument]: doc_type_nchanl = "number of channels for type embedding" doc_type_nlayer = "number of hidden layers of type embedding net" doc_numb_aparam = "dimension of atomic parameter. if set to a value > 0, the atomic parameters are embedded." @@ -149,7 +151,7 @@ def descrpt_se_a_tpe_args(): ] -def descrpt_se_r_args(): +def descrpt_se_r_args() -> list[Argument]: doc_sel = "A list of integers. The length of the list should be the same as the number of atom types in the system. `sel[i]` gives the selected number of type-i neighbors. `sel[i]` is recommended to be larger than the maximally possible number of type-i neighbors in the cut-off radius." doc_rcut = "The cut-off radius." doc_rcut_smth = "Where to start smoothing. For example the 1/r term is smoothed from `rcut` to `rcut_smth`" @@ -191,7 +193,7 @@ def descrpt_se_r_args(): ] -def descrpt_se_ar_args(): +def descrpt_se_ar_args() -> list[Argument]: link = make_link("se_a", "model/descriptor[se_a]") doc_a = f"The parameters of descriptor {link}" link = make_link("se_r", "model/descriptor[se_r]") @@ -203,7 +205,7 @@ def descrpt_se_ar_args(): ] -def descrpt_hybrid_args(): +def descrpt_hybrid_args() -> list[Argument]: doc_list = "A list of descriptor definitions" return [ @@ -238,7 +240,7 @@ def descrpt_hybrid_args(): ] -def descrpt_variant_type_args(): +def descrpt_variant_type_args() -> Variant: link_lf = make_link("loc_frame", "model/descriptor[loc_frame]") link_se_a = make_link("se_a", "model/descriptor[se_a]") link_se_r = make_link("se_r", "model/descriptor[se_r]") @@ -267,7 +269,7 @@ def descrpt_variant_type_args(): ) -def fitting_ener(): +def fitting_ener() -> list[Argument]: doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built." @@ -305,7 +307,7 @@ def fitting_ener(): ] -def fitting_polar(): +def fitting_polar() -> list[Argument]: doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built." doc_activation_function = f"The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}" doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' @@ -339,11 +341,11 @@ def fitting_polar(): ] -def fitting_global_polar(): +def fitting_global_polar() -> list[Argument]: return fitting_polar() -def fitting_dipole(): +def fitting_dipole() -> list[Argument]: doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built." doc_activation_function = f"The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}" doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' @@ -368,7 +370,7 @@ def fitting_dipole(): ] -def fitting_variant_type_args(): +def fitting_variant_type_args() -> Variant: doc_descrpt_type = "The type of the fitting. See explanation below. \n\n\ - `ener`: Fit an energy model (potential energy surface).\n\n\ - `dipole`: Fit an atomic dipole model. Atomic dipole labels for all the selected atoms (see `sel_type`) should be provided by `dipole.npy` in each data system. The file has number of frames lines and 3 times of number of selected atoms columns.\n\n\ @@ -389,7 +391,7 @@ def fitting_variant_type_args(): ) -def modifier_dipole_charge(): +def modifier_dipole_charge() -> list[Argument]: doc_model_name = "The name of the frozen dipole model file." doc_model_charge_map = f"The charge of the WFCC. The list length should be the same as the {make_link('sel_type', 'model/fitting_net[dipole]/sel_type')}. " doc_sys_charge_map = f"The charge of real atoms. The list length should be the same as the {make_link('type_map', 'model/type_map')}" @@ -405,7 +407,7 @@ def modifier_dipole_charge(): ] -def modifier_variant_type_args(): +def modifier_variant_type_args() -> Variant: doc_modifier_type = "The type of modifier. See explanation below.\n\n\ -`dipole_charge`: Use WFCC to model the electronic structure of the system. Correct the long-range interaction" return Variant( @@ -418,7 +420,7 @@ def modifier_variant_type_args(): ) -def model_args(): +def model_args() -> Argument: doc_type_map = "A list of strings. Give the name to each type of atoms." doc_data_stat_nbatch = "The model determines the normalization from the statistics of the data. This key specifies the number of `frames` in each `system` used for statistics." doc_data_stat_protect = "Protect parameter for atomic energy regression." @@ -473,7 +475,7 @@ def model_args(): return ca -def learning_rate_exp(): +def learning_rate_exp() -> list[Argument]: doc_start_lr = "The learning rate the start of the training." doc_stop_lr = "The desired learning rate at the end of the training." doc_decay_steps = ( @@ -488,7 +490,7 @@ def learning_rate_exp(): return args -def learning_rate_variant_type_args(): +def learning_rate_variant_type_args() -> Variant: doc_lr = "The type of the learning rate." return Variant( @@ -500,22 +502,22 @@ def learning_rate_variant_type_args(): ) -def learning_rate_args(): +def learning_rate_args() -> Argument: doc_lr = "The definitio of learning rate" return Argument( "learning_rate", dict, [], [learning_rate_variant_type_args()], doc=doc_lr ) -def start_pref(item): +def start_pref(item: str) -> str: return f"The prefactor of {item} loss at the start of the training. Should be larger than or equal to 0. If set to none-zero value, the {item} label should be provided by file {item}.npy in each data system. If both start_pref_{item} and limit_pref_{item} are set to 0, then the {item} will be ignored." -def limit_pref(item): +def limit_pref(item: str) -> str: return f"The prefactor of {item} loss at the limit of the training, Should be larger than or equal to 0. i.e. the training step goes to infinity." -def loss_ener(): +def loss_ener() -> list[Argument]: doc_start_pref_e = start_pref("energy") doc_limit_pref_e = limit_pref("energy") doc_start_pref_f = start_pref("force") @@ -586,7 +588,7 @@ def loss_ener(): ] -def loss_variant_type_args(): +def loss_variant_type_args() -> Variant: doc_loss = "The type of the loss. \n" return Variant( @@ -598,7 +600,7 @@ def loss_variant_type_args(): ) -def loss_args(): +def loss_args() -> Argument: doc_loss = "The definition of loss function. The type of the loss depends on the type of the fitting. For fitting type `ener`, the prefactors before energy, force, virial and atomic energy losses may be provided. For fitting type `dipole`, `polar` and `global_polar`, the loss may be an empty `dict` or unset." ca = Argument( "loss", dict, [], [loss_variant_type_args()], optional=True, doc=doc_loss @@ -606,7 +608,7 @@ def loss_args(): return ca -def training_args(): +def training_args() -> Argument: link_sys = make_link("systems", "training/systems") doc_systems = "The data systems. This key can be provided with a listthat specifies the systems, or be provided with a string by which the prefix of all systems are given and the list of the systems is automatically generated." doc_set_prefix = f"The prefix of the sets in the {link_sys}." @@ -712,14 +714,14 @@ def training_args(): return Argument("training", dict, args, [], doc=doc_training) -def make_index(keys): +def make_index(keys: list[str]) -> str: ret = [] for ii in keys: ret.append(make_link(ii, ii)) return ", ".join(ret) -def gen_doc(*, make_anchor=True, make_link=True, **kwargs): +def gen_doc(*, make_anchor: bool = True, make_link: bool = True, **kwargs: Any) -> str: if make_link: make_anchor = True ma = model_args() @@ -741,7 +743,7 @@ def gen_doc(*, make_anchor=True, make_link=True, **kwargs): return "\n\n".join(ptr) -def check(data): +def check(data: dict) -> None: ma = model_args() lra = learning_rate_args() la = loss_args() @@ -751,7 +753,7 @@ def check(data): base.check_value(data) -def normalize(data): +def normalize(data: dict) -> dict: ma = model_args() lra = learning_rate_args() la = loss_args() diff --git a/tests/test_checker.py b/tests/test_checker.py index f3800da..4486c7a 100644 --- a/tests/test_checker.py +++ b/tests/test_checker.py @@ -8,7 +8,7 @@ class TestChecker(unittest.TestCase): - def test_name_type(self): + def test_name_type(self) -> None: # naive ca = Argument("key1", int) ca.check({"key1": 10}) @@ -52,7 +52,7 @@ def test_name_type(self): ca.check({"kwargs": anydict}, strict=True) ca.check_value(anydict) - def test_sub_fields(self): + def test_sub_fields(self) -> None: ca = Argument( "base", dict, @@ -100,7 +100,7 @@ def test_sub_fields(self): with self.assertRaises(ValueError): Argument("base", dict, [Argument("sub1", int), Argument("sub1", int)]) - def test_sub_repeat_list(self): + def test_sub_repeat_list(self) -> None: ca = Argument( "base", list, [Argument("sub1", int), Argument("sub2", str)], repeat=True ) @@ -124,7 +124,7 @@ def test_sub_repeat_list(self): with self.assertRaises(ArgumentTypeError): ca.check(err_dict2) - def test_sub_repeat_dict(self): + def test_sub_repeat_dict(self) -> None: ca = Argument( "base", dict, [Argument("sub1", int), Argument("sub2", str)], repeat=True ) @@ -165,7 +165,7 @@ def test_sub_repeat_dict(self): with self.assertRaises(ArgumentTypeError): ca.check(err_dict3) - def test_sub_variants(self): + def test_sub_variants(self) -> None: ca = Argument( "base", dict, diff --git a/tests/test_cli.py b/tests/test_cli.py index 3cf4318..25ec787 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,7 +9,7 @@ class TestCli(unittest.TestCase): - def test_check(self): + def test_check(self) -> None: subprocess.check_call( [ "dargs", @@ -43,7 +43,7 @@ def test_check(self): stdin=f, ) - def test_doc_all_arguments(self): + def test_doc_all_arguments(self) -> None: """Test printing documentation for all arguments.""" result = subprocess.run( [ @@ -64,7 +64,7 @@ def test_doc_all_arguments(self): self.assertIn("Argument 2", result.stdout) self.assertIn("Argument 3", result.stdout) - def test_doc_specific_argument(self): + def test_doc_specific_argument(self) -> None: """Test printing documentation for a specific argument.""" result = subprocess.run( [ @@ -84,7 +84,7 @@ def test_doc_specific_argument(self): self.assertNotIn("Argument 2", result.stdout) self.assertNotIn("Argument 3", result.stdout) - def test_doc_nested_arguments(self): + def test_doc_nested_arguments(self) -> None: """Test printing documentation for nested arguments.""" # Test top-level base argument result = subprocess.run( @@ -139,7 +139,7 @@ def test_doc_nested_arguments(self): # Check that the full path is in the output self.assertIn("base/sub2/subsub1", result.stdout) - def test_doc_invalid_path(self): + def test_doc_invalid_path(self) -> None: """Test error handling for invalid argument path.""" result = subprocess.run( [ @@ -154,7 +154,7 @@ def test_doc_invalid_path(self): self.assertNotEqual(result.returncode, 0) self.assertIn("not found", result.stderr) - def test_doc_invalid_nested_path(self): + def test_doc_invalid_nested_path(self) -> None: """Test error handling for invalid nested argument path.""" result = subprocess.run( [ @@ -169,7 +169,7 @@ def test_doc_invalid_nested_path(self): self.assertNotEqual(result.returncode, 0) self.assertIn("not found", result.stderr) - def test_doc_with_python_module(self): + def test_doc_with_python_module(self) -> None: """Test doc command using python -m.""" result = subprocess.run( [ @@ -187,7 +187,7 @@ def test_doc_with_python_module(self): self.assertIn("test1:", result.stdout) self.assertIn("Argument 1", result.stdout) - def test_doc_invalid_function_format(self): + def test_doc_invalid_function_format(self) -> None: """Test error handling for invalid function format.""" result = subprocess.run( [ diff --git a/tests/test_creation.py b/tests/test_creation.py index c75a8ec..e63b1da 100644 --- a/tests/test_creation.py +++ b/tests/test_creation.py @@ -6,13 +6,13 @@ class TestCreation(unittest.TestCase): - def test_dtype(self): + def test_dtype(self) -> None: ref = Argument("key1", [bool, str, dict]) ca = Argument("key1", int) ca.set_dtype([bool, str, dict]) self.assertTrue(ca == ref) - def test_sub_fields(self): + def test_sub_fields(self) -> None: ref = Argument( "base", dict, @@ -39,7 +39,7 @@ def test_sub_fields(self): ca.set_repeat(True) self.assertTrue(ca == ref) - def test_idx_fields(self): + def test_idx_fields(self) -> None: s1 = Argument("sub1", int) vt1 = Argument( "type1", @@ -73,7 +73,7 @@ def test_idx_fields(self): self.assertTrue(ca.I["base[type1]"] is vt1) self.assertTrue(ca.I["base[type2]//shared"] == Argument("shared", int)) - def test_sub_variants(self): + def test_sub_variants(self) -> None: ref = Argument( "base", dict, @@ -153,7 +153,7 @@ def test_sub_variants(self): v1.set_default(False) self.assertTrue(ca == ref) - def test_idx_variants(self): + def test_idx_variants(self) -> None: vt1 = Argument( "type1", dict, @@ -177,7 +177,7 @@ def test_idx_variants(self): with self.assertRaises(KeyError): vnt["type3"] - def test_complicated(self): + def test_complicated(self) -> None: ref = Argument( "base", dict, diff --git a/tests/test_docgen.py b/tests/test_docgen.py index 0e448e3..e004779 100644 --- a/tests/test_docgen.py +++ b/tests/test_docgen.py @@ -9,7 +9,7 @@ class TestDocgen(unittest.TestCase): - def test_sub_fields(self): + def test_sub_fields(self) -> None: ca = Argument( "base", dict, @@ -41,7 +41,7 @@ def test_sub_fields(self): jsonstr = json.dumps(ca, cls=ArgumentEncoder) # print("\n\n"+docstr) - def test_sub_repeat_list(self): + def test_sub_repeat_list(self) -> None: ca = Argument( "base", list, @@ -70,7 +70,7 @@ def test_sub_repeat_list(self): jsonstr = json.dumps(ca, cls=ArgumentEncoder) # print("\n\n"+docstr) - def test_sub_repeat_dict(self): + def test_sub_repeat_dict(self) -> None: ca = Argument( "base", dict, @@ -98,7 +98,7 @@ def test_sub_repeat_dict(self): docstr = ca.gen_doc() jsonstr = json.dumps(ca, cls=ArgumentEncoder) - def test_sub_variants(self): + def test_sub_variants(self) -> None: ca = Argument( "base", dict, @@ -174,7 +174,7 @@ def test_sub_variants(self): jsonstr = json.dumps(ca, cls=ArgumentEncoder) # print("\n\n"+docstr) - def test_multi_variants(self): + def test_multi_variants(self) -> None: ca = Argument( "base", dict, @@ -247,7 +247,7 @@ def test_multi_variants(self): jsonstr = json.dumps(ca, cls=ArgumentEncoder) # print("\n\n"+docstr) - def test_dpmd(self): + def test_dpmd(self) -> None: from .dpmdargs import gen_doc dargs.RAW_ANCHOR = False diff --git a/tests/test_json_schema.py b/tests/test_json_schema.py index 48dba41..2da52e8 100644 --- a/tests/test_json_schema.py +++ b/tests/test_json_schema.py @@ -11,13 +11,13 @@ class TestJsonSchema(unittest.TestCase): - def test_json_schema(self): + def test_json_schema(self) -> None: args = gen_args() schema = generate_json_schema(args) data = json.loads(example_json_str) validate(data, schema) - def test_convert_types(self): + def test_convert_types(self) -> None: self.assertEqual(_convert_types(int), "number") self.assertEqual(_convert_types(str), "string") self.assertEqual(_convert_types(float), "number") diff --git a/tests/test_normalizer.py b/tests/test_normalizer.py index 8a03571..9f9d995 100644 --- a/tests/test_normalizer.py +++ b/tests/test_normalizer.py @@ -6,7 +6,7 @@ class TestNormalizer(unittest.TestCase): - def test_default(self): + def test_default(self) -> None: # naive ca = Argument("Key1", int, optional=True, default=1) beg = {} @@ -20,8 +20,8 @@ def test_default(self): self.assertDictEqual(end1, ref) self.assertTrue(end1 is beg) - def test_default_dict(self): - def make_arguments(): + def test_default_dict(self) -> None: + def make_arguments() -> list[Argument]: arg_foo = Argument("foo", int, optional=True, default=1) arg_bar = Argument("bar", dict, [arg_foo], optional=True, default={}) return [arg_bar] @@ -30,7 +30,7 @@ def make_arguments(): data = base.normalize_value({}) self.assertDictEqual(data, {"bar": {}}) - def test_alias(self): + def test_alias(self) -> None: ca = Argument("Key1", int, alias=["Old1", "Old2"]) beg = {"Old1": 1} end = ca.normalize(beg) @@ -44,7 +44,7 @@ def test_alias(self): self.assertDictEqual(end1, ref) self.assertTrue(end1 is beg1) - def test_trim(self): + def test_trim(self) -> None: ca = Argument("Key1", int) beg = {"Key1": 1, "_comment": 123} end = ca.normalize(beg, trim_pattern="_*") @@ -60,7 +60,7 @@ def test_trim(self): self.assertDictEqual(end1, ref) self.assertTrue(end1 is beg) - def test_combined(self): + def test_combined(self) -> None: ca = Argument( "base", dict, @@ -80,7 +80,7 @@ def test_combined(self): ca.normalize_value(beg2["base"], trim_pattern="_*"), ref2["base"] ) - def test_complicated(self): + def test_complicated(self) -> None: ca = Argument( "base", dict, @@ -196,7 +196,7 @@ def test_complicated(self): with self.assertRaises(ValueError): ca.normalize(beg2, trim_pattern="vnt*") - def test_dpmd(self): + def test_dpmd(self) -> None: import json from .dpmdargs import example_json_str, normalize diff --git a/tests/test_notebook.py b/tests/test_notebook.py index bd26f50..6184197 100644 --- a/tests/test_notebook.py +++ b/tests/test_notebook.py @@ -15,7 +15,7 @@ @unittest.skipUnless(ipython_installed, "IPython not installed") class TestNotebook(unittest.TestCase): - def test_html_validation(self): + def test_html_validation(self) -> None: from dargs.notebook import print_html doc_test = "Test doc."