@@ -267,7 +267,8 @@ def example_entrypoint() -> ExampleExtension:
267267 # Test 1: Simple CPU tensor
268268 import torch
269269
270- cpu_tensor = torch .randn (3 , 4 )
270+ with torch .inference_mode ():
271+ cpu_tensor = torch .randn (3 , 4 )
271272
272273 # Call extension method
273274 result_tensor = await extension .do_stuff ({"operation" : "process_tensor" , "tensor" : cpu_tensor })
@@ -284,7 +285,8 @@ def example_entrypoint() -> ExampleExtension:
284285 assert tensor_info ["is_cuda" ] is False
285286
286287 # Test 2: Multiple tensors
287- tensors = [torch .ones (2 , 2 ), torch .zeros (2 , 2 ), torch .eye (2 )]
288+ with torch .inference_mode ():
289+ tensors = [torch .ones (2 , 2 ), torch .zeros (2 , 2 ), torch .eye (2 )]
288290 stacked_result = await extension .do_stuff (
289291 {"operation" : "test_multiple_tensors" , "tensors" : tensors }
290292 )
@@ -415,7 +417,8 @@ def example_entrypoint() -> ExampleExtension:
415417 import torch
416418
417419 # Test 1: Basic tensor processing
418- input_tensor = torch .randn (4 , 5 )
420+ with torch .inference_mode ():
421+ input_tensor = torch .randn (4 , 5 )
419422 normalized = await extension .do_stuff (
420423 {"operation" : "process_tensor_isolated" , "tensor" : input_tensor }
421424 )
@@ -428,11 +431,12 @@ def example_entrypoint() -> ExampleExtension:
428431 assert abs (norm_info ["output_std" ] - 1.0 ) < 1e-6 # Should be close to 1
429432
430433 # Test 2: Different dtypes
431- tensors_dict = {
432- "float32" : torch .randn (2 , 3 ),
433- "int64" : torch .randint (0 , 10 , (2 , 3 )),
434- "bool" : torch .tensor ([[True , False ], [False , True ]]),
435- }
434+ with torch .inference_mode ():
435+ tensors_dict = {
436+ "float32" : torch .randn (2 , 3 ),
437+ "int64" : torch .randint (0 , 10 , (2 , 3 )),
438+ "bool" : torch .tensor ([[True , False ], [False , True ]]),
439+ }
436440
437441 dtype_results = await extension .do_stuff (
438442 {"operation" : "test_different_dtypes" , "tensors_dict" : tensors_dict }
@@ -540,7 +544,8 @@ def example_entrypoint() -> ExampleExtension:
540544 import torch
541545
542546 # Test 1: GPU tensor operations
543- gpu_tensor = torch .randn (5 , 5 ).cuda ()
547+ with torch .inference_mode ():
548+ gpu_tensor = torch .randn (5 , 5 ).cuda ()
544549 gpu_result = await extension .do_stuff ({"operation" : "process_gpu_tensor" , "tensor" : gpu_tensor })
545550
546551 assert isinstance (gpu_result , torch .Tensor )
@@ -552,7 +557,8 @@ def example_entrypoint() -> ExampleExtension:
552557 assert "cuda" in gpu_info ["device" ]
553558
554559 # Test 2: CPU to GPU transfer
555- cpu_tensor = torch .ones (3 , 3 )
560+ with torch .inference_mode ():
561+ cpu_tensor = torch .ones (3 , 3 )
556562 transferred_result = await extension .do_stuff (
557563 {"operation" : "transfer_between_devices" , "tensor" : cpu_tensor }
558564 )
@@ -637,7 +643,8 @@ def example_entrypoint() -> ExampleExtension:
637643 import torch
638644
639645 # Test GPU operations
640- gpu_tensor = torch .randn (4 , 4 ).cuda ()
646+ with torch .inference_mode ():
647+ gpu_tensor = torch .randn (4 , 4 ).cuda ()
641648 squared_result = await extension .do_stuff (
642649 {"operation" : "process_gpu_operations" , "tensor" : gpu_tensor }
643650 )
0 commit comments