# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
from ctypes import (
    CFUNCTYPE,
    POINTER,
    c_int,
    c_int64,
    c_size_t,
    c_uint8,
    c_uint16,
    c_void_p,
    pointer,
)

import torch


# Define data structures required for DLPack
class DLDataType(ctypes.Structure):
    _fields_ = [
        ("code", c_uint8),  # Data type code, e.g., 2 for float
        ("bits", c_uint8),  # Number of bits per element, e.g., 32
        ("lanes", c_uint16),  # Number of lanes, usually 1
    ]


class DLDevice(ctypes.Structure):
    _fields_ = [
        ("device_type", c_int),  # Device type, typically 2 for GPU
        ("device_id", c_int),  # Device ID, usually 0 for default GPU
    ]


class DLTensor(ctypes.Structure):
    _fields_ = [
        ("data", c_void_p),  # Data pointer
        ("device", DLDevice),  # Device information
        ("ndim", c_int),  # Number of dimensions
        ("dtype", DLDataType),  # Data type
        ("shape", POINTER(c_int64)),  # Pointer to array of dimension sizes
        (
            "strides",
            POINTER(c_int64),
        ),  # Pointer to strides array (can be NULL for default contiguous layout)
        ("byte_offset", c_size_t),  # Byte offset (usually 0)
    ]


# Deleter type for DLManagedTensor
DLManagedTensorDeleter = CFUNCTYPE(
    None, POINTER(ctypes.c_void_p)
)  # Not used directly here


# Define DLManagedTensor structure, with deleter prototype void(*deleter)(DLManagedTensor*)
class DLManagedTensor(ctypes.Structure):
    pass


DLManagedTensor._fields_ = [
    ("dl_tensor", DLTensor),
    ("manager_ctx", c_void_p),
    ("deleter", CFUNCTYPE(None, POINTER(DLManagedTensor))),
]


# A no-op deleter that doesn't perform any operation
@CFUNCTYPE(None, POINTER(DLManagedTensor))
def no_op_deleter(dmt_ptr):
    # You can also call cudaFree here if you want to free memory when the tensor's lifecycle ends
    pass


# Wrapper class to prevent Python garbage collection of DLPack-related objects
class CapsuleWrapper:
    """
    A wrapper class that holds references to the PyCapsule and its associated data.

    This class prevents Python's garbage collector from collecting the shape_array and
    managed_tensor objects while the capsule is still in use. It serves as a container
    to maintain the lifecycle of all DLPack-related objects.
    """

    def __init__(self, capsule, shape_array, managed_tensor):
        """
        Initialize the CapsuleWrapper with the necessary objects.

        Parameters:
            capsule: The PyCapsule object that follows the DLPack protocol
            shape_array: The array containing tensor shape information
            managed_tensor: The DLManagedTensor instance that the capsule points to
        """
        self.capsule = (
            capsule  # The main PyCapsule object that can be passed to other libraries
        )
        self._shape_array = shape_array  # Keep reference to prevent garbage collection
        self._managed_tensor = (
            managed_tensor  # Keep reference to prevent garbage collection
        )


def create_dlpack_capsule(
    ptr, segment_size, segment_stride, num_segments, torch_dtype, dev_id
):
    """
    Parameters:
      ptr: GPU memory address obtained from cudaMalloc (Python int)
      segment_size: Memory size of each segments in bytes
      segment_stride: Memory stride size between segments in bytes
      num_segments: Number of segments
      torch_dtype: torch dtype
      dev_id: device id.
    Returns:
      A PyCapsule object compliant with DLPack specification, which can be directly converted to a
      tensor using torch.utils.dlpack.from_dlpack
    """
    bits_per_elements = 0
    dldata_type_code = 0
    # refer to https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h#L160
    if torch_dtype in [
        torch.float8_e5m2,
        torch.float8_e4m3fn,
        torch.bfloat16,
        torch.float16,
        torch.float32,
        torch.float64,
    ]:
        bits_per_elements = torch.finfo(torch_dtype).bits
        dldata_type_code = 2
    elif torch_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
        bits_per_elements = torch.iinfo(torch_dtype).bits
        dldata_type_code = 0
    elif torch_dtype in [torch.uint8, torch.uint16, torch.uint32, torch.uint64]:
        bits_per_elements = torch.iinfo(torch_dtype).bits
        dldata_type_code = 1
    else:
        raise NotImplementedError(torch_dtype)
    bytes_per_element = bits_per_elements // 8
    # Allocate space for shape (constructing a one-dimensional tensor here)
    ShapeArrayType = c_int64 * 2  # 1 dimension
    shape_array = ShapeArrayType(num_segments, segment_size // bytes_per_element)
    stride_array = ShapeArrayType(segment_stride // bytes_per_element, 1)
    # Set device information: GPU (device_type=2) and device_id=dev_id (modify as needed)
    device = DLDevice(device_type=2, device_id=dev_id)
    # Set data type
    dtype = DLDataType(code=dldata_type_code, bits=bits_per_elements, lanes=1)
    # Construct DLTensor
    dltensor = DLTensor()
    dltensor.data = c_void_p(ptr)
    dltensor.device = device
    dltensor.ndim = 2
    dltensor.dtype = dtype
    dltensor.shape = ctypes.cast(shape_array, POINTER(c_int64))
    dltensor.strides = ctypes.cast(stride_array, POINTER(c_int64))
    dltensor.byte_offset = 0
    # Construct DLManagedTensor and set deleter to no-op (you can also call cudaFree here)
    managed_tensor = DLManagedTensor()
    managed_tensor.dl_tensor = dltensor
    managed_tensor.manager_ctx = None
    managed_tensor.deleter = no_op_deleter
    # Note: Must ensure that shape_array and managed_tensor are not garbage collected by Python,
    # A simple way is to attach them to the capsule object.
    # Call PyCapsule_New to create capsule
    PyCapsule_New = ctypes.pythonapi.PyCapsule_New
    PyCapsule_New.restype = c_void_p
    PyCapsule_New.argtypes = [c_void_p, ctypes.c_char_p, c_void_p]
    # Allocate managed_tensor on the heap (note that pointer returns a pointer)
    managed_tensor_ptr = pointer(managed_tensor)
    # The capsule name must be "dltensor", as required by the DLPack specification
    capsule_ptr = PyCapsule_New(managed_tensor_ptr, b"dltensor", None)
    # Convert capsule_ptr to Python object
    capsule = ctypes.cast(capsule_ptr, ctypes.py_object).value
    # To prevent shape_array and managed_tensor from being collected, we attach them as attributes to the capsule
    capsule_wrapper = CapsuleWrapper(capsule, shape_array, managed_tensor)
    return capsule_wrapper


def pack_strided_memory(
    ptr: int,
    segment_size: int,
    segment_stride: int,
    num_segments: int,
    dtype: torch.dtype,
    dev_id,
):
    """
    Pack GPU memory into a PyTorch tensor with specified stride.

    Parameters:
        ptr: GPU memory address obtained from cudaMalloc
        segment_size: Memory size of each segment in bytes
        segment_stride: Memory stride size between segments in bytes
        num_segments: Number of segments
        dtype: PyTorch data type for the resulting tensor
        dev_id: CUDA device ID

    Returns:
        PyTorch tensor that references the provided memory

    Note:
        This function creates a new DLPack capsule each time it's called,
        even with the same pointer. Each capsule is consumed only once.
    """
    # Create a new capsule each time
    capsule_wrapper = create_dlpack_capsule(
        ptr, segment_size, segment_stride, num_segments, dtype, dev_id
    )
    torch_tensor = torch.utils.dlpack.from_dlpack(capsule_wrapper.capsule)
    torch_tensor._capsule_wrapper = capsule_wrapper
    return torch_tensor
