· 5 years ago · May 12, 2020, 02:52 PM
1#
2#
3# In this file you're looking at the statement returned by
4# ModelTransferrer.transfer_model_stmt()
5#
6#
7
8
9
10from typing import List, Tuple, Dict, Type, Iterable, Set, Optional
11from urllib.parse import urljoin
12from itertools import chain
13from uuid import uuid4
14from logging import getLogger
15from datetime import datetime
16from string import ascii_letters
17import secrets
18import requests
19
20from django.db import models, transaction, connection
21from django.urls import reverse
22from django.dispatch import receiver
23from django.conf import settings
24from django.contrib import admin
25from django.contrib.auth import models as auth
26from django.db.models.signals import pre_delete
27from django.contrib.contenttypes.models import ContentType
28
29from sqlalchemy import and_, or_, literal, MetaData, func, select, case, exists
30from sqlalchemy.orm import aliased
31from sqlalchemy.dialects import postgresql
32from aldjemy.orm import construct_models
33from celery import shared_task, chord
34
35from aspiredu.endpoint.models import Endpoint
36from aspiredu.toolbox.toposort import toposort
37
38from .pointer import EndpointPointer
39from . import fdw, s3
40
41
42logger = getLogger(__name__)
43
44
45TRANSFER_MAX_AGE = 120 # 2 minutes, max clock skew for webhook
46
47#
48# Extension hooks
49#
50
51# Models that this process can transfer.
52# SHALLOW_MODELS are automatically added.
53# Will eventually be replaced by auto-discovery.
54TRANSFER_MODELS: Set[Type[models.Model]] = set()
55
56# Models that must be transferred to run a retrieval on the new backend.
57SHALLOW_MODELS: Set[Type[models.Model]] = set()
58
59# Many to Many fields that must be ignored by the transfer process.
60IGNORE_M2M_FIELDS: Set[Type[models.ManyToManyField]] = set()
61
62# Models that the transfer process should not transfer.
63IGNORE_MODELS: Set[Type[models.Model]] = {Endpoint, auth.User}
64
65# Models with a foreign key to self to make a tree.
66RECURSIVE_MODELS: Set[Type[models.Model]] = set()
67
68
69def transfer_name_default():
70 return str(uuid4())
71
72
73class Transfer(models.Model):
74 """A destination-driven transfer process.
75
76 The ``endpoint`` is the destination endpoint. The destination
77 endpoint should be created manually, though only fields directly
78 on the endpoint need to be specified manually.
79
80 The ``source`` is the endpoint from which we need to copy data.
81 It is an EndpointPointer, that communicates both the backend as
82 well as the endpoint on that particular backend.
83
84 The name is a unique identifier that will be used to construct
85 the remote server and the schema for the foreign data wrapper.
86
87 `shallow` signals that only the endpoint and the product
88 configuration models should be transferred.
89 """
90
91 created = models.DateTimeField(auto_now_add=True)
92 updated = models.DateTimeField(auto_now=True)
93 source = models.ForeignKey(EndpointPointer, on_delete=models.CASCADE)
94 destination = models.ForeignKey(Endpoint, on_delete=models.CASCADE)
95 name = models.CharField(default=transfer_name_default, unique=True, max_length=128)
96 shallow = models.BooleanField(default=True)
97
98 def target_models(self):
99 """Determine which models are targeted by this transfer."""
100 return SHALLOW_MODELS if self.shallow else SHALLOW_MODELS | TRANSFER_MODELS
101
102 def create_fdw(self, db: dict):
103 """Create a Foreign Data Wrapper for the given database.
104
105 Create a foreign data wrapper to the given server.
106 This must be cleaned up with ``drop_fdw`` on completion.
107
108 :param db: The database settings from the source backend.
109 """
110 fdw.create(
111 server=self.name,
112 host=db["HOST"],
113 port=db["PORT"],
114 dbname=db["NAME"],
115 user=db["USER"],
116 password=db["PASSWORD"],
117 src="public",
118 dst=self.name,
119 )
120
121 def drop_fdw(self):
122 """Drop the Foreign Data Wrapper for this transfer."""
123 fdw.drop(server=self.name, dst=self.name)
124
125
126@receiver(
127 pre_delete,
128 sender=Transfer,
129 dispatch_uid="aspiredu.unity.transfer.pre_delete_transfer",
130)
131def pre_delete_transfer(sender, instance, **kwargs):
132 instance.drop_fdw()
133
134
135@admin.register(Transfer)
136class TransferAdmin(admin.ModelAdmin):
137 list_display = ["created", "source", "destination"]
138 autocomplete_fields = ["source", "destination"]
139
140 def save_model(self, request, obj, form, change):
141 super().save_model(request, obj, form, change)
142 if not change:
143 transaction.on_commit(lambda: transfer_signature(obj).delay())
144
145
146class Checkpoint(models.Model):
147 """A record of the latest transferred instance for a model."""
148
149 created = models.DateTimeField(auto_now_add=True)
150 updated = models.DateTimeField(auto_now=True)
151 transfer = models.ForeignKey(
152 Transfer, on_delete=models.CASCADE, related_name="checkpoints"
153 )
154
155 content_type = models.ForeignKey(
156 ContentType,
157 on_delete=models.CASCADE,
158 related_name="checkpoints",
159 help_text="The model that was transferred.",
160 )
161 # NOTE: Cannot use in a GenericForeignKey because the object_id
162 # is from the source, and may not exist in the destination.
163 object_id = models.CharField(
164 max_length=256,
165 help_text="The source primary key value of the last instance transferred.",
166 )
167 timestamp = models.DateTimeField(
168 null=True,
169 blank=True,
170 help_text="The source updated timestamp of the last instance transferred.",
171 )
172
173 class Meta:
174 unique_together = ("transfer", "content_type")
175
176
177@admin.register(Checkpoint)
178class CheckpointAdmin(admin.ModelAdmin):
179 list_display = [
180 "created",
181 "updated",
182 "transfer",
183 "content_type",
184 "object_id",
185 "timestamp",
186 ]
187
188
189def discover_models(
190 targets: Iterable[Type[models.Model]],
191) -> Tuple[Set[Type[models.Model]], Dict[Type[models.Model], List[str]]]:
192 """Discover models required to transfer target models.
193
194 Recursively look at all foreign keys on each model in the target
195 set of models, and figure out the full set of models that need to
196 be transferred in order for the target models to transfer properly.
197 Order the discovered models in the appropriate transfer order.
198
199 Also look for many-to-many fields that have their concrete side on
200 one of the models given, and add the through tables to the set of
201 models that will be transferred.
202
203 Return a full set of models, as well as a dictionary of natural keys
204 for models that do not have natural keys specified, such as
205 many-to-many through tables.
206 """
207 discovered = set(targets)
208
209 while True:
210 more = set(
211 field.related_model
212 for model in discovered
213 for field in model._meta.get_fields()
214 if field.concrete
215 and field.is_relation
216 and field.related_model # ignore GFKs
217 and not field.many_to_many # ignore M2Ms
218 ) | set(
219 # Go figure, the concrete side doesn't have the through table set
220 field.remote_field.through
221 for model in discovered
222 for field in model._meta.get_fields()
223 if field.concrete
224 and field.is_relation
225 and field.many_to_many # Get M2Ms
226 and field not in IGNORE_M2M_FIELDS
227 )
228 if not more - IGNORE_MODELS - discovered:
229 break
230 discovered |= more - IGNORE_MODELS
231
232 # Discover natural keys for many-to-many through models
233 natural_keys: Dict[Type[models.Model], List[str]] = {
234 field.remote_field.through: [
235 field.m2m_field_name(),
236 field.m2m_reverse_field_name(),
237 ]
238 for model in discovered
239 for field in model._meta.get_fields()
240 if field.concrete and field.is_relation and field.many_to_many # Get M2Ms
241 }
242 extra_natural_keys: Dict[Type[models.Model], List[str]] = {auth.User: ["username"]}
243 return discovered, {**natural_keys, **extra_natural_keys}
244
245
246def sort_models(
247 discovered: Set[Type[models.Model]],
248 *,
249 ignore: Optional[Set[Type[models.Model]]] = None,
250) -> Iterable[Set[Type[models.Model]]]:
251 """Sort the discovered models topologically for dependent relations."""
252 ignore_models: Set[Type[models.Model]] = (ignore or set()) | IGNORE_MODELS
253 return toposort(
254 {
255 model: {
256 field.related_model
257 for field in model._meta.get_fields()
258 if field.concrete
259 and field.is_relation
260 and field.related_model # ignore GFKs
261 and not field.many_to_many # ignore M2Ms
262 and field.related_model is not model # Ignore self-references
263 and field.related_model not in ignore_models
264 }
265 for model in discovered
266 }
267 )
268
269
270def ident(Model: Type[models.Model]) -> str:
271 """Find the appropriate identity string for the given model."""
272 return Model._meta.db_table
273
274
275def has_auto_now(Model: models.Model):
276 """Determine if this model has an auto_now datetime field."""
277 return any(
278 isinstance(field, models.DateTimeField) and field.auto_now
279 for field in Model._meta.get_fields()
280 )
281
282
283def transfer_signature(transfer: Transfer):
284 """Create a signature to run the transfer."""
285 models, natural_keys = discover_models(transfer.target_models())
286 models_map = {Model: ident(Model) for Model in models}
287 auto_now_models = set(filter(has_auto_now, models))
288 timeless_models = models - auto_now_models
289 signature = transfer_setup.si(transfer.id)
290
291 for model_group in chain(
292 sort_models(auto_now_models),
293 sort_models(timeless_models, ignore=auto_now_models),
294 ):
295 signature |= chord(
296 [
297 transfer_model.si(transfer.id, models_map[Model])
298 for Model in model_group
299 ],
300 transfer_s3_files.si(transfer.id),
301 )
302 signature |= transfer_delete.si(transfer.id)
303 return signature
304
305
306@shared_task
307def transfer_s3_files(transfer_id):
308 """Transfer all s3 files"""
309 transfer = Transfer.objects.get(id=transfer_id)
310 if transfer.shallow or not all(
311 [
312 settings.AWS_STORAGE_BUCKET_NAME,
313 settings.AWS_ACCESS_KEY_ID,
314 settings.AWS_SECRET_ACCESS_KEY,
315 settings.AWS_TRANSFER_ASSUME_ROLE_ARN,
316 ]
317 ):
318 logger.info(f"event={'transfer_s3_files_skipped'!r} transfer={transfer!r}")
319 return
320
321 logger.info(f"event={'transfer_s3_files_start'!r} transfer={transfer!r}")
322
323 token = transfer.source.backend.token
324
325 endpoint_url = urljoin(
326 transfer.source.backend.url,
327 reverse("unity:endpoint-detail", kwargs={"name": transfer.source.name}),
328 )
329 source_endpoint = requests.get(
330 endpoint_url, headers={"Authorization": f"Bearer {token}"}
331 ).json()
332
333 media_url = urljoin(transfer.source.backend.url, reverse("unity:media"))
334 source_media = requests.get(
335 media_url, headers={"Authorization": f"Bearer {token}"}
336 ).json()
337
338 s3_transfer = s3.S3(
339 dest_bucket=settings.AWS_STORAGE_BUCKET_NAME,
340 dest_uuid=str(transfer.destination.uuid),
341 source_bucket=source_media["bucket"],
342 source_path=source_media["path"],
343 source_uuid=source_endpoint["uuid"],
344 )
345 s3_transfer.transfer(
346 settings.AWS_ACCESS_KEY_ID,
347 settings.AWS_SECRET_ACCESS_KEY,
348 settings.AWS_TRANSFER_ASSUME_ROLE_ARN,
349 )
350
351 logger.info(f"event={'transfer_s3_files_stop'!r} transfer={transfer!r}")
352
353
354@shared_task
355def transfer_setup(transfer_id):
356 """Prepare the transfer to start processing."""
357 transfer = Transfer.objects.get(id=transfer_id)
358 logger.info(f"event={'transfer_setup_start'!r} transfer={transfer!r}")
359 url = urljoin(transfer.source.backend.url, reverse("unity:db"))
360 token = transfer.source.backend.token
361 db = requests.get(url, headers={"Authorization": f"Bearer {token}"}).json()
362 # Won't work if sslrootcert is needed
363 transfer.create_fdw(db) # Dropped automatically when transfer is deleted
364 logger.info(f"event={'transfer_setup_stop'!r} transfer={transfer!r}")
365
366
367@shared_task(bind=True)
368def transfer_model(self, transfer_id, model_ident):
369 # Things that need to be added to non-shallow transfer:
370 # * LMS-specific models
371 # * LTI Models
372 # * Recursive models
373 # * S3 transfers
374 # * M2M transfers
375 # * Autodiscover models
376 #
377 transfer = Transfer.objects.get(id=transfer_id)
378 models, natural_keys = discover_models(transfer.target_models())
379 models_map = {ident(Model): Model for Model in models}
380 Model = models_map[model_ident]
381 checkpoint = transfer.checkpoints.filter(
382 content_type=ContentType.objects.get_for_model(Model)
383 ).first()
384 if not checkpoint:
385 logger.info(f"event={'transfer_start'!r} model={Model!r}")
386 else:
387 logger.info(
388 f"event={'transfer_resume'!r} model={Model!r} "
389 f"timestamp={checkpoint.timestamp!r} "
390 f"object_id={checkpoint.object_id!r}"
391 )
392
393 transferrer = ModelTransferrer(transfer, natural_keys, Model)
394 complete, pk, timestamp = transferrer.transfer_chunk(checkpoint)
395
396 transfer.checkpoints.update_or_create(
397 content_type=ContentType.objects.get_for_model(Model),
398 defaults={"object_id": pk, "timestamp": timestamp},
399 )
400 if not complete:
401 logger.info(
402 f"event={'transfer_pause'!r} model={Model!r} "
403 f"pk={pk!r} timestamp={timestamp!r}"
404 )
405 self.replace(transfer_model.si(transfer_id, model_ident))
406 logger.info(
407 f"event={'transfer_stop'!r} model={Model!r}"
408 f"pk={pk!r} timestamp={timestamp!r}"
409 )
410
411
412@shared_task
413def transfer_delete(transfer_id):
414 """Delete the completed transfer."""
415 transfer = Transfer.objects.get(id=transfer_id)
416 transfer_repr = f"{transfer!r}" # Pre-format to print after delete
417 logger.info(f"event={'transfer_delete_start'!r} transfer={transfer_repr}")
418 transfer.delete()
419 logger.info(f"event={'transfer_delete_stop'!r} transfer={transfer_repr}")
420
421
422class Fields:
423 """A container for filtering fields as needed to transfer.
424
425 Split the fields into these categories:
426
427 1. `endpoint` - References to the `Endpoint` model.
428 2. `reference` - References to other models.
429 3. `basic` - Basic copyable values.
430 4. `uuid_pk` - UUID primary keys.
431 """
432
433 def __init__(
434 self, Model, natural_keys, *, natural_key: bool = False, recursive: bool = True
435 ):
436 Model = Model
437 self.all = [
438 field
439 for field in Model._meta.get_fields()
440 if field.concrete
441 and not field.primary_key
442 and (
443 not natural_key
444 or field.name in (natural_keys.get(Model) or Model.natural_key.paths)
445 )
446 ]
447 self.related = [
448 field
449 for field in self.all
450 if field.is_relation
451 and field.related_model # ignore GFKs
452 and not field.many_to_many # ignore M2Ms
453 ]
454
455 # Fields that are to endpoint, which is special-cased to allow
456 # simply overriding the endpoint_id instead of looking it up.
457 self.endpoint = [
458 field for field in self.related if field.related_model == Endpoint
459 ]
460
461 # Fields that reference another model, and will need to have
462 # their identifiers dereferenced into the natural key to
463 # look up the appropriate related instance in the destination.
464 exclude_to = [Model, Endpoint] if not recursive else [Endpoint]
465 self.reference = [
466 field for field in self.related if field.related_model not in exclude_to
467 ]
468
469 # Fields that directly represent a basic copyable value.
470 self.basic = [field for field in self.all if not field.is_relation]
471 self.uuid_pk = [
472 field for field in [Model._meta.pk] if isinstance(field, models.UUIDField)
473 ]
474
475
476class Dereference:
477 """Match related ids in the destination by the natural key.
478
479 Using the natural key to match the identity, convert references
480 in the source to references in the destination.
481
482 Based on the dereferenced reference fields, create:
483
484 1. ``src_joins`` - Joins for the source schema.
485 2. ``dst_joins`` - Joins for the destination schema.
486 3. ``select`` - The field that should be added to the select
487 statement so that it can be inserted, because
488 it is part of the natural key.
489
490 ``src_joins`` are separated from ``dst_joins`` so that the joins
491 for the source (sent to the foreign server) are earlier than the
492 joins for the local destination, in hopes that it makes it more
493 likely that the filtering that can be done on the foreign server
494 is done on the foreign server and saves time by avoid large
495 transfer volumes between the database servers.
496 """
497
498 def __init__(self, Model, on_column, src, dst, src_id, dst_id, natural_keys):
499 src_model = aliased(src[Model])
500 dst_model = aliased(dst[Model])
501 fields = Fields(Model, natural_keys, natural_key=True)
502
503 self.src_joins = [
504 (
505 src_model,
506 and_(
507 on_column == getattr(src_model, Model._meta.pk.column),
508 # Add in the endpoint to the join. It makes outer joins easier.
509 *[
510 getattr(src_model, field.column) == src_id
511 for field in fields.endpoint
512 ],
513 ),
514 )
515 ]
516 self.dst_joins = []
517 self.select = getattr(dst_model, Model._meta.pk.column)
518
519 ref_cmps = []
520 for field in fields.reference:
521 deref = Dereference(
522 field.related_model,
523 getattr(src_model, field.column),
524 src,
525 dst,
526 src_id,
527 dst_id,
528 natural_keys,
529 )
530 self.src_joins += deref.src_joins
531 self.dst_joins += deref.dst_joins
532 ref_cmps.append(getattr(dst_model, field.column) == deref.select)
533
534 self.dst_joins += [
535 (
536 dst_model,
537 and_(
538 *[
539 getattr(dst_model, field.column) == dst_id
540 for field in fields.endpoint
541 ],
542 *[
543 getattr(dst_model, field.column)
544 == getattr(src_model, field.column)
545 for field in fields.basic
546 ],
547 *ref_cmps,
548 ),
549 )
550 ]
551
552
553class ModelTransferrer:
554 """Manage the transfer of a model."""
555
556 def __init__(
557 self,
558 transfer: Transfer,
559 natural_keys: Dict[Type[models.Model], List[str]],
560 Model: Type[models.Model],
561 ):
562 """Create a new transferrer.
563
564 :param transfer: The transfer instance.
565 :param natural_keys: Additional natural keys that are not specified
566 direct on models.
567 :param Model: The model to transfer.
568 """
569 self.transfer = transfer
570 self.natural_keys = natural_keys
571 self.Model = Model
572
573 self.src = construct_models(MetaData(schema=transfer.name))
574 self.dst = construct_models(MetaData(schema="public"))
575 self.src_id = transfer.source.remote_id
576 self.dst_id = transfer.destination.id
577
578 def transfer_chunk(self, checkpoint: Optional[Checkpoint]):
579 if checkpoint and checkpoint.model_class() != self.Model:
580 raise ValueError("checkpoint is not for the correct model.")
581
582 with connection.cursor() as cursor:
583 if self.Model in RECURSIVE_MODELS:
584 # This is a quick way to work this out for Account.
585 # If we were to add a database constraint that there
586 # could only be one account with a null parent
587 # (as I intend to do eventually), we will have to
588 # re-work this to transfer in a loop until completion.
589 # I'm not sure exactly how we'd signal that we'd
590 # gotten all of them and can stop, though.
591 #
592 # Unset the limit so that this sparse run transfers all.
593 stmt = self.transfer_model_stmt(recursive=False, limit=None)
594 compiled = stmt.compile(dialect=postgresql.dialect())
595 cursor.execute(compiled.string, compiled.params)
596 stmt = self.transfer_model_stmt(
597 limit=1000,
598 pk=checkpoint and checkpoint.object_id,
599 timestamp=checkpoint and checkpoint.timestamp,
600 )
601 compiled = stmt.compile(dialect=postgresql.dialect())
602 print(compiled.string)
603 cursor.execute(compiled.string, compiled.params)
604 error, complete, inserted, pk, timestamp = cursor.fetchone()
605
606 if error:
607 raise Exception("'after' row was deleted unexpectedly.")
608
609 return complete, pk, timestamp
610
611 def transfer_model_stmt(
612 self,
613 *,
614 recursive=True,
615 limit: Optional[int] = 1000, # can be None for the recursive hack
616 pk: Optional[str] = None, # Has default for the recursive hack
617 timestamp: Optional[datetime] = None, # Has default for the recursive hack
618 ):
619 """Transfer the given model.
620
621 Produces an SQL statement that inserts values from the source into the
622 destination, and returns a single row with four values: A nullable integer
623 indicating if any occured and what the error is, a boolean indicating
624 whether this statement completed the transfer for this model, and the
625 datetime and primary key of the last row accounted transferred. The
626 datetime may be NULL if this model does not have an auto_now field.
627
628 :param bool recursive: Whether to transfer relations that target
629 this same model. When tranferring a model
630 with a link like this, you should first
631 run with this as ``False``, then as ``True``
632 to update the links after all the instance
633 have been created.
634 :param int limit: (optional) The maximum number of rows to transfer
635 in this iteration.
636 :param pk: (optional) The primary key of the last instance on the
637 source database, as returned from the execution of the
638 statement produced by the previous run of this function.
639 :param timestamp: (optional) The timestamp of the last instance from the
640 previous execution of the statement produced by the
641 previous run of this function. The timestamp on the row
642 may have changed since the execution of the previous
643 statement.
644 """
645
646 # The structure of the query is something like this:
647 """
648 WITH
649 after AS (
650 SELECT <natkey_fields> FROM <table>
651 WHERE <pk> = <object_id>
652 ),
653 chunk AS (
654 SELECT
655 ROW_NUMBER() OVER () AS x_row_number, -- avoid name clashes
656 <pk>,
657 <timestamp>,
658 <fields_and_relation_natkeys>
659 FROM <table>
660 WHERE
661 <table.endpoint_id> = <src_id>
662 -- NOTE: Only included if ``after`` was given to the call args
663 <object_id (from after)> IS NULL OR (
664 ROW(<<table.field> for <field> in <natkey_fields>>)
665 > (
666 SELECT <<after.field> for <field> in <natkey_fields>>
667 FROM after
668 )
669 )
670 -- TODO: Figure out where to put the timestamp in the ROWs
671 ORDER BY <natkey_fields>
672 LIMIT <limit>
673 ),
674 inserted AS (
675 --- ######## START ADAPTATION OF THE OLD QUERY ####################
676 INSERT INTO <dst.table> (<>)
677 ON CONFLICT DO UPDATE
678 <...>
679 RETURNING 1 -- NOTE: Unused placeholder constant. Is it even needed?
680 --- ######### STOP ADAPTATION OF THE OLD QUERY ####################
681 )
682 SELECT
683 IF EXISTS(SELECT * FROM after) THEN 1 ELSE 0 AS error,
684 x_row_number < <limit> AS completed,
685 <pk>,
686 <timestamp> -- May be null if there are no auto_now fields.
687 FROM chunk
688 WHERE x_row_number = (SELECT MAX(x_row_number) FROM chunk) -- last row.
689 """
690
691 fields = Fields(self.Model, self.natural_keys, recursive=recursive)
692 natkey_fields = Fields(
693 self.Model, self.natural_keys, recursive=recursive, natural_key=True
694 )
695
696 after = (
697 self.after_query(pk=pk, natkey_fields=natkey_fields).cte() if pk else None
698 )
699 rownum_label = "rownum_" + "".join(
700 secrets.choice(ascii_letters) for i in range(10)
701 )
702
703 chunk = self.chunk_query(
704 after=after,
705 rownum_label=rownum_label,
706 fields=fields,
707 natkey_fields=natkey_fields,
708 limit=limit,
709 timestamp=timestamp,
710 ).cte()
711
712 inserted = self.inserted_query(chunk=chunk, fields=fields).cte()
713
714 after_error = (
715 case([exists(select(["*"]).select_from(after)), 1], else_=0)
716 if after
717 else literal(False)
718 )
719 return select(
720 [
721 after_error.label("error"),
722 (getattr(chunk.c, rownum_label) < literal(limit)).label("completed"),
723 exists(select(["*"]).select_from(inserted)).label("inserted"),
724 getattr(chunk.c, self.Model._meta.pk.column).label("pk"),
725 literal(None).label("timestamp"),
726 ]
727 ).where(
728 (
729 getattr(chunk.c, rownum_label)
730 == select([func.max(getattr(chunk.c, rownum_label))])
731 )
732 )
733
734 def after_query(self, pk: str, natkey_fields: Fields):
735 """Generate the query that gets the after row."""
736 return (
737 self.src[self.Model]
738 .query(
739 getattr(self.src[self.Model], field.column)
740 for field in natkey_fields.all
741 )
742 .filter(
743 getattr(self.src[self.Model], self.Model._meta.pk.column)
744 == literal(self.Model._meta.pk.to_python(pk))
745 )
746 )
747
748 def chunk_query(
749 self,
750 *,
751 after,
752 rownum_label: str,
753 fields: Fields,
754 natkey_fields: Fields,
755 limit: Optional[int], # can be None for the recursive hack
756 timestamp: Optional[datetime] = None,
757 ):
758 # NOTE: The rows are filtered by endpoint only if the endpoint is directly
759 # on the current model. Otherwise, it will not be able to apply that
760 # filtering until after the chunk query, when it has access to more
761 # than just the current table.
762 #
763 # The current-table limitation is not inherent, but it greatly
764 # simplifies the implementation. Time may prove that we cannot live
765 # with the limitation, since this also affects models like ManyToMany
766 # through models, which may prove difficult to impossible to add a
767 # direct endpoint field to.
768 return (
769 self.src[self.Model]
770 .query(
771 func.row_number().over().label(rownum_label),
772 getattr(self.src[self.Model], self.Model._meta.pk.column),
773 *[getattr(self.src[self.Model], field.column) for field in fields.all],
774 )
775 .filter(
776 *[
777 getattr(self.src[self.Model], field.column) == self.src_id
778 for field in fields.endpoint
779 ],
780 *(
781 [
782 # TODO: Add the timestamp to this filter if given
783 func.row(
784 getattr(self.src[self.Model], field.column)
785 for field in natkey_fields.all
786 )
787 > select(
788 [
789 getattr(after.c, field.column)
790 for field in natkey_fields.all
791 ]
792 ).subquery()
793 ]
794 if after
795 else []
796 ),
797 )
798 .order_by(
799 *[
800 getattr(self.src[self.Model], field.column)
801 for field in natkey_fields.all
802 ]
803 )
804 .limit(limit)
805 )
806
807 def inserted_query(self, *, chunk, fields: Fields):
808 """Generate the query that will insert into the database."""
809 src_inner_joins = []
810 src_outer_joins = []
811 dst_inner_joins = []
812 dst_outer_joins = []
813 ref_selects = []
814 ref_filters = []
815
816 # Recursively dereference the natural key for each reference field.
817 # Find all the joins, selects, and filters that need to be applied
818 # when dealing with the natural keys, recursively.
819 for field in fields.reference:
820 deref = Dereference(
821 field.related_model,
822 getattr(chunk.c, field.column),
823 self.src,
824 self.dst,
825 self.src_id,
826 self.dst_id,
827 self.natural_keys,
828 )
829 if field.null:
830 src_outer_joins += deref.src_joins
831 dst_outer_joins += deref.dst_joins
832 # Ignore references not in the destination
833 # without ignoring null references
834 ref_filters.append(
835 or_(
836 getattr(self.src[self.Model], field.column).is_(None),
837 deref.select.isnot(None),
838 )
839 )
840 else:
841 src_inner_joins += deref.src_joins
842 dst_inner_joins += deref.dst_joins
843 ref_selects.append(deref.select)
844
845 query = (
846 self.src[self.Model]
847 .query(
848 # UUID PK fields, Basic fields, endpoint fields, then reference fields
849 *[func.gen_random_uuid() for field in fields.uuid_pk],
850 *[
851 getattr(self.src[self.Model], field.column)
852 for field in fields.basic
853 ],
854 *[literal(self.dst_id) for field in fields.endpoint],
855 *ref_selects,
856 )
857 .select_from(self.src[self.Model])
858 .filter(*ref_filters)
859 )
860
861 # Put src_joins first to hopefully make it more likely for those
862 # to be done on the remote server to minimize data transfer.
863 for join in src_inner_joins:
864 query = query.join(*join)
865 for join in src_outer_joins:
866 query = query.outerjoin(*join)
867 for join in dst_inner_joins:
868 query = query.join(*join)
869 for join in dst_outer_joins:
870 query = query.outerjoin(*join)
871
872 # Fields must be in the same order as they are selected
873 insert_fields = (
874 fields.uuid_pk + fields.basic + fields.endpoint + fields.reference
875 )
876 insert = postgresql.insert(self.dst[self.Model]).from_select(
877 [getattr(self.dst[self.Model], field.column) for field in insert_fields],
878 query,
879 )
880 index_elements = [
881 getattr(self.dst[self.Model], field.column)
882 for field in insert_fields
883 if field.name
884 in (self.natural_keys.get(self.Model) or self.Model.natural_key.paths)
885 # The natural key is the conflict index
886 ]
887 set_elements = {
888 getattr(self.dst[self.Model], field.column): getattr(
889 insert.excluded, field.column
890 )
891 for field in insert_fields
892 # All the fields except those that are in the conflict index
893 if field.name
894 not in (self.natural_keys.get(self.Model) or self.Model.natural_key.paths)
895 # and fields that are a UUID primary key
896 and field not in fields.uuid_pk
897 }
898 if set_elements:
899 stmt = insert.on_conflict_do_update(
900 index_elements=index_elements,
901 set_={
902 excluded.key: current for excluded, current in set_elements.items()
903 },
904 where=or_(
905 excluded.is_distinct_from(current)
906 for excluded, current in set_elements.items()
907 ),
908 )
909 else:
910 stmt = insert.on_conflict_do_nothing(index_elements=index_elements)
911 return stmt.returning(literal(1))