-
-
Notifications
You must be signed in to change notification settings - Fork 151
Open
Description
I am using a custom through model, example below:
features = SortedManyToManyField(
...
sort_value_field_name="sort_number",
through="ProductDetailFeatureThrough",
through_fields=("productdetailv3", "productfeaturev3"),
)This is working fine - reordering and assigning works.
However, I have noticed that all the existing through models are wiped and re-created anew on every model save. This, of course, wipes any custom parameters that has been stored in them.
I have solved by creating a fork of sortedm2m and modifying it in the following places:
First of all, I've added a field on my custom through class called
class ProductDetailFeatureThrough(models.Model):
_through_extra_params = ["param_1", "param_2", "param_3", "param_4", "param_5"]
...Then, in the sortedm2m/fields.py, I have modified the set method to temporarily store the existing data to be then re-added to the newly created objects.
...
def set(self, objs, **kwargs): # pylint: disable=arguments-differ
# Choosing to clear first will ensure the order is maintained.
kwargs["clear"] = True
through_extra_params = getattr(self.through, "_through_extra_params", [])
if through_extra_params:
# Fetch the additional fields from existing through objects
# and store them in a dict for later use in _add_items
existing_through_objs = self.through.objects.filter(
**{
self.source_field_name: self.related_val[0],
}
).values(self.target_field_name, *through_extra_params)
print("Existing through objs", existing_through_objs)
def process_item(item):
# split into key and additional params
key = item[self.target_field_name]
del item[self.target_field_name]
return (key, item)
existing_through_objs_by_target = dict(
map(process_item, existing_through_objs)
)
# set for further re-use in _add_items
setattr(
self, self._extra_data_param_name(), existing_through_objs_by_target
)
else:
pass
# print("NO EXTRA PARAMS found")
super().set(objs, **kwargs)
def _add_items(self, source_field_name, target_field_name, *objs, **kwargs):
# source_field_name: the PK fieldname in join table for the source object
# target_field_name: the PK fieldname in join table for the target object
# *objs - objects to add. Either object instances, or primary keys of object instances.
# **kwargs: in Django >= 2.2; contains `through_defaults` key.
through_defaults = kwargs.get("through_defaults") or {}
existing_through_objs_by_target = getattr(
self, self._extra_data_param_name(), {}
)
if existing_through_objs_by_target:
# clean up the object, we do not want the attr to stick around
delattr(self, self._extra_data_param_name())
print("---")
print("Additional defaults data", existing_through_objs_by_target)
print("---")
# If there aren't any objects, there is nothing to do.
if objs:
# Django uses a set here, we need to use a list to keep the
# correct ordering.
new_ids = []
for obj in objs:
if isinstance(obj, self.model):
if not router.allow_relation(obj, self.instance):
raise ValueError(
'Cannot add "%r": instance is on database "%s", value is on database "%s"'
% (
obj,
self.instance._state.db,
obj._state.db,
) # pylint: disable=protected-access
)
fk_val = self.through._meta.get_field(
target_field_name
).get_foreign_related_value(obj)[0]
if fk_val is None:
raise ValueError(
'Cannot add "%r": the value for field "%s" is None'
% (obj, target_field_name)
)
new_ids.append(fk_val)
elif isinstance(obj, Model):
raise TypeError(
"'%s' instance expected, got %r"
% (self.model._meta.object_name, obj)
)
else:
new_ids.append(obj)
db = router.db_for_write(self.through, instance=self.instance)
manager = self.through._default_manager.using(
db
) # pylint: disable=protected-access
vals = (
self.through._default_manager.using(
db
) # pylint: disable=protected-access
.values_list(target_field_name, flat=True)
.filter(
**{
source_field_name: self.related_val[0],
"%s__in" % target_field_name: new_ids,
}
)
)
# make set.difference_update() keeping ordering
new_ids_set = set(new_ids)
new_ids_set.difference_update(vals)
new_ids = list(filter(lambda _id: _id in new_ids_set, new_ids))
# Add the ones that aren't there already
with transaction.atomic(using=db, savepoint=False):
if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are inserting the
# duplicate data row for symmetrical reverse entries.
signals.m2m_changed.send(
sender=self.through,
action="pre_add",
instance=self.instance,
reverse=self.reverse,
model=self.model,
pk_set=new_ids_set,
using=db,
)
rel_source_fk = self.related_val[0]
rel_through = self.through
sort_field_name = (
rel_through._sort_field_name
) # pylint: disable=protected-access
# Use the max of all indices as start index...
# maybe an autoincrement field should do the job more efficiently ?
source_queryset = manager.filter(
**{"%s_id" % source_field_name: rel_source_fk}
)
sort_value_max = (
source_queryset.aggregate(max=Max(sort_field_name))["max"] or 0
)
## **MODIFIED bulk_data to include existing_through_objs_by_target**
bulk_data = [
dict(
through_defaults,
**{
"%s_id" % source_field_name: rel_source_fk,
"%s_id" % target_field_name: obj_id,
sort_field_name: obj_idx,
},
**existing_through_objs_by_target.get(obj_id, {}),
)
for obj_idx, obj_id in enumerate(new_ids, sort_value_max + 1)
]
# print("NEW DATA", bulk_data)
# print("---")
manager.bulk_create([rel_through(**data) for data in bulk_data])
if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are inserting the
# duplicate data row for symmetrical reverse entries.
signals.m2m_changed.send(
sender=self.through,
action="post_add",
instance=self.instance,
reverse=self.reverse,
model=self.model,
pk_set=new_ids_set,
using=db,
)What shall I do to get the above incorporated to the main source? I appreciate this is a relatively quick&dirty solution, any feedback is very welcome!
Metadata
Metadata
Assignees
Labels
No labels