diff --git a/README.md b/README.md index 4c74eaa..6518804 100644 --- a/README.md +++ b/README.md @@ -228,7 +228,7 @@ or `JSON` for machine-parseable output to plug in your GHAS enablement automatio ```text $ python3 main.py --help -usage: main.py [-h] [--ac-report AC_REPORT] [--enterprise ENTERPRISE] [--organization ORGANIZATION] [--output OUTPUT] [--output-format OUTPUT_FORMAT] [--token TOKEN] [--licenses LICENSES] +usage: main.py [-h] [--ac-report AC_REPORT] [--enterprise ENTERPRISE] [--organization ORGANIZATION] [--output OUTPUT] [--output-format OUTPUT_FORMAT] [--token TOKEN] [--licenses LICENSES] [--server-url SERVER_URL] GHAS activation and coverage activation @@ -245,6 +245,8 @@ options: Output format - text or json (default: 'text') --token TOKEN GitHub Personal Access Token (if not set in GITHUB_TOKEN environment variable) --licenses LICENSES Number of (still) available GHAS licenses (default: 0) + --server-url SERVER_URL + GitHub Server URL (default: https://api.github.com for GitHub.com, or https://HOSTNAME/api/v3 for GHES) ``` You must provide: @@ -254,6 +256,7 @@ You must provide: Other parameters are optional, but note: - `--ac-report` with the path to the Max Active Committers report. If left empty, the script will gather the data from the GraphQL API. +- `--server-url` to specify a custom GitHub server URL. By default, it uses `https://api.github.com` for GitHub.com. For GitHub Enterprise Server (GHES), use `https://HOSTNAME/api/v3` (replace `HOSTNAME` with your GHES hostname). ### Prerequisites @@ -303,6 +306,22 @@ Other parameters are optional, but note: python3 main.py --org thez-org --licenses 600 --output-format json --output report.json ``` +### GitHub Enterprise Server (GHES) + +The script supports GitHub Enterprise Server (GHES) by using the `--server-url` parameter. For GHES installations, the API URL format is `https://HOSTNAME/api/v3`, where `HOSTNAME` is your GHES server hostname. + +**Example usage with GHES:** + +```shell +# For an organization on GHES +python3 main.py --server-url https://github.example.com/api/v3 --org my-org --output report.md + +# For an enterprise on GHES with active committers report +python3 main.py --server-url https://github.example.com/api/v3 --ac-report report.csv --enterprise my-enterprise --licenses 100 --output report.md +``` + +**Note:** Make sure your Personal Access Token (PAT) has the appropriate permissions for your GHES instance. + ## License This project is licensed under the terms of the MIT open source license. Please refer to the [LICENSE](LICENSE) for the full terms. diff --git a/github.py b/github.py index b72b631..0a378a5 100644 --- a/github.py +++ b/github.py @@ -3,6 +3,7 @@ import time import threading import concurrent.futures +import re from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta from models import Repository @@ -15,7 +16,7 @@ MAX_WORKERS = 5 -def add_active_committers(report, repositories, token): +def add_active_committers(report, repositories, token, server_url="https://api.github.com"): try: with open(report, "r") as file: reader = csv.reader(file) @@ -31,7 +32,7 @@ def add_active_committers(report, repositories, token): with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: futures = [ executor.submit( - get_active_committers_in_last_90_days, repo.org, repo.name, token + get_active_committers_in_last_90_days, repo.org, repo.name, token, server_url ) for repo in repositories ] @@ -50,20 +51,20 @@ def get_organizations(args, token): if not args.enterprise: orgs_in_ent.append(args.organization) else: - orgs_in_ent = get_orgs_in_ent(args.enterprise, token) + orgs_in_ent = get_orgs_in_ent(args.enterprise, token, args.server_url) return orgs_in_ent -def process_organizations(orgs_in_ent, token): +def process_organizations(orgs_in_ent, token, server_url="https://api.github.com"): total_repositories = [] for i, org in enumerate(orgs_in_ent, start=1): logger.info(f"[{i}/{len(orgs_in_ent)}] Processing organization: {org}") - total_repositories.extend(get_ghas_status_for_repos(org, token)) + total_repositories.extend(get_ghas_status_for_repos(org, token, server_url)) return total_repositories -def get_ghas_status_for_repos(org, token): - url = f"https://api.github.com/orgs/{org}/repos?per_page=100" +def get_ghas_status_for_repos(org, token, server_url="https://api.github.com"): + url = f"{server_url}/orgs/{org}/repos?per_page=100" headers = {"Authorization": f"token {token}"} page = 1 repos = [] @@ -73,7 +74,24 @@ def get_ghas_status_for_repos(org, token): handle_rate_limit(response) + # Check for HTTP errors + if response.status_code != 200: + logger.error(f"HTTP {response.status_code} error for URL {url}") + logger.error(f"Response: {response.text}") + if response.status_code == 401: + raise Exception(f"Authentication failed for organization {org}. Please check your GitHub token is valid for the server {server_url}") + elif response.status_code == 404: + raise Exception(f"Organization '{org}' not found or not accessible. Please check the organization name and your token permissions.") + else: + raise Exception(f"Failed to fetch repositories for organization {org}. Status: {response.status_code}") + data = response.json() + + # Check if data is a list (successful response) or dict (error response) + if not isinstance(data, list): + logger.error(f"Unexpected response format: {data}") + raise Exception(f"API returned an error: {data.get('message', 'Unknown error')}") + for repo_data in data: owner, name = repo_data["full_name"].split("/") ghas_status = ( @@ -93,8 +111,14 @@ def get_ghas_status_for_repos(org, token): return repos -def get_active_committers_in_last_90_days(org, repo, token): - url = "https://api.github.com/graphql" +def get_active_committers_in_last_90_days(org, repo, token, server_url="https://api.github.com"): + # For GHES, the GraphQL endpoint is at /api/graphql, for GitHub.com it's at /graphql + if server_url != "https://api.github.com": + # For GHES: Remove any /api/vX suffix to get base URL (supports v3, v4, v44, etc.) + base_url = re.sub(r'/api/v\d+$', '', server_url) + graphql_url = f"{base_url}/api/graphql" + else: + graphql_url = f"{server_url}/graphql" headers = {"Authorization": f"token {token}"} active_committers = set() @@ -137,20 +161,36 @@ def get_active_committers_in_last_90_days(org, repo, token): since = (datetime.now() - timedelta(days=90)).isoformat() variables = {"org": org, "repo": repo, "since": since, "after": end_cursor} payload = {"query": query, "variables": variables} - response = requests.post(url, headers=headers, json=payload) + response = requests.post(graphql_url, headers=headers, json=payload) handle_rate_limit(response) if response.status_code != 200: logger.info(f"Response: {response.json()}") - next + continue if response.status_code == 401: logger.info(f"Insufficient permissions token provided.") break data = response.json() + + # Check for GraphQL errors + if "errors" in data: + logger.error(f"GraphQL errors for {org}/{repo}: {data['errors']}") + break + + # Check if data structure is valid + if "data" not in data or not data["data"] or "repository" not in data["data"]: + logger.error(f"Invalid GraphQL response structure for {org}/{repo}: {data}") + break + repository = data["data"]["repository"] + + # Handle case where repository is None (e.g., doesn't exist or no access) + if repository is None: + logger.warning(f"Repository {org}/{repo} not found or not accessible") + break refs = repository.get("refs") visibility = repository.get("visibility") @@ -182,8 +222,14 @@ def get_active_committers_in_last_90_days(org, repo, token): return list(active_committers) -def get_orgs_in_ent(enterprise_name, token): - url = "https://api.github.com/graphql" +def get_orgs_in_ent(enterprise_name, token, server_url="https://api.github.com"): + # For GHES, the GraphQL endpoint is at /api/graphql, for GitHub.com it's at /graphql + if server_url != "https://api.github.com": + # For GHES: Remove any /api/vX suffix to get base URL (supports v3, v4, v44, etc.) + base_url = re.sub(r'/api/v\d+$', '', server_url) + graphql_url = f"{base_url}/api/graphql" + else: + graphql_url = f"{server_url}/graphql" headers = {"Authorization": f"token {token}", "X-Github-Next-Global-ID": "true"} orgs = [] end_cursor = None @@ -205,7 +251,7 @@ def get_orgs_in_ent(enterprise_name, token): """ variables = {"enterprise": enterprise_name, "after": end_cursor} payload = {"query": query, "variables": variables} - response = requests.post(url, headers=headers, json=payload) + response = requests.post(graphql_url, headers=headers, json=payload) handle_rate_limit(response) diff --git a/helpers.py b/helpers.py index f41140d..58c9728 100644 --- a/helpers.py +++ b/helpers.py @@ -55,12 +55,20 @@ def parse_arguments(): required=False, default=0, ) + parser.add_argument( + "--server-url", + type=str, + help="GitHub Server URL (default: https://api.github.com for GitHub.com, or https://HOSTNAME/api/v3 for GHES)", + required=False, + default="https://api.github.com", + ) args = parser.parse_args() if args.enterprise is None and args.organization is None: parser.error("Either --enterprise or --organization must be provided.") - token = os.getenv("GITHUB_TOKEN") or args.token + # Prioritize command line token over environment variable + token = args.token or os.getenv("GITHUB_TOKEN") if token is None: parser.error( "Either GITHUB_TOKEN environment variable or --token must be provided." diff --git a/main.py b/main.py index 21e3b3b..4b556f5 100644 --- a/main.py +++ b/main.py @@ -14,10 +14,10 @@ def main(): # Gather all data needed for the report - all orgs in the enterprise, repositories in orgs and active committers in repositories orgs_in_ent = get_organizations(args, token) logger.info(f"Number of organizations to process: {len(orgs_in_ent)}") - total_repositories = process_organizations(orgs_in_ent, token) + total_repositories = process_organizations(orgs_in_ent, token, args.server_url) logger.info(f"Adding active committers to {len(total_repositories)} repositories") - add_active_committers(args.ac_report, total_repositories, token) + add_active_committers(args.ac_report, total_repositories, token, args.server_url) # Generate report and print report logger.info(f"Generating report...") diff --git a/tests/test_ghes_support.py b/tests/test_ghes_support.py new file mode 100644 index 0000000..652ee68 --- /dev/null +++ b/tests/test_ghes_support.py @@ -0,0 +1,104 @@ +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import unittest +from unittest.mock import patch, MagicMock +from github import get_ghas_status_for_repos, get_orgs_in_ent, get_active_committers_in_last_90_days + + +class TestGHESSupport(unittest.TestCase): + """Test that GHES server URLs are properly handled""" + + def setUp(self): + self.token = "test_token" + self.github_com_url = "https://api.github.com" + self.ghes_url = "https://github.example.com/api/v3" + + @patch("github.requests.get") + def test_get_ghas_status_uses_custom_server_url(self, mock_get): + """Test that get_ghas_status_for_repos uses custom server URL""" + mock_response = MagicMock() + mock_response.json.return_value = [] + mock_response.links = {} + mock_response.headers = {"X-RateLimit-Remaining": "100"} + mock_get.return_value = mock_response + + # Test with GHES URL + get_ghas_status_for_repos("test-org", self.token, self.ghes_url) + + # Verify the URL was constructed correctly + call_args = mock_get.call_args + self.assertTrue(call_args[0][0].startswith(self.ghes_url)) + self.assertIn("test-org", call_args[0][0]) + + @patch("github.requests.get") + def test_get_ghas_status_uses_default_github_url(self, mock_get): + """Test that get_ghas_status_for_repos uses default GitHub.com URL""" + mock_response = MagicMock() + mock_response.json.return_value = [] + mock_response.links = {} + mock_response.headers = {"X-RateLimit-Remaining": "100"} + mock_get.return_value = mock_response + + # Test with default URL (not passing server_url) + get_ghas_status_for_repos("test-org", self.token) + + # Verify the URL was constructed correctly + call_args = mock_get.call_args + self.assertTrue(call_args[0][0].startswith(self.github_com_url)) + + @patch("github.requests.post") + def test_get_orgs_in_ent_uses_custom_server_url(self, mock_post): + """Test that get_orgs_in_ent uses custom GraphQL endpoint for GHES""" + mock_response = MagicMock() + mock_response.json.return_value = { + "data": { + "enterprise": { + "organizations": { + "nodes": [], + "pageInfo": {"hasNextPage": False, "endCursor": None} + } + } + } + } + mock_response.headers = {"X-RateLimit-Remaining": "100"} + mock_response.status_code = 200 + mock_post.return_value = mock_response + + # Test with GHES URL + get_orgs_in_ent("test-enterprise", self.token, self.ghes_url) + + # Verify the GraphQL endpoint was constructed correctly + call_args = mock_post.call_args + self.assertIn("graphql", call_args[0][0]) + self.assertTrue(call_args[0][0].startswith(self.ghes_url)) + + @patch("github.requests.post") + def test_get_active_committers_uses_custom_server_url(self, mock_post): + """Test that get_active_committers_in_last_90_days uses custom GraphQL endpoint""" + mock_response = MagicMock() + mock_response.json.return_value = { + "data": { + "repository": { + "visibility": "PRIVATE", + "refs": None + } + } + } + mock_response.headers = {"X-RateLimit-Remaining": "100"} + mock_response.status_code = 200 + mock_post.return_value = mock_response + + # Test with GHES URL + get_active_committers_in_last_90_days("test-org", "test-repo", self.token, self.ghes_url) + + # Verify the GraphQL endpoint was constructed correctly + call_args = mock_post.call_args + self.assertIn("graphql", call_args[0][0]) + self.assertTrue(call_args[0][0].startswith(self.ghes_url)) + + +if __name__ == "__main__": + unittest.main()