Source code for pypots.nn.functional.cuda

"""

"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


import torch


# overwrite autocast to make it compatible with both torch >=2.4 and <2.4
[docs] def autocast(**kwargs): if torch.__version__ < "2.4": from torch.cuda.amp import autocast return autocast(**kwargs) else: from torch.amp import autocast return autocast("cuda", **kwargs)