Skip to content

Custom through models are re-created on every model save #219

@petrklus

Description

@petrklus

I am using a custom through model, example below:

    features = SortedManyToManyField(
        ...
        sort_value_field_name="sort_number",
        through="ProductDetailFeatureThrough",
        through_fields=("productdetailv3", "productfeaturev3"),
    )

This is working fine - reordering and assigning works.

However, I have noticed that all the existing through models are wiped and re-created anew on every model save. This, of course, wipes any custom parameters that has been stored in them.

I have solved by creating a fork of sortedm2m and modifying it in the following places:

First of all, I've added a field on my custom through class called

class ProductDetailFeatureThrough(models.Model):
    _through_extra_params = ["param_1", "param_2", "param_3", "param_4", "param_5"]
    ...

Then, in the sortedm2m/fields.py, I have modified the set method to temporarily store the existing data to be then re-added to the newly created objects.

...
        def set(self, objs, **kwargs):  # pylint: disable=arguments-differ
            # Choosing to clear first will ensure the order is maintained.
            kwargs["clear"] = True

            through_extra_params = getattr(self.through, "_through_extra_params", [])
            if through_extra_params:
                # Fetch the additional fields from existing through objects
                # and store them in a dict for later use in _add_items

                existing_through_objs = self.through.objects.filter(
                    **{
                        self.source_field_name: self.related_val[0],
                    }
                ).values(self.target_field_name, *through_extra_params)
                print("Existing through objs", existing_through_objs)

                def process_item(item):
                    # split into key and additional params
                    key = item[self.target_field_name]
                    del item[self.target_field_name]
                    return (key, item)

                existing_through_objs_by_target = dict(
                    map(process_item, existing_through_objs)
                )

                # set for further re-use in _add_items
                setattr(
                    self, self._extra_data_param_name(), existing_through_objs_by_target
                )
            else:
                pass
                # print("NO EXTRA PARAMS found")

            super().set(objs, **kwargs)


        def _add_items(self, source_field_name, target_field_name, *objs, **kwargs):
            # source_field_name: the PK fieldname in join table for the source object
            # target_field_name: the PK fieldname in join table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.
            # **kwargs: in Django >= 2.2; contains `through_defaults` key.
            through_defaults = kwargs.get("through_defaults") or {}

            existing_through_objs_by_target = getattr(
                self, self._extra_data_param_name(), {}
            )

            if existing_through_objs_by_target:
                # clean up the object, we do not want the attr to stick around
                delattr(self, self._extra_data_param_name())

            print("---")
            print("Additional defaults data", existing_through_objs_by_target)
            print("---")

            # If there aren't any objects, there is nothing to do.
            if objs:
                # Django uses a set here, we need to use a list to keep the
                # correct ordering.
                new_ids = []
                for obj in objs:
                    if isinstance(obj, self.model):
                        if not router.allow_relation(obj, self.instance):
                            raise ValueError(
                                'Cannot add "%r": instance is on database "%s", value is on database "%s"'
                                % (
                                    obj,
                                    self.instance._state.db,
                                    obj._state.db,
                                )  # pylint: disable=protected-access
                            )

                        fk_val = self.through._meta.get_field(
                            target_field_name
                        ).get_foreign_related_value(obj)[0]

                        if fk_val is None:
                            raise ValueError(
                                'Cannot add "%r": the value for field "%s" is None'
                                % (obj, target_field_name)
                            )

                        new_ids.append(fk_val)
                    elif isinstance(obj, Model):
                        raise TypeError(
                            "'%s' instance expected, got %r"
                            % (self.model._meta.object_name, obj)
                        )
                    else:
                        new_ids.append(obj)

                db = router.db_for_write(self.through, instance=self.instance)
                manager = self.through._default_manager.using(
                    db
                )  # pylint: disable=protected-access
                vals = (
                    self.through._default_manager.using(
                        db
                    )  # pylint: disable=protected-access
                    .values_list(target_field_name, flat=True)
                    .filter(
                        **{
                            source_field_name: self.related_val[0],
                            "%s__in" % target_field_name: new_ids,
                        }
                    )
                )

                # make set.difference_update() keeping ordering
                new_ids_set = set(new_ids)
                new_ids_set.difference_update(vals)

                new_ids = list(filter(lambda _id: _id in new_ids_set, new_ids))

                # Add the ones that aren't there already
                with transaction.atomic(using=db, savepoint=False):
                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(
                            sender=self.through,
                            action="pre_add",
                            instance=self.instance,
                            reverse=self.reverse,
                            model=self.model,
                            pk_set=new_ids_set,
                            using=db,
                        )

                    rel_source_fk = self.related_val[0]
                    rel_through = self.through
                    sort_field_name = (
                        rel_through._sort_field_name
                    )  # pylint: disable=protected-access

                    # Use the max of all indices as start index...
                    # maybe an autoincrement field should do the job more efficiently ?
                    source_queryset = manager.filter(
                        **{"%s_id" % source_field_name: rel_source_fk}
                    )
                    sort_value_max = (
                        source_queryset.aggregate(max=Max(sort_field_name))["max"] or 0
                    )

                    ## **MODIFIED bulk_data to include existing_through_objs_by_target**
                    bulk_data = [
                        dict(
                            through_defaults,
                            **{
                                "%s_id" % source_field_name: rel_source_fk,
                                "%s_id" % target_field_name: obj_id,
                                sort_field_name: obj_idx,
                            },
                            **existing_through_objs_by_target.get(obj_id, {}),
                        )
                        for obj_idx, obj_id in enumerate(new_ids, sort_value_max + 1)
                    ]

                    # print("NEW DATA", bulk_data)
                    # print("---")

                    manager.bulk_create([rel_through(**data) for data in bulk_data])

                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(
                            sender=self.through,
                            action="post_add",
                            instance=self.instance,
                            reverse=self.reverse,
                            model=self.model,
                            pk_set=new_ids_set,
                            using=db,
                        )

What shall I do to get the above incorporated to the main source? I appreciate this is a relatively quick&dirty solution, any feedback is very welcome!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions