diff --git a/CHANGELOG.md b/CHANGELOG.md index 855418f..d30ef03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,13 @@ Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to _No notable unreleased changes_ +## [0.1.23] - 2025-11-18 + +### Fixed + +- Fixed a bug where `SaferAddFieldForeignKey` ignored the `ForeignKey` `to_field` + parameter, resulting in an incorrect column type and incorrect primary key reference. + ## [0.1.22] - 2025-08-07 ### Fixed @@ -228,7 +235,8 @@ _No notable unreleased changes_ - `SaferAddIndexConcurrently` migration operation to create new Postgres indexes in a safer, idempotent way. -[Unreleased]: https://github.com/octoenergy/django-migration-helpers/compare/v0.1.22...HEAD +[Unreleased]: https://github.com/octoenergy/django-migration-helpers/compare/v0.1.23...HEAD +[0.1.23]: https://github.com/octoenergy/django-migration-helpers/compare/v0.1.22...v0.1.23 [0.1.22]: https://github.com/octoenergy/django-migration-helpers/compare/v0.1.21...v0.1.22 [0.1.21]: https://github.com/octoenergy/django-migration-helpers/compare/v0.1.20...v0.1.21 [0.1.20]: https://github.com/kraken-tech/django-pg-migration-tools/compare/v0.1.19...v0.1.20 diff --git a/pyproject.toml b/pyproject.toml index dff9a71..013e9eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ where = ["src"] [project] name = "django_pg_migration_tools" -version = "0.1.22" +version = "0.1.23" description = "Tools for making Django migrations safer and more scalable." license.file = "LICENSE" readme = "README.md" @@ -175,7 +175,7 @@ exclude_also = [ [tool.bumpversion] # Do not manually edit the version, use `make version_{type}` instead. -current_version = "0.1.22" +current_version = "0.1.23" # Relabel the Unreleased section of the changelog and add a new unreleased section as a reminder to # add to it. diff --git a/src/django_pg_migration_tools/operations.py b/src/django_pg_migration_tools/operations.py index 269cc27..5834277 100644 --- a/src/django_pg_migration_tools/operations.py +++ b/src/django_pg_migration_tools/operations.py @@ -1306,8 +1306,26 @@ def _get_remote_pk_field(self) -> models.Field[Any, Any]: assert isinstance(pk_field, models.Field) return pk_field + def _get_remote_to_field(self) -> models.Field[Any, Any]: + to_field = self.field.to_fields[0] + remote_model = self._get_remote_model() + + remote_field = next( + field for field in remote_model._meta.get_fields() if field.name == to_field + ) + assert isinstance(remote_field, models.Field) + return remote_field + + def _get_target_field(self) -> models.Field[Any, Any]: + # If to_field is specified, we don't want to default to using the pk. + if self.field.to_fields and self.field.to_fields[0]: + target_field = self._get_remote_to_field() + else: + target_field = self._get_remote_pk_field() + return target_field + def _get_column_type(self) -> str: - remote_field = self._get_remote_pk_field() + remote_field = self._get_target_field() column_type: str | None = remote_field.db_type(self.schema_editor.connection) assert column_type is not None return column_type @@ -1384,8 +1402,8 @@ def _is_constraint_valid(self) -> bool: def _alter_table_add_not_valid_fk(self) -> None: remote_model = self._get_remote_model() - remote_pk_field = self._get_remote_pk_field() - referred_column_name = remote_pk_field.db_column or remote_pk_field.name + remote_target_field = self._get_target_field() + referred_column_name = remote_target_field.db_column or remote_target_field.name self.schema_editor.execute( psycopg_sql.SQL(ConstraintQueries.ALTER_TABLE_ADD_NOT_VALID_FK) .format( diff --git a/tests/django_pg_migration_tools/test_operations.py b/tests/django_pg_migration_tools/test_operations.py index a3463bf..ee9f83f 100644 --- a/tests/django_pg_migration_tools/test_operations.py +++ b/tests/django_pg_migration_tools/test_operations.py @@ -28,6 +28,7 @@ NullIntFieldModel, UniqueConditionCharModel, UniqueExpressionCharModel, + UUIDFieldModel, get_check_constraint, ) @@ -3789,6 +3790,77 @@ def test_operation_when_related_model_has_explicit_pk_field(self): AND attname = 'other_int_model_field_id'; """) + @pytest.mark.django_db(transaction=True) + def test_operation_when_has_explicit_non_pk_to_field(self): + project_state = ProjectState() + project_state.add_model(ModelState.from_model(IntModel)) + project_state.add_model(ModelState.from_model(UUIDFieldModel)) + new_state = project_state.clone() + + # Relate the IntModel -> UUIDFieldModel by UUIDFieldModel.uuid_field relationship. + operation = operations.SaferAddFieldForeignKey( + model_name="intmodel", + name="uuid_model_uuid_field", + field=models.ForeignKey( + "example_app.UUIDFieldModel", + null=True, + on_delete=models.CASCADE, + db_index=False, + to_field="uuid_field", # Do not default to the Primary Key field. + ), + ) + + operation.state_forwards(self.app_label, new_state) + with connection.schema_editor(atomic=False, collect_sql=False) as editor: + with utils.CaptureQueriesContext(connection) as queries: + operation.database_forwards( + self.app_label, editor, from_state=project_state, to_state=new_state + ) + + assert len(queries) == 4 + assert queries[0]["sql"] == dedent(""" + SELECT 1 + FROM pg_catalog.pg_attribute + WHERE + attrelid = 'example_app_intmodel'::regclass + AND attname = 'uuid_model_uuid_field_id'; + """) + assert queries[1]["sql"] == dedent(""" + ALTER TABLE "example_app_intmodel" + ADD COLUMN IF NOT EXISTS "uuid_model_uuid_field_id" + uuid NULL; + """) + assert queries[2]["sql"] == dedent(""" + ALTER TABLE "example_app_intmodel" + ADD CONSTRAINT "example_app_intmodel_uuid_model_uuid_field_id_fk" FOREIGN KEY ("uuid_model_uuid_field_id") + REFERENCES "example_app_uuidfieldmodel" ("uuid_field") + DEFERRABLE INITIALLY DEFERRED + NOT VALID; + """) + assert queries[3]["sql"] == dedent(""" + ALTER TABLE "example_app_intmodel" + VALIDATE CONSTRAINT "example_app_intmodel_uuid_model_uuid_field_id_fk"; + """) + + with connection.schema_editor(atomic=False, collect_sql=False) as editor: + with utils.CaptureQueriesContext(connection) as reverse_queries: + operation.database_backwards( + self.app_label, editor, from_state=new_state, to_state=project_state + ) + + assert len(reverse_queries) == 2 + assert reverse_queries[0]["sql"] == dedent(""" + SELECT 1 + FROM pg_catalog.pg_attribute + WHERE + attrelid = 'example_app_intmodel'::regclass + AND attname = 'uuid_model_uuid_field_id'; + """) + assert reverse_queries[1]["sql"] == dedent(""" + ALTER TABLE "example_app_intmodel" + DROP COLUMN "uuid_model_uuid_field_id"; + """) + class TestSaferAddCheckConstraint: app_label = "example_app" diff --git a/tests/example_app/models.py b/tests/example_app/models.py index 086e008..6b715f0 100644 --- a/tests/example_app/models.py +++ b/tests/example_app/models.py @@ -95,6 +95,18 @@ class CharIDModel(models.Model): id = models.CharField(max_length=42, primary_key=True) +class UUIDFieldModel(models.Model): + uuid_field = models.UUIDField() + + class Meta: + constraints = ( + models.UniqueConstraint( + fields=["uuid_field"], + name="unique_uuid_field", + ), + ) + + class ModelWithCheckConstraint(models.Model): class Meta: constraints = (