-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_init.py
More file actions
162 lines (131 loc) · 5.98 KB
/
generate_init.py
File metadata and controls
162 lines (131 loc) · 5.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import ast
import os
import sys
import textwrap
from collections import defaultdict
# To update init with methods in pat2vec root dir,
# The script now writes the file directly. Just run: `python generate_init.py`
def generate_init_file_content(package_path="pat2vec"):
"""Walks a package to find all public, top-level functions and classes
and generates the content for the __init__.py file.
Args:
package_path (str): The path to the package root. Defaults to "pat2vec".
"""
# This will store all importable names (functions and classes) per module.
module_to_imports = defaultdict(list)
all_import_names = set()
for root, _, files in os.walk(package_path):
for file_name in files:
if file_name.endswith(".py") and file_name != "__init__.py":
file_path = os.path.join(root, file_name)
relative_path = os.path.relpath(file_path, package_path)
module_path, _ = os.path.splitext(relative_path)
import_path = "." + module_path.replace(os.path.sep, ".")
with open(file_path, "r", encoding="utf-8") as f:
try:
tree = ast.parse(f.read(), filename=file_path)
# Iterate over top-level nodes only
for node in tree.body:
if isinstance(
node, (ast.FunctionDef, ast.ClassDef)
) and not node.name.startswith("_"):
name = node.name
module_to_imports[import_path].append(name)
all_import_names.add(name)
elif isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(
target, ast.Name
) and not target.id.startswith("_"):
name = target.id
# Heuristic: export all-caps variables as constants
if name.isupper():
module_to_imports[import_path].append(name)
all_import_names.add(name)
elif isinstance(node, ast.AnnAssign) and isinstance(
node.target, ast.Name
):
name = node.target.id
if not name.startswith("_") and name.isupper():
module_to_imports[import_path].append(name)
all_import_names.add(name)
except SyntaxError as e:
print(f"Skipping {file_path} due to syntax error: {e}")
version = get_version_from_pyproject()
# --- Build the file content ---
output_lines = [
'"""',
"pat2vec: A package for processing patient data.",
"",
"This file is auto-generated by `generate_init.py`.",
"",
"It exposes the main functions and methods of the pat2vec library for easy access.",
"",
'"""',
"",
f'__version__ = "{version}"',
"",
]
imported_names = set() # Keep track of names we've already imported.
for module, imports in sorted(module_to_imports.items()):
unique_imports_for_module = []
for name in sorted(imports):
if name not in imported_names:
unique_imports_for_module.append(name)
imported_names.add(name)
if unique_imports_for_module:
import_line = f"from {module} import ("
output_lines.append(import_line)
wrapper = textwrap.TextWrapper(
width=88, initial_indent=" ", subsequent_indent=" "
)
wrapped_imports = wrapper.fill(", ".join(unique_imports_for_module))
output_lines.append(wrapped_imports)
output_lines.append(")")
output_lines.append("\n")
output_lines.append("# Define the public API of the package")
all_list_str = '", "'.join(sorted(list(all_import_names)))
wrapper = textwrap.TextWrapper(
width=88, initial_indent=" ", subsequent_indent=" "
)
wrapped_all = wrapper.fill(f'"{all_list_str}"')
output_lines.append("__all__ = [")
output_lines.append(wrapped_all)
output_lines.append("]")
return "\n".join(output_lines)
def format_with_black(content: str) -> str:
"""Run black on the generated content string and return the formatted result."""
import subprocess
result = subprocess.run(
[sys.executable, "-m", "black", "-"],
input=content,
capture_output=True,
text=True,
)
if result.returncode != 0:
print(f"Warning: black formatting failed:\n{result.stderr}", file=sys.stderr)
return content # Fall back to unformatted content
return result.stdout
def get_version_from_pyproject(pyproject_path="pyproject.toml"):
"""Extract version from pyproject.toml without external deps."""
import re
try:
with open(pyproject_path, "r", encoding="utf-8") as f:
content = f.read()
match = re.search(r'^version\s*=\s*["\']([^"\']+)["\']', content, re.MULTILINE)
if match:
return match.group(1)
except FileNotFoundError:
pass
return "0.0.0" # fallback
if __name__ == "__main__":
init_content = generate_init_file_content()
init_content = format_with_black(init_content) # <-- replaces the manual rstrip
output_path = os.path.join("pat2vec", "__init__.py")
print(f"Writing __init__.py to {output_path}...")
try:
with open(output_path, "w", encoding="utf-8") as f:
f.write(init_content)
print("Done.")
except IOError as e:
print(f"Error writing to file: {e}", file=sys.stderr)