diff --git a/import-automation/workflow/ingestion-helper/README.md b/import-automation/workflow/ingestion-helper/README.md index 7de6d49a91..a608388d09 100644 --- a/import-automation/workflow/ingestion-helper/README.md +++ b/import-automation/workflow/ingestion-helper/README.md @@ -64,6 +64,11 @@ Initializes the Spanner database by creating all necessary tables and uploading * `enableEmbeddings` (Optional): Boolean to enable creation of embedding tables and models. * **Note on Protos**: The `storage.pb` file is generated during the Docker build process. The `Dockerfile` fetches `storage.proto` from the `datacommonsorg/import` GitHub repository and compiles it into `storage.pb`. +#### `seed_database` +Seeds the Spanner database with base empty nodes required by the Data Commons schema (`StatisticalVariable`, `StatVarGroup`, `StatVarObservation`, `Topic`, and `c/g/Root`). + +* This action requires no payload parameters. + #### `embedding_ingestion` Triggers the generation of embeddings for updated nodes in Spanner. It fetches nodes of specific types (e.g., `StatisticalVariable`, `Topic`) that have been updated, generates embeddings using a remote ML model in Spanner, and stores the results in the `NodeEmbedding` table. diff --git a/import-automation/workflow/ingestion-helper/main.py b/import-automation/workflow/ingestion-helper/main.py index 9921f0fd34..ce2e57accb 100644 --- a/import-automation/workflow/ingestion-helper/main.py +++ b/import-automation/workflow/ingestion-helper/main.py @@ -219,7 +219,13 @@ def ingestion_helper(request): FLAGS.enable_embeddings) spanner.initialize_database(enable_embeddings=enable_embeddings) return ('OK', 200) + elif action_type == 'seed_database': + # Seeds the database with base empty nodes. + logging.info("Action: seed_database") + spanner.seed_database() + return ('OK', 200) elif action_type == 'embedding_ingestion': + logging.info("Action: embedding_ingestion") enable_embeddings = request_json.get('enableEmbeddings', FLAGS.enable_embeddings) diff --git a/import-automation/workflow/ingestion-helper/main_test.py b/import-automation/workflow/ingestion-helper/main_test.py index 1e6d71869a..69c87ebfa3 100644 --- a/import-automation/workflow/ingestion-helper/main_test.py +++ b/import-automation/workflow/ingestion-helper/main_test.py @@ -120,5 +120,27 @@ def run_in_transaction_side_effect(func): self.assertEqual(batch[0][0], "dc/1") self.assertEqual(batch[0][1], "Node 1") + @patch.dict(os.environ, { + "SPANNER_INSTANCE_ID": "test-instance", + "SPANNER_DATABASE_ID": "test-db", + "SPANNER_PROJECT_ID": "test-proj" + }) + @patch('main.SpannerClient') + def test_seed_database_success(self, mock_spanner_client_class): + mock_spanner_client = MagicMock() + mock_spanner_client_class.return_value = mock_spanner_client + + mock_request = MagicMock() + mock_request.get_json.return_value = { + "actionType": "seed_database" + } + + response, status_code = main.ingestion_helper(mock_request) + + self.assertEqual(status_code, 200) + self.assertIn("OK", response) + mock_spanner_client.seed_database.assert_called_once() + if __name__ == '__main__': unittest.main() + diff --git a/import-automation/workflow/ingestion-helper/spanner_client.py b/import-automation/workflow/ingestion-helper/spanner_client.py index 27b30a087f..2255cf1420 100644 --- a/import-automation/workflow/ingestion-helper/spanner_client.py +++ b/import-automation/workflow/ingestion-helper/spanner_client.py @@ -555,3 +555,39 @@ def initialize_database(self, enable_embeddings=False): except Exception as e: logging.error(f"Failed to update DDL with protos: {e}") raise + + def seed_database(self): + """Seeds the database with base empty nodes.""" + logging.info("Seeding database with base nodes...") + + def _seed(transaction: Transaction): + candidates = { + "StatisticalVariable": ["StatisticalVariable", ["Class"], spanner.COMMIT_TIMESTAMP], + "StatVarGroup": ["StatVarGroup", ["Class"], spanner.COMMIT_TIMESTAMP], + "StatVarObservation": ["StatVarObservation", ["Class"], spanner.COMMIT_TIMESTAMP], + "Topic": ["Topic", ["Class"], spanner.COMMIT_TIMESTAMP], + "c/g/Root": ["c/g/Root", ["StatVarGroup"], spanner.COMMIT_TIMESTAMP], + } + subjects = list(candidates.keys()) + sql = "SELECT subject_id FROM Node WHERE subject_id IN UNNEST(@subjects)" + params = {"subjects": subjects} + param_types = {"subjects": Array(STRING)} + existing = set() + for row in transaction.execute_sql(sql, params, param_types): + existing.add(row[0]) + + values = [candidates[subj] for subj in subjects if subj not in existing] + + if values: + columns = ["subject_id", "types", "last_update_timestamp"] + transaction.insert(table="Node", columns=columns, values=values) + + try: + self.database.run_in_transaction(_seed) + if self.graph_database and self.graph_database.name != self.database.name: + self.graph_database.run_in_transaction(_seed) + logging.info("Database seeded successfully.") + except Exception as e: + logging.error(f"Error seeding database: {e}") + raise + diff --git a/import-automation/workflow/ingestion-helper/spanner_client_test.py b/import-automation/workflow/ingestion-helper/spanner_client_test.py index 3a961db48f..1b66f29a14 100644 --- a/import-automation/workflow/ingestion-helper/spanner_client_test.py +++ b/import-automation/workflow/ingestion-helper/spanner_client_test.py @@ -38,7 +38,8 @@ def test_initialize_database_all_exist(self, mock_spanner_client): ["table", "Node"], ["table", "Edge"], ["table", "Observation"], ["table", "NodeEmbedding"], ["table", "ImportStatus"], ["table", "IngestionHistory"], ["table", "ImportVersionHistory"], - ["table", "IngestionLock"], + ["table", "IngestionLock"], ["table", "Cache"], + ["table", "VariableMetadata"], ["index", "NodeEmbeddingIndex"], ["model", "NodeEmbeddingModel"] ] @@ -220,5 +221,56 @@ def run_in_transaction_side_effect(callback, *args, **kwargs): args, _ = mock_transaction.execute_update.call_args self.assertIn("UPDATE IngestionLock", args[0]) + @patch('google.cloud.spanner.Client') + def test_seed_database(self, mock_spanner_client): + # Setup mock + mock_instance = MagicMock() + mock_db = MagicMock() + mock_spanner_client.return_value.instance.return_value = mock_instance + mock_instance.database.return_value = mock_db + + mock_transaction = MagicMock() + mock_transaction.execute_sql.return_value = [] + def run_in_transaction_side_effect(callback, *args, **kwargs): + return callback(mock_transaction, *args, **kwargs) + mock_db.run_in_transaction.side_effect = run_in_transaction_side_effect + + client = SpannerClient("project", "instance", "database") + + # Run method + client.seed_database() + + # Verify + mock_transaction.insert.assert_called_once() + args, kwargs = mock_transaction.insert.call_args + self.assertEqual(kwargs['table'], 'Node') + self.assertEqual(len(kwargs['values']), 5) + expected_subjects = ["StatisticalVariable", "StatVarGroup", "StatVarObservation", "Topic", "c/g/Root"] + actual_subjects = [val[0] for val in kwargs['values']] + self.assertEqual(actual_subjects, expected_subjects) + + @patch('google.cloud.spanner.Client') + def test_seed_database_already_exists(self, mock_spanner_client): + # Setup mock + mock_instance = MagicMock() + mock_db = MagicMock() + mock_spanner_client.return_value.instance.return_value = mock_instance + mock_instance.database.return_value = mock_db + + mock_transaction = MagicMock() + mock_transaction.execute_sql.return_value = [["StatisticalVariable"], ["StatVarGroup"], ["StatVarObservation"], ["Topic"], ["c/g/Root"]] + def run_in_transaction_side_effect(callback, *args, **kwargs): + return callback(mock_transaction, *args, **kwargs) + mock_db.run_in_transaction.side_effect = run_in_transaction_side_effect + + client = SpannerClient("project", "instance", "database") + + # Run method + client.seed_database() + + # Verify + mock_transaction.insert.assert_not_called() + if __name__ == '__main__': unittest.main() +