Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to

_No notable unreleased changes_

## [0.1.23] - 2025-11-18

### Fixed

- Fixed a bug where `SaferAddFieldForeignKey` ignored the `ForeignKey` `to_field`
parameter, resulting in an incorrect column type and incorrect primary key reference.

## [0.1.22] - 2025-08-07

### Fixed
Expand Down Expand Up @@ -228,7 +235,8 @@ _No notable unreleased changes_
- `SaferAddIndexConcurrently` migration operation to create new Postgres
indexes in a safer, idempotent way.

[Unreleased]: https://github.com/octoenergy/django-migration-helpers/compare/v0.1.22...HEAD
[Unreleased]: https://github.com/octoenergy/django-migration-helpers/compare/v0.1.23...HEAD
[0.1.23]: https://github.com/octoenergy/django-migration-helpers/compare/v0.1.22...v0.1.23
[0.1.22]: https://github.com/octoenergy/django-migration-helpers/compare/v0.1.21...v0.1.22
[0.1.21]: https://github.com/octoenergy/django-migration-helpers/compare/v0.1.20...v0.1.21
[0.1.20]: https://github.com/kraken-tech/django-pg-migration-tools/compare/v0.1.19...v0.1.20
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ where = ["src"]

[project]
name = "django_pg_migration_tools"
version = "0.1.22"
version = "0.1.23"
description = "Tools for making Django migrations safer and more scalable."
license.file = "LICENSE"
readme = "README.md"
Expand Down Expand Up @@ -175,7 +175,7 @@ exclude_also = [

[tool.bumpversion]
# Do not manually edit the version, use `make version_{type}` instead.
current_version = "0.1.22"
current_version = "0.1.23"

# Relabel the Unreleased section of the changelog and add a new unreleased section as a reminder to
# add to it.
Expand Down
24 changes: 21 additions & 3 deletions src/django_pg_migration_tools/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,8 +1306,26 @@ def _get_remote_pk_field(self) -> models.Field[Any, Any]:
assert isinstance(pk_field, models.Field)
return pk_field

def _get_remote_to_field(self) -> models.Field[Any, Any]:
to_field = self.field.to_fields[0]
remote_model = self._get_remote_model()

remote_field = next(
field for field in remote_model._meta.get_fields() if field.name == to_field
)
assert isinstance(remote_field, models.Field)
return remote_field

def _get_target_field(self) -> models.Field[Any, Any]:
# If to_field is specified, we don't want to default to using the pk.
if self.field.to_fields and self.field.to_fields[0]:
target_field = self._get_remote_to_field()
else:
target_field = self._get_remote_pk_field()
return target_field

def _get_column_type(self) -> str:
remote_field = self._get_remote_pk_field()
remote_field = self._get_target_field()
column_type: str | None = remote_field.db_type(self.schema_editor.connection)
assert column_type is not None
return column_type
Expand Down Expand Up @@ -1384,8 +1402,8 @@ def _is_constraint_valid(self) -> bool:

def _alter_table_add_not_valid_fk(self) -> None:
remote_model = self._get_remote_model()
remote_pk_field = self._get_remote_pk_field()
referred_column_name = remote_pk_field.db_column or remote_pk_field.name
remote_target_field = self._get_target_field()
referred_column_name = remote_target_field.db_column or remote_target_field.name
self.schema_editor.execute(
psycopg_sql.SQL(ConstraintQueries.ALTER_TABLE_ADD_NOT_VALID_FK)
.format(
Expand Down
72 changes: 72 additions & 0 deletions tests/django_pg_migration_tools/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
NullIntFieldModel,
UniqueConditionCharModel,
UniqueExpressionCharModel,
UUIDFieldModel,
get_check_constraint,
)

Expand Down Expand Up @@ -3789,6 +3790,77 @@ def test_operation_when_related_model_has_explicit_pk_field(self):
AND attname = 'other_int_model_field_id';
""")

@pytest.mark.django_db(transaction=True)
def test_operation_when_has_explicit_non_pk_to_field(self):
project_state = ProjectState()
project_state.add_model(ModelState.from_model(IntModel))
project_state.add_model(ModelState.from_model(UUIDFieldModel))
new_state = project_state.clone()

# Relate the IntModel -> UUIDFieldModel by UUIDFieldModel.uuid_field relationship.
operation = operations.SaferAddFieldForeignKey(
model_name="intmodel",
name="uuid_model_uuid_field",
field=models.ForeignKey(
"example_app.UUIDFieldModel",
null=True,
on_delete=models.CASCADE,
db_index=False,
to_field="uuid_field", # Do not default to the Primary Key field.
),
)

operation.state_forwards(self.app_label, new_state)
with connection.schema_editor(atomic=False, collect_sql=False) as editor:
with utils.CaptureQueriesContext(connection) as queries:
operation.database_forwards(
self.app_label, editor, from_state=project_state, to_state=new_state
)

assert len(queries) == 4
assert queries[0]["sql"] == dedent("""
SELECT 1
FROM pg_catalog.pg_attribute
WHERE
attrelid = 'example_app_intmodel'::regclass
AND attname = 'uuid_model_uuid_field_id';
""")
assert queries[1]["sql"] == dedent("""
ALTER TABLE "example_app_intmodel"
ADD COLUMN IF NOT EXISTS "uuid_model_uuid_field_id"
uuid NULL;
""")
assert queries[2]["sql"] == dedent("""
ALTER TABLE "example_app_intmodel"
ADD CONSTRAINT "example_app_intmodel_uuid_model_uuid_field_id_fk" FOREIGN KEY ("uuid_model_uuid_field_id")
REFERENCES "example_app_uuidfieldmodel" ("uuid_field")
DEFERRABLE INITIALLY DEFERRED
NOT VALID;
""")
assert queries[3]["sql"] == dedent("""
ALTER TABLE "example_app_intmodel"
VALIDATE CONSTRAINT "example_app_intmodel_uuid_model_uuid_field_id_fk";
""")

with connection.schema_editor(atomic=False, collect_sql=False) as editor:
with utils.CaptureQueriesContext(connection) as reverse_queries:
operation.database_backwards(
self.app_label, editor, from_state=new_state, to_state=project_state
)

assert len(reverse_queries) == 2
assert reverse_queries[0]["sql"] == dedent("""
SELECT 1
FROM pg_catalog.pg_attribute
WHERE
attrelid = 'example_app_intmodel'::regclass
AND attname = 'uuid_model_uuid_field_id';
""")
assert reverse_queries[1]["sql"] == dedent("""
ALTER TABLE "example_app_intmodel"
DROP COLUMN "uuid_model_uuid_field_id";
""")


class TestSaferAddCheckConstraint:
app_label = "example_app"
Expand Down
12 changes: 12 additions & 0 deletions tests/example_app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,18 @@ class CharIDModel(models.Model):
id = models.CharField(max_length=42, primary_key=True)


class UUIDFieldModel(models.Model):
uuid_field = models.UUIDField()

class Meta:
constraints = (
models.UniqueConstraint(
fields=["uuid_field"],
name="unique_uuid_field",
),
)


class ModelWithCheckConstraint(models.Model):
class Meta:
constraints = (
Expand Down
Loading