Skip to content

Commit 21d21df

Browse files
committed
Fix bug with custom column names
It would fail :-)
1 parent 0c6e0ab commit 21d21df

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

psqlextra/compiler.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,8 @@ def _rewrite_insert_nothing(self, sql, params, returning):
124124
# for conflicts
125125
conflict_target = self._build_conflict_target()
126126

127-
def format_field_name(field_name):
128-
if isinstance(field_name, tuple):
129-
return self.qn(field_name[0])
130-
return self.qn(field_name)
131-
132127
where_clause = ', '.join([
133-
'{0} = %s'.format(format_field_name(field_name))
128+
'{0} = %s'.format(self._format_field_name(field_name))
134129
for field_name in self.query.conflict_target
135130
])
136131

@@ -188,10 +183,20 @@ def _assert_valid_field(field_name):
188183
'names and hstore key.'
189184
) % str(field_name))
190185

191-
for field in self.query.conflict_target:
192-
_assert_valid_field(field)
193-
conflict_target.append(
194-
self._format_field_name(field))
186+
for field_name in self.query.conflict_target:
187+
_assert_valid_field(field_name)
188+
189+
# special handling for hstore keys
190+
if isinstance(field_name, tuple):
191+
conflict_target.append(
192+
'(%s->\'%s\')' % (
193+
self._format_field_name(field_name),
194+
field_name[1]
195+
)
196+
)
197+
else:
198+
conflict_target.append(
199+
self._format_field_name(field_name))
195200

196201
return '(%s)' % ','.join(conflict_target)
197202

@@ -207,28 +212,30 @@ def _get_model_field(self, name: str):
207212
no such field exists.
208213
"""
209214

215+
field_name = name
216+
if isinstance(field_name, tuple):
217+
field_name = field_name[0]
218+
210219
for field in self.query.model._meta.local_concrete_fields:
211-
if field.name == name or field.column == name:
220+
if field.name == field_name or field.column == field_name:
212221
return field
213222

214223
return None
215224

216-
def _format_field_name(self, name) -> str:
225+
def _format_field_name(self, field_name) -> str:
217226
"""Formats a field's name for usage in SQL.
218227
219228
Arguments:
220-
name:
229+
field_name:
221230
The field name to format.
222231
223232
Returns:
224233
The specified field name formatted for
225234
usage in SQL.
226235
"""
227236

228-
if isinstance(name, tuple):
229-
return '(%s->\'%s\')' % name
230-
231-
return self.qn(name)
237+
field = self._get_model_field(field_name)
238+
return self.qn(field.column)
232239

233240
def _format_field_value(self, field_name) -> str:
234241
"""Formats a field's value for usage in SQL.
@@ -243,11 +250,9 @@ def _format_field_value(self, field_name) -> str:
243250
in SQL.
244251
"""
245252

246-
if isinstance(field_name, tuple):
247-
field_name, _ = field_name
248-
253+
field = self._get_model_field(field_name)
249254
return SQLInsertCompiler.prepare_value(
250255
self,
251-
self._get_model_field(field_name),
252-
getattr(self.query.objs[0], field_name)
256+
field,
257+
getattr(self.query.objs[0], field.name)
253258
)

tests/test_on_conflict.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,19 @@ def test_on_conflict_outdated_model(conflict_action):
214214
.on_conflict(['title'], conflict_action)
215215
.insert_and_get(title='beer')
216216
)
217+
218+
@pytest.mark.parametrize("conflict_action", CONFLICT_ACTIONS)
219+
def test_on_conflict_custom_column_names(conflict_action):
220+
"""Asserts that models with custom column names (models
221+
where the column and field name are different) work properly."""
222+
223+
model = get_fake_model({
224+
'title': models.CharField(max_length=140, unique=True, db_column='beer'),
225+
'description': models.CharField(max_length=255, db_column='desc')
226+
})
227+
228+
id = (
229+
model.objects
230+
.on_conflict(['title'], conflict_action)
231+
.insert(title='yeey', description='great thing')
232+
)

0 commit comments

Comments
 (0)