Skip to content

Commit 6845ad5

Browse files
authored
Wrap all sqlalchemy usages of session into a block (#1706)
1 parent 99d9c81 commit 6845ad5

File tree

1 file changed

+72
-68
lines changed
  • backend/src/serverless/microservices/python/crowd-backend/crowd/backend/repository

1 file changed

+72
-68
lines changed

backend/src/serverless/microservices/python/crowd-backend/crowd/backend/repository/repository.py

Lines changed: 72 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@ def __init__(self, tenant_id="", db_url=False, test=False, send=True):
7070
)
7171

7272
Base.metadata.create_all(self.engine, checkfirst=True)
73-
Session = sessionmaker(bind=self.engine)
74-
self.session = Session()
73+
self.Session = sessionmaker(bind=self.engine)
7574

7675
self.tenant_id = tenant_id
7776
self.send = send
@@ -106,25 +105,25 @@ def find_in_table(self, table, query, many=False):
106105
dict: document
107106
"""
108107

109-
search_query = self.session.query(table)
110-
for attr, value in query.items():
111-
112-
# Check if query is nested
113-
nested_count = attr.count(".")
114-
# If nested
115-
if nested_count > 0:
116-
attributes = attr.split(".")
117-
nested_attributes = tuple(attributes[1:])
118-
# Define nested expression
119-
expr = getattr(table, attributes[0])[nested_attributes]
120-
# Execute search_query
121-
search_query = search_query.filter(expr == json.dumps(value))
122-
else:
123-
search_query = search_query.filter(getattr(table, attr) == value)
108+
with self.Session() as session:
109+
search_query = session.query(table)
110+
for attr, value in query.items():
111+
# Check if query is nested
112+
nested_count = attr.count(".")
113+
# If nested
114+
if nested_count > 0:
115+
attributes = attr.split(".")
116+
nested_attributes = tuple(attributes[1:])
117+
# Define nested expression
118+
expr = getattr(table, attributes[0])[nested_attributes]
119+
# Execute search_query
120+
search_query = search_query.filter(expr == json.dumps(value))
121+
else:
122+
search_query = search_query.filter(getattr(table, attr) == value)
124123

125-
if many:
126-
return search_query.all()
127-
return search_query.first()
124+
if many:
125+
return search_query.all()
126+
return search_query.first()
128127

129128
def find_by_id(self, table, id):
130129
"""
@@ -138,15 +137,17 @@ def find_by_id(self, table, id):
138137
dict: the document
139138
"""
140139

141-
return self.session.query(table).get(id)
140+
with self.Session() as session:
141+
return session.query(table).get(id)
142142

143143
def find_all_usernames(self):
144144
with self.engine.connect() as con:
145145
return con.execute(
146146
f"""select m."id", mw."username", m."displayName", m."emails"
147147
from "members" m
148148
inner join "memberActivityAggregatesMVs" mw on m.id = mw.id
149-
where m."tenantId" = '{self.tenant_id}'""").fetchall()
149+
where m."tenantId" = '{self.tenant_id}'"""
150+
).fetchall()
150151

151152
def find_all(
152153
self, table, ignore_tenant: "bool" = False, query: "dict" = None, order: "dict" = None
@@ -173,29 +174,30 @@ def find_all(
173174
**{dbk.TENANT: uuid.UUID(self.tenant_id)},
174175
}
175176

176-
search_query = self.session.query(table)
177-
for attr, value in query.items():
178-
# Check if query is nested
179-
nested_count = attr.count(".")
180-
# If nested
181-
if nested_count > 0:
182-
attributes = attr.split(".")
183-
nested_attributes = tuple(attributes[1:])
184-
# Define nested expression
185-
expr = getattr(table, attributes[0])[nested_attributes]
186-
# Execute search_query
187-
search_query = search_query.filter(expr == json.dumps(value))
188-
else:
189-
search_query = search_query.filter(getattr(table, attr) == value)
190-
191-
if order:
192-
for key, value in order.items():
193-
if value:
194-
search_query = search_query.order_by(asc(key))
177+
with self.Session() as session:
178+
search_query = session.query(table)
179+
for attr, value in query.items():
180+
# Check if query is nested
181+
nested_count = attr.count(".")
182+
# If nested
183+
if nested_count > 0:
184+
attributes = attr.split(".")
185+
nested_attributes = tuple(attributes[1:])
186+
# Define nested expression
187+
expr = getattr(table, attributes[0])[nested_attributes]
188+
# Execute search_query
189+
search_query = search_query.filter(expr == json.dumps(value))
195190
else:
196-
search_query = search_query.order_by(desc(key))
191+
search_query = search_query.filter(getattr(table, attr) == value)
197192

198-
return search_query.all()
193+
if order:
194+
for key, value in order.items():
195+
if value:
196+
search_query = search_query.order_by(asc(key))
197+
else:
198+
search_query = search_query.order_by(desc(key))
199+
200+
return search_query.all()
199201

200202
def find_activities(self, search_filters=None):
201203
if not search_filters:
@@ -208,22 +210,23 @@ def count(self, table, search_filters=None):
208210

209211
search_filters[dbk.TENANT] = uuid.UUID(self.tenant_id)
210212

211-
search_query = self.session.query(table)
212-
for attr, value in search_filters.items():
213-
# Check if query is nested
214-
nested_count = attr.count(".")
215-
# If nested
216-
if nested_count > 0:
217-
attributes = attr.split(".")
218-
nested_attributes = tuple(attributes[1:])
219-
# Define nested expression
220-
expr = getattr(table, attributes[0])[nested_attributes]
221-
# Execute query
222-
search_query = search_query.filter(expr == json.dumps(value))
223-
else:
224-
search_query = search_query.filter(getattr(table, attr) == value)
213+
with self.Session() as session:
214+
search_query = session.query(table)
215+
for attr, value in search_filters.items():
216+
# Check if query is nested
217+
nested_count = attr.count(".")
218+
# If nested
219+
if nested_count > 0:
220+
attributes = attr.split(".")
221+
nested_attributes = tuple(attributes[1:])
222+
# Define nested expression
223+
expr = getattr(table, attributes[0])[nested_attributes]
224+
# Execute query
225+
search_query = search_query.filter(expr == json.dumps(value))
226+
else:
227+
search_query = search_query.filter(getattr(table, attr) == value)
225228

226-
return search_query.count()
229+
return search_query.count()
227230

228231
def find_available_microservices(self, service):
229232
"""
@@ -253,16 +256,17 @@ def find_new_members(self, microservice, query: "dict" = None) -> "list[dict]":
253256
**{dbk.TENANT: uuid.UUID(self.tenant_id)},
254257
}
255258

256-
search_query = self.session.query(Member)
259+
with self.Session() as session:
260+
search_query = session.query(Member)
257261

258-
# Filter with query
259-
for attr, value in query.items():
260-
search_query = search_query.filter(getattr(Member, attr) == value)
262+
# Filter with query
263+
for attr, value in query.items():
264+
search_query = search_query.filter(getattr(Member, attr) == value)
261265

262-
# Find members that are new
263-
# We use a security padding of 5 minutes
264-
search_query = search_query.filter(
265-
Member.createdAt >= (microservice.updatedAt - timedelta(minutes=5))
266-
).order_by(Member.createdAt.desc())
266+
# Find members that are new
267+
# We use a security padding of 5 minutes
268+
search_query = search_query.filter(
269+
Member.createdAt >= (microservice.updatedAt - timedelta(minutes=5))
270+
).order_by(Member.createdAt.desc())
267271

268-
return search_query.all()
272+
return search_query.all()

0 commit comments

Comments
 (0)