diff --git a/segment_anything/build_sam.py b/segment_anything/build_sam.py index b280cf4..1edb21b 100644 --- a/segment_anything/build_sam.py +++ b/segment_anything/build_sam.py @@ -103,7 +103,8 @@ def _build_sam( # sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: - state_dict = torch.load(f) + device = "cuda" if torch.cuda.is_available() else "cpu" + state_dict = torch.load(f, map_location=device) info = sam.load_state_dict(state_dict, strict=False) print(info) for n, p in sam.named_parameters():