@@ -333,10 +333,10 @@ def in_shardings_jax(
333333 `jax.device_put`.
334334
335335 Example usage:
336- >>> from jax.experimental import export
336+ >>> from jax import export
337337 >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
338338 >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
339- ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
339+ ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
340340 ... )(np.arange(jax.device_count()))
341341 >>> exp.in_shardings_hlo
342342 ({devices=[8]<=[8]},)
@@ -347,7 +347,7 @@ def in_shardings_jax(
347347 # Put the args and kwargs on the appropriate devices
348348 >>> run_arg = jax.device_put(np.arange(jax.device_count()),
349349 ... exp.in_shardings_jax(run_mesh)[0])
350- >>> res = export .call(exp) (run_arg)
350+ >>> res = exp .call(run_arg)
351351 >>> res.addressable_shards
352352 [Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
353353 Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
@@ -372,19 +372,53 @@ def out_shardings_jax(
372372 for s in self .out_shardings_hlo )
373373
374374 def has_vjp (self ) -> bool :
375+ """Returns if this Exported supports VJP."""
375376 return self ._get_vjp is not None
376377
377378 def vjp (self ) -> Exported :
378379 """Gets the exported VJP.
379380
380381 Returns None if not available, which can happen if the Exported has been
381- loaded from an external format, without a VJP."""
382+ loaded from an external format without a VJP.
383+ """
382384 if self ._get_vjp is None :
383385 raise ValueError ("No VJP is available" )
384386 return self ._get_vjp (self )
385387
388+ def serialize (self ,
389+ vjp_order : int = 0 ) -> bytearray :
390+ """Serializes an Exported.
391+
392+ Args:
393+ vjp_order: The maximum vjp order to include. E.g., the value 2 means that we
394+ serialize the primal functions and two orders of the `vjp` function. This
395+ should allow 2nd order reverse mode differentiation of the deserialized
396+ function. i.e., `jax.grad(jax.grad(f)).`
397+ """
398+ # Lazy load the serialization module, since flatbuffers is an optional
399+ # dependency.
400+ from jax ._src .export .serialization import serialize
401+ return serialize (self , vjp_order = vjp_order )
402+
403+ def call (self , * args , ** kwargs ):
404+ return call_exported (self )(* args , ** kwargs )
405+
406+
407+ def deserialize (blob : bytearray ) -> Exported :
408+ """Deserializes an Exported.
409+
410+ Args:
411+ blob: a bytearray obtained from `Exported.serialize`.
412+ """
413+ # Lazy load the serialization module, since flatbuffers is an optional
414+ # dependency.
415+ from jax ._src .export .serialization import deserialize
416+ return deserialize (blob )
417+
386418
387419def default_lowering_platform () -> str :
420+ """Retrieves the default lowering platform for the exporting machine.
421+ """
388422 # Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
389423 return xb .canonicalize_platform (jax .default_backend ())
390424
@@ -411,14 +445,20 @@ def args_specs(
411445 return shape_poly .symbolic_args_specs (args , polymorphic_shapes )
412446
413447
414- def export (fun_jax : Callable ,
415- * ,
416- lowering_platforms : Sequence [str ] | None = None ,
417- disabled_checks : Sequence [DisabledSafetyCheck ] = (),
418- _device_assignment_for_internal_jax2tf_use_only = None ,
419- ) -> Callable [..., Exported ]:
448+ # TODO(necula): remove this once we remove jax.experimental.export.
449+ def export_back_compat (
450+ fun_jax : Callable ,
451+ * ,
452+ lowering_platforms : Sequence [str ] | None = None ,
453+ disabled_checks : Sequence [DisabledSafetyCheck ] = (),
454+ _device_assignment_for_internal_jax2tf_use_only = None ,
455+ ) -> Callable [..., Exported ]:
420456 """Exports native serialization for a JAX function.
421457
458+ Note: this function exists only for internal usage by jax2tf and for
459+ backwards compatibility with jax.experimental.export. Use
460+ `jax.export` instead.
461+
422462 Args:
423463 fun_jax: the function to lower and serialize.
424464 lowering_platforms:
@@ -498,6 +538,85 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
498538 _device_assignment_for_internal_jax2tf_use_only = _device_assignment_for_internal_jax2tf_use_only )
499539 return do_export
500540
541+ def export (
542+ fun_jit : stages .Wrapped ,
543+ * ,
544+ lowering_platforms : Sequence [str ] | None = None ,
545+ disabled_checks : Sequence [DisabledSafetyCheck ] = (),
546+ ) -> Callable [..., Exported ]:
547+ """Exports a JAX function for persistent serialization.
548+
549+ Args:
550+ fun_jit: the function to export. Should be the result of `jit`.
551+ lowering_platforms:
552+ Optional sequence containing a subset of 'tpu', 'cpu',
553+ 'cuda', 'rocm'. If more than one platform is specified, then
554+ the lowered code takes an argument specifying the platform.
555+ If None, then use the default JAX backend.
556+ The calling convention for multiple platforms is explained in the
557+ `jax_export.Exported` docstring.
558+ disabled_checks: the safety checks to disable. See docstring
559+ of `DisabledSafetyCheck`.
560+
561+ Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
562+ or values with `.shape` and `.dtype` attributes, and returns an
563+ `Exported`.
564+
565+ Usage:
566+ >>> from jax import export
567+ >>> exported: export.Exported = export.export(jnp.sin)(
568+ ... np.arange(4, dtype=np.float32))
569+
570+ # You can inspect the Exported object
571+ >>> exported.in_avals
572+ (ShapedArray(float32[4]),)
573+ >>> blob: bytearray = exported.serialize()
574+
575+ # The serialized bytes are safe to use in a separate process
576+ >>> rehydrated: export.Exported = export.deserialize(blob)
577+ >>> rehydrated.fun_name
578+ 'sin'
579+ >>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32))
580+ Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32)
581+ """
582+ if not isinstance (fun_jit , stages .Wrapped ):
583+ raise ValueError (
584+ f"Function to be exported must be the result of `jit` but is: { fun_jit } " )
585+ if lowering_platforms is not None :
586+ actual_lowering_platforms = tuple (lowering_platforms )
587+ else :
588+ actual_lowering_platforms = (default_lowering_platform (),)
589+
590+ def do_export (* args_specs , ** kwargs_specs ) -> Exported :
591+ # TODO: move to `lower`
592+ symbolic_scope : tuple [shape_poly .SymbolicScope , tree_util .KeyPath ] | None = None # type: ignore[invalid-annotation,unused-ignore]
593+ for k_path , aval in tree_util .tree_flatten_with_path ((args_specs , kwargs_specs ))[0 ]:
594+ # Static args may have no `shape` attribute.
595+ if not hasattr (aval , "shape" ):
596+ continue
597+ for d in aval .shape :
598+ if shape_poly .is_symbolic_dim (d ):
599+ if symbolic_scope is None :
600+ symbolic_scope = (d .scope , k_path )
601+ continue
602+ symbolic_scope [0 ]._check_same_scope (
603+ d , when = f"when exporting { util .fun_name (fun_jit )} " ,
604+ self_descr = f"current (from { shape_poly .args_kwargs_path_to_str (symbolic_scope [1 ])} ) " ,
605+ other_descr = shape_poly .args_kwargs_path_to_str (k_path ))
606+
607+ traced = fun_jit .trace ( # type: ignore
608+ * args_specs , ** kwargs_specs ,
609+ _experimental_lowering_parameters = mlir .LoweringParameters (
610+ platforms = actual_lowering_platforms ,
611+ for_export = True ,
612+ ))
613+ jaxpr , fun_name = traced .jaxpr , traced .fun_name
614+ lowered = traced .lower ()
615+ return _export_lowered (
616+ lowered , jaxpr , fun_name ,
617+ disabled_checks = disabled_checks )
618+ return do_export
619+
501620def _export_lowered (
502621 lowered : stages .Lowered ,
503622 jaxpr : core .ClosedJaxpr , fun_name : str ,
@@ -599,7 +718,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported:
599718 device_assignment = device_assignment ,
600719 apply_jit = True ,
601720 flat_primal_fun = True )
602- return export (fun_vjp_jax ,
721+ return export (fun_vjp_jax , # type: ignore[arg-type]
603722 lowering_platforms = exp_primal .lowering_platforms ,
604723 disabled_checks = exp_primal .disabled_safety_checks )(* vjp_in_avals )
605724
@@ -816,7 +935,7 @@ def is_token(typ, attrs):
816935
817936def _check_lowering (lowering ) -> None :
818937 if not isinstance (lowering , pxla .MeshComputation ):
819- raise NotImplementedError (f"serialization is supported only for pjit . { lowering } " )
938+ raise NotImplementedError (f"serialization is supported only for jit . { lowering } " )
820939
821940 if lowering .compile_args ["host_callbacks" ] or lowering .compile_args ["keepalive" ]:
822941 raise NotImplementedError ("serialization of host_callbacks is not yet implemented" )
0 commit comments