from typing import List, Optional

import torch

try:
    import triton
    import triton.language as tl
except ImportError as err:
    raise ImportError("Triton is not installed") from err


@triton.jit
def apply_token_bitmask_inplace_kernel(
    logits_ptr,
    bitmask_ptr,
    indices_ptr,
    num_rows,
    vocab_size,
    logits_strides,
    bitmask_strides,
    NUM_SMS: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    """Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
    where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
    the masked logits will be set to -inf.

    Parameters
    ----------
    logits_ptr : tl.tensor
        Pointer to the logits tensor to apply the bitmask to.

    bitmask_ptr : tl.tensor
        Pointer to the bitmask tensor to apply.

    indices_ptr : Optional[tl.tensor]
        Optional pointer to indices tensor specifying which rows to apply the mask to.

    num_rows : int
        Number of rows to process. If indices_ptr is provided, this is the number of unique indices.

    vocab_size : int
        Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
        same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.

    logits_strides : int
        Stride between rows in the logits tensor.

    bitmask_strides : int
        Stride between rows in the bitmask tensor.

    NUM_SMS : int
        Number of streaming multiprocessors to use.

    BLOCK_SIZE : int
        Size of processing blocks.
    """

    pid = tl.program_id(0)
    num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
    for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
        row_id = work_id // num_blocks
        block_offset = (work_id % num_blocks) * BLOCK_SIZE
        batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
        offsets = block_offset + tl.arange(0, BLOCK_SIZE)
        bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
        vocab_mask = offsets < vocab_size
        packed_bitmask_mask = bitmask_offsets < bitmask_strides
        packed_bitmask = tl.load(
            bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets, packed_bitmask_mask
        )
        bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
        bitmask = bitmask.reshape(BLOCK_SIZE)

        tl.store(
            logits_ptr + batch_id * logits_strides + offsets, -float("inf"), vocab_mask & bitmask
        )


def apply_token_bitmask_inplace_triton(
    logits: torch.Tensor,
    bitmask: torch.Tensor,
    vocab_size: Optional[int] = None,
    indices: Optional[List[int]] = None,
):
    NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
    BLOCK_SIZE = 4096

    arch = torch.cuda.get_device_properties(0).gcnArchName
    if torch.version.hip is not None and "gfx1" not in arch:
        # For AMD GPUs (non-Navi)
        WARP_SIZE = 64
    else:
        WARP_SIZE = 32

    assert bitmask.dtype == torch.int32, "bitmask must be of type int32"

    detected_vocab_size = min(logits.shape[-1], bitmask.shape[-1] * 32)
    if vocab_size is None:
        vocab_size = detected_vocab_size
    else:
        assert (
            vocab_size <= detected_vocab_size
        ), f"vocab_size {vocab_size} is larger than the detected vocab_size {detected_vocab_size}"

    num_rows = len(indices) if indices is not None else logits.shape[0] if logits.ndim == 2 else 1

    if indices is not None:
        indices_cpu = torch.tensor(indices, dtype=torch.int32)
        indices = indices_cpu.to(device=logits.device, non_blocking=True)

    grid = (NUM_SMS,)

    apply_token_bitmask_inplace_kernel[grid](
        logits,
        bitmask,
        indices,
        num_rows,
        vocab_size,
        logits.stride()[0],
        bitmask.stride()[0],
        NUM_SMS,
        BLOCK_SIZE,
        num_warps=BLOCK_SIZE // WARP_SIZE // (16 // logits.element_size()),
        num_stages=3,
    )
