Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions import-automation/workflow/ingestion-helper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ Initializes the Spanner database by creating all necessary tables and uploading
* `enableEmbeddings` (Optional): Boolean to enable creation of embedding tables and models.
* **Note on Protos**: The `storage.pb` file is generated during the Docker build process. The `Dockerfile` fetches `storage.proto` from the `datacommonsorg/import` GitHub repository and compiles it into `storage.pb`.

#### `seed_database`
Seeds the Spanner database with base empty nodes required by the Data Commons schema (`StatisticalVariable`, `StatVarGroup`, `StatVarObservation`, `Topic`, and `c/g/Root`).

* This action requires no payload parameters.

#### `embedding_ingestion`
Triggers the generation of embeddings for updated nodes in Spanner. It fetches nodes of specific types (e.g., `StatisticalVariable`, `Topic`) that have been updated, generates embeddings using a remote ML model in Spanner, and stores the results in the `NodeEmbedding` table.

Expand Down
6 changes: 6 additions & 0 deletions import-automation/workflow/ingestion-helper/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,13 @@ def ingestion_helper(request):
FLAGS.enable_embeddings)
spanner.initialize_database(enable_embeddings=enable_embeddings)
return ('OK', 200)
elif action_type == 'seed_database':
# Seeds the database with base empty nodes.
logging.info("Action: seed_database")
spanner.seed_database()
return ('OK', 200)
elif action_type == 'embedding_ingestion':

logging.info("Action: embedding_ingestion")
enable_embeddings = request_json.get('enableEmbeddings',
FLAGS.enable_embeddings)
Expand Down
22 changes: 22 additions & 0 deletions import-automation/workflow/ingestion-helper/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,5 +120,27 @@ def run_in_transaction_side_effect(func):
self.assertEqual(batch[0][0], "dc/1")
self.assertEqual(batch[0][1], "Node 1")

@patch.dict(os.environ, {
"SPANNER_INSTANCE_ID": "test-instance",
"SPANNER_DATABASE_ID": "test-db",
"SPANNER_PROJECT_ID": "test-proj"
})
@patch('main.SpannerClient')
def test_seed_database_success(self, mock_spanner_client_class):
mock_spanner_client = MagicMock()
mock_spanner_client_class.return_value = mock_spanner_client

mock_request = MagicMock()
mock_request.get_json.return_value = {
"actionType": "seed_database"
}

response, status_code = main.ingestion_helper(mock_request)

self.assertEqual(status_code, 200)
self.assertIn("OK", response)
mock_spanner_client.seed_database.assert_called_once()

if __name__ == '__main__':
unittest.main()

36 changes: 36 additions & 0 deletions import-automation/workflow/ingestion-helper/spanner_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,39 @@ def initialize_database(self, enable_embeddings=False):
except Exception as e:
logging.error(f"Failed to update DDL with protos: {e}")
raise

def seed_database(self):
"""Seeds the database with base empty nodes."""
logging.info("Seeding database with base nodes...")

def _seed(transaction: Transaction):
candidates = {
"StatisticalVariable": ["StatisticalVariable", ["Class"], spanner.COMMIT_TIMESTAMP],
"StatVarGroup": ["StatVarGroup", ["Class"], spanner.COMMIT_TIMESTAMP],
"StatVarObservation": ["StatVarObservation", ["Class"], spanner.COMMIT_TIMESTAMP],
"Topic": ["Topic", ["Class"], spanner.COMMIT_TIMESTAMP],
"c/g/Root": ["c/g/Root", ["StatVarGroup"], spanner.COMMIT_TIMESTAMP],
}
subjects = list(candidates.keys())
sql = "SELECT subject_id FROM Node WHERE subject_id IN UNNEST(@subjects)"
params = {"subjects": subjects}
param_types = {"subjects": Array(STRING)}
existing = set()
for row in transaction.execute_sql(sql, params, param_types):
existing.add(row[0])

values = [candidates[subj] for subj in subjects if subj not in existing]

if values:
columns = ["subject_id", "types", "last_update_timestamp"]
transaction.insert(table="Node", columns=columns, values=values)

try:
self.database.run_in_transaction(_seed)
if self.graph_database and self.graph_database.name != self.database.name:
self.graph_database.run_in_transaction(_seed)
logging.info("Database seeded successfully.")
except Exception as e:
logging.error(f"Error seeding database: {e}")
raise

Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def test_initialize_database_all_exist(self, mock_spanner_client):
["table", "Node"], ["table", "Edge"], ["table", "Observation"],
["table", "NodeEmbedding"], ["table", "ImportStatus"],
["table", "IngestionHistory"], ["table", "ImportVersionHistory"],
["table", "IngestionLock"],
["table", "IngestionLock"], ["table", "Cache"],
["table", "VariableMetadata"],
["index", "NodeEmbeddingIndex"],
["model", "NodeEmbeddingModel"]
]
Expand Down Expand Up @@ -220,5 +221,56 @@ def run_in_transaction_side_effect(callback, *args, **kwargs):
args, _ = mock_transaction.execute_update.call_args
self.assertIn("UPDATE IngestionLock", args[0])

@patch('google.cloud.spanner.Client')
def test_seed_database(self, mock_spanner_client):
# Setup mock
mock_instance = MagicMock()
mock_db = MagicMock()
mock_spanner_client.return_value.instance.return_value = mock_instance
mock_instance.database.return_value = mock_db

mock_transaction = MagicMock()
mock_transaction.execute_sql.return_value = []
def run_in_transaction_side_effect(callback, *args, **kwargs):
return callback(mock_transaction, *args, **kwargs)
mock_db.run_in_transaction.side_effect = run_in_transaction_side_effect

client = SpannerClient("project", "instance", "database")

# Run method
client.seed_database()

# Verify
mock_transaction.insert.assert_called_once()
args, kwargs = mock_transaction.insert.call_args
self.assertEqual(kwargs['table'], 'Node')
self.assertEqual(len(kwargs['values']), 5)
expected_subjects = ["StatisticalVariable", "StatVarGroup", "StatVarObservation", "Topic", "c/g/Root"]
actual_subjects = [val[0] for val in kwargs['values']]
self.assertEqual(actual_subjects, expected_subjects)

@patch('google.cloud.spanner.Client')
def test_seed_database_already_exists(self, mock_spanner_client):
# Setup mock
mock_instance = MagicMock()
mock_db = MagicMock()
mock_spanner_client.return_value.instance.return_value = mock_instance
mock_instance.database.return_value = mock_db

mock_transaction = MagicMock()
mock_transaction.execute_sql.return_value = [["StatisticalVariable"], ["StatVarGroup"], ["StatVarObservation"], ["Topic"], ["c/g/Root"]]
def run_in_transaction_side_effect(callback, *args, **kwargs):
return callback(mock_transaction, *args, **kwargs)
mock_db.run_in_transaction.side_effect = run_in_transaction_side_effect

client = SpannerClient("project", "instance", "database")

# Run method
client.seed_database()

# Verify
mock_transaction.insert.assert_not_called()

if __name__ == '__main__':
unittest.main()

Loading