Skip to content

Commit 982dec9

Browse files
committed
Cleanup: replace db.cursor() with get_db_cursor()
1 parent d4fbedf commit 982dec9

File tree

3 files changed

+21
-28
lines changed

3 files changed

+21
-28
lines changed

dbutils.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ def get_db_connection():
4444
@contextmanager
4545
def get_db_cursor(commit=False):
4646
with get_db_connection() as connection:
47-
cursor = connection.cursor()
48-
try:
49-
yield cursor
50-
if commit:
51-
connection.commit()
52-
finally:
53-
cursor.close()
47+
cursor = connection.cursor()
48+
try:
49+
yield cursor
50+
if commit:
51+
connection.commit()
52+
finally:
53+
cursor.close()
5454

5555

5656
def db_connect():
@@ -75,21 +75,14 @@ def db_connect():
7575
db_pool = None
7676
log.error("DB connection failed")
7777

78-
# This class is only needed until we replace all db.cursor() calls with get_db_cursor()
79-
class ThinDBWrapper(object):
80-
@staticmethod
81-
def cursor():
82-
return get_db_cursor()
83-
db = ThinDBWrapper
84-
8578

8679
###########################
8780
# DB schema migration #
8881
###########################
8982

9083
def get_existing_schema_version():
9184
existing_schema_version = 0
92-
with db.cursor() as c:
85+
with get_db_cursor() as c:
9386
try:
9487
c.execute(f'SELECT schema_version FROM {DB_PREFIX}runtime_data;')
9588
res = c.fetchone()
@@ -120,7 +113,7 @@ def migrate_if_needed():
120113
method_to_call = getattr(sys.modules[__name__], method_name)
121114
method_to_call()
122115
# automatically upgrade schema version if there is no exception:
123-
with db.cursor() as c:
116+
with get_db_cursor() as c:
124117
c.execute(f'UPDATE {DB_PREFIX}runtime_data SET schema_version = %s;', (try_migrating_to,))
125118
try_migrating_to += 1
126119
if try_migrating_to == existing_schema_version + 1:
@@ -130,12 +123,12 @@ def migrate_if_needed():
130123

131124

132125
def migration_step_1():
133-
with db.cursor() as c:
126+
with get_db_cursor() as c:
134127
c.execute(f'CREATE TABLE {DB_PREFIX}runtime_data (schema_version SMALLSERIAL NOT NULL);')
135128
c.execute(f'INSERT INTO {DB_PREFIX}runtime_data (schema_version) VALUES (1);')
136129

137130
def migration_step_2():
138-
with db.cursor() as c:
131+
with get_db_cursor() as c:
139132
# UNLOGGED: Disabling WAL avoids high I/O load. Since NetFlow data is of temporary nature, this still
140133
# allows us to perform queries, but if the database crashes it is acceptable to lose all of the records.
141134
c.execute(f'CREATE UNLOGGED TABLE {DB_PREFIX}records (seq BIGSERIAL NOT NULL PRIMARY KEY, ts NUMERIC(16,6) NOT NULL, client_ip TEXT);')

netflowbot.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import requests
1414

1515
from grafoleancollector import Collector, send_results_to_grafolean
16-
from dbutils import db, DB_PREFIX
16+
from dbutils import get_db_cursor, DB_PREFIX
1717
from lookup import PROTOCOLS, DIRECTION_INGRESS, DIRECTION_EGRESS
1818

1919
logging.basicConfig(format='%(asctime)s.%(msecs)03d | %(levelname)s | %(message)s',
@@ -26,7 +26,7 @@
2626

2727

2828
def _get_last_used_seq(job_id):
29-
with db.cursor() as c:
29+
with get_db_cursor() as c:
3030
c.execute(f'SELECT j.last_used_seq, r.ts FROM {DB_PREFIX}bot_jobs j, {DB_PREFIX}records r WHERE j.job_id = %s AND j.last_used_seq = r.seq;', (job_id,))
3131
rec = c.fetchone()
3232
if rec is None:
@@ -35,7 +35,7 @@ def _get_last_used_seq(job_id):
3535
return last_used_seq, ts
3636

3737
def _get_current_max_seq():
38-
with db.cursor() as c:
38+
with get_db_cursor() as c:
3939
c.execute(f"SELECT seq, ts FROM {DB_PREFIX}records WHERE seq = (SELECT MAX(seq) FROM {DB_PREFIX}records);")
4040
rec = c.fetchone()
4141
if rec is None:
@@ -44,7 +44,7 @@ def _get_current_max_seq():
4444
return max_seq, now_ts
4545

4646
def _save_current_max_seq(job_id, seq):
47-
with db.cursor() as c:
47+
with get_db_cursor() as c:
4848
c.execute(f"INSERT INTO {DB_PREFIX}bot_jobs (job_id, last_used_seq) VALUES (%s, %s) ON CONFLICT (job_id) DO UPDATE SET last_used_seq = %s;", (job_id, seq, seq))
4949

5050

@@ -178,7 +178,7 @@ def construct_output_path_prefix(interval_label, direction, entity_id, interface
178178
@staticmethod
179179
def get_traffic_for_entity(interval_label, last_seq, max_seq, time_between, direction, entity_id, entity_ip):
180180
# returns cumulative traffic for the whole entity, and traffic per interface for this entity
181-
with db.cursor() as c:
181+
with get_db_cursor() as c:
182182

183183
c.execute(f"""
184184
SELECT
@@ -218,7 +218,7 @@ def get_traffic_for_entity(interval_label, last_seq, max_seq, time_between, dire
218218
# @staticmethod
219219
# def get_traffic_all_entities(interval_label, last_seq, max_seq, direction):
220220
# output_path = NetFlowBot.construct_output_path_prefix(interval_label, direction, entity_id=None, interface=None)
221-
# with db.cursor() as c:
221+
# with get_db_cursor() as c:
222222
# c.execute(f"""
223223
# SELECT
224224
# sum(f.in_bytes)
@@ -242,7 +242,7 @@ def get_traffic_for_entity(interval_label, last_seq, max_seq, time_between, dire
242242

243243
# @staticmethod
244244
# def get_top_N_IPs(output_path_prefix, from_time, to_time, interface_index, is_direction_in=True):
245-
# with db.cursor() as c:
245+
# with get_db_cursor() as c:
246246
# # TODO: missing check for IP: r.client_ip = %s AND
247247
# c.execute(f"""
248248
# SELECT
@@ -275,7 +275,7 @@ def get_traffic_for_entity(interval_label, last_seq, max_seq, time_between, dire
275275

276276
# @staticmethod
277277
# def get_top_N_protocols(output_path_prefix, from_time, to_time, interface_index, is_direction_in=True):
278-
# with db.cursor() as c:
278+
# with get_db_cursor() as c:
279279
# # TODO: missing check for IP: r.client_ip = %s AND
280280
# c.execute(f"""
281281
# SELECT

netflowwriter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from colors import color
1414

1515
from lookup import PROTOCOLS
16-
from dbutils import migrate_if_needed, db, DB_PREFIX
16+
from dbutils import migrate_if_needed, get_db_cursor, DB_PREFIX
1717

1818

1919
logging.basicConfig(format='%(asctime)s.%(msecs)03d | %(levelname)s | %(message)s',
@@ -69,7 +69,7 @@ def write_record(j):
6969
# }
7070
# https://www.cisco.com/en/US/technologies/tk648/tk362/technologies_white_paper09186a00800a3db9.html#wp9001622
7171

72-
with db.cursor() as c:
72+
with get_db_cursor() as c:
7373
# first save the flow record:
7474
ts = j['ts']
7575
log.info(f"Received record [{j['seq']}]: {datetime.utcfromtimestamp(ts)} from {j['client']}")

0 commit comments

Comments
 (0)