#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################

"""
Data types and tags used for emitting CUTLASS C++ kernels
"""

import enum
import re

# The following block implements enum.auto() for Python 3.5 variants that don't include it such
# as the default 3.5.2 on Ubuntu 16.04.
#
# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility

try:
    from enum import auto as enum_auto
except ImportError:
    __cutlass_library_auto_enum = 0

    def enum_auto() -> int:  # type: ignore[no-redef]
        global __cutlass_library_auto_enum
        i = __cutlass_library_auto_enum
        __cutlass_library_auto_enum += 1
        return i


###################################################################################################


#
class GeneratorTarget(enum.Enum):
    Library = enum_auto()


#
GeneratorTargetNames = {GeneratorTarget.Library: "library"}
#

###################################################################################################


#
class DataType(enum.Enum):
    void = enum_auto()  # primarily used to disable C tensor for epilogues
    b1 = enum_auto()
    u2 = enum_auto()
    u4 = enum_auto()
    u8 = enum_auto()
    u16 = enum_auto()
    u32 = enum_auto()
    u64 = enum_auto()
    s2 = enum_auto()
    s4 = enum_auto()
    s8 = enum_auto()
    s16 = enum_auto()
    s32 = enum_auto()
    s64 = enum_auto()
    e4m3 = enum_auto()
    e5m2 = enum_auto()
    f8 = enum_auto()
    f6 = enum_auto()
    f4 = enum_auto()
    e3m2 = enum_auto()
    e2m3 = enum_auto()
    e2m1 = enum_auto()
    ue8m0 = enum_auto()
    ue4m3 = enum_auto()
    f16 = enum_auto()
    bf16 = enum_auto()
    f32 = enum_auto()
    tf32 = enum_auto()
    f64 = enum_auto()
    cf16 = enum_auto()
    cbf16 = enum_auto()
    cf32 = enum_auto()
    ctf32 = enum_auto()
    cf64 = enum_auto()
    cs2 = enum_auto()
    cs4 = enum_auto()
    cs8 = enum_auto()
    cs16 = enum_auto()
    cs32 = enum_auto()
    cs64 = enum_auto()
    cu2 = enum_auto()
    cu4 = enum_auto()
    cu8 = enum_auto()
    cu16 = enum_auto()
    cu32 = enum_auto()
    cu64 = enum_auto()
    invalid = enum_auto()


#
ShortDataTypeNames = {
    DataType.s32: "i",
    DataType.e4m3: "e4m3",
    DataType.e5m2: "e5m2",
    DataType.f16: "h",
    DataType.f32: "s",
    DataType.f64: "d",
    DataType.cf32: "c",
    DataType.cf64: "z",
    DataType.f8: "f8",
    DataType.f6: "f6",
    DataType.f4: "f4",
}

#
DataTypeNames = {
    DataType.void: "void",
    DataType.b1: "b1",
    DataType.u2: "u2",
    DataType.u4: "u4",
    DataType.u8: "u8",
    DataType.u16: "u16",
    DataType.u32: "u32",
    DataType.u64: "u64",
    DataType.s2: "s2",
    DataType.s4: "s4",
    DataType.s8: "s8",
    DataType.s16: "s16",
    DataType.s32: "s32",
    DataType.s64: "s64",
    DataType.e4m3: "e4m3",
    DataType.e5m2: "e5m2",
    DataType.f8: "f8",
    DataType.f6: "f6",
    DataType.f4: "f4",
    DataType.e2m3: "e2m3",
    DataType.e3m2: "e3m2",
    DataType.e2m1: "e2m1",
    DataType.ue8m0: "ue8m0",
    DataType.ue4m3: "ue4m3",
    DataType.f16: "f16",
    DataType.bf16: "bf16",
    DataType.f32: "f32",
    DataType.tf32: "tf32",
    DataType.f64: "f64",
    DataType.cf16: "cf16",
    DataType.cbf16: "cbf16",
    DataType.cf32: "cf32",
    DataType.ctf32: "ctf32",
    DataType.cf64: "cf64",
    DataType.cu2: "cu2",
    DataType.cu4: "cu4",
    DataType.cu8: "cu8",
    DataType.cu16: "cu16",
    DataType.cu32: "cu32",
    DataType.cu64: "cu64",
    DataType.cs2: "cs2",
    DataType.cs4: "cs4",
    DataType.cs8: "cs8",
    DataType.cs16: "cs16",
    DataType.cs32: "cs32",
    DataType.cs64: "cs64",
}

DataTypeTag = {
    DataType.void: "void",
    DataType.b1: "cutlass::uint1b_t",
    DataType.u2: "cutlass::uint2b_t",
    DataType.u4: "cutlass::uint4b_t",
    DataType.u8: "uint8_t",
    DataType.u16: "uint16_t",
    DataType.u32: "uint32_t",
    DataType.u64: "uint64_t",
    DataType.s2: "cutlass::int2b_t",
    DataType.s4: "cutlass::int4b_t",
    DataType.s8: "int8_t",
    DataType.s16: "int16_t",
    DataType.s32: "int32_t",
    DataType.s64: "int64_t",
    DataType.e4m3: "cutlass::float_e4m3_t",
    DataType.e5m2: "cutlass::float_e5m2_t",
    DataType.f8: "cutlass::type_erased_dynamic_float8_t",
    DataType.f6: "cutlass::type_erased_dynamic_float6_t",
    DataType.f4: "cutlass::type_erased_dynamic_float4_t",
    DataType.e2m3: "cutlass::float_e2m3_t",
    DataType.e3m2: "cutlass::float_e3m2_t",
    DataType.e2m1: "cutlass::float_e2m1_t",
    DataType.ue8m0: "cutlass::float_ue8m0_t",
    DataType.ue4m3: "cutlass::float_ue4m3_t",
    DataType.f16: "cutlass::half_t",
    DataType.bf16: "cutlass::bfloat16_t",
    DataType.f32: "float",
    DataType.tf32: "cutlass::tfloat32_t",
    DataType.f64: "double",
    DataType.cf16: "cutlass::complex<cutlass::half_t>",
    DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
    DataType.cf32: "cutlass::complex<float>",
    DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
    DataType.cf64: "cutlass::complex<double>",
    DataType.cu2: "cutlass::complex<cutlass::uint2b_t>",
    DataType.cu4: "cutlass::complex<cutlass::uint4b_t>",
    DataType.cu8: "cutlass::complex<cutlass::uint8_t>",
    DataType.cu16: "cutlass::complex<cutlass::uint16_t>",
    DataType.cu32: "cutlass::complex<cutlass::uint32_t>",
    DataType.cu64: "cutlass::complex<cutlass::uint64_t>",
    DataType.cs2: "cutlass::complex<cutlass::int2b_t>",
    DataType.cs4: "cutlass::complex<cutlass::int4b_t>",
    DataType.cs8: "cutlass::complex<cutlass::int8_t>",
    DataType.cs16: "cutlass::complex<cutlass::int16_t>",
    DataType.cs32: "cutlass::complex<cutlass::int32_t>",
    DataType.cs64: "cutlass::complex<cutlass::int64_t>",
}

DataTypeSize = {
    DataType.void: 0,
    DataType.b1: 1,
    DataType.u2: 2,
    DataType.u4: 4,
    DataType.u8: 8,
    DataType.u16: 16,
    DataType.u32: 32,
    DataType.u64: 64,
    DataType.s2: 2,
    DataType.s4: 4,
    DataType.s8: 8,
    DataType.s16: 16,
    DataType.s32: 32,
    DataType.s64: 64,
    DataType.e4m3: 8,
    DataType.e5m2: 8,
    DataType.f8: 8,
    DataType.f6: 6,
    DataType.f4: 4,
    DataType.e2m3: 6,
    DataType.e3m2: 6,
    DataType.e2m1: 4,
    DataType.ue8m0: 8,
    DataType.ue4m3: 8,
    DataType.f16: 16,
    DataType.bf16: 16,
    DataType.f32: 32,
    DataType.tf32: 32,
    DataType.f64: 64,
    DataType.cf16: 32,
    DataType.cbf16: 32,
    DataType.cf32: 64,
    DataType.ctf32: 32,
    DataType.cf64: 128,
    DataType.cu2: 4,
    DataType.cu4: 8,
    DataType.cu8: 16,
    DataType.cu16: 32,
    DataType.cu32: 64,
    DataType.cu64: 128,
    DataType.cs2: 4,
    DataType.cs4: 8,
    DataType.cs8: 16,
    DataType.cs16: 32,
    DataType.cs32: 64,
    DataType.cs64: 128,
}


###################################################################################################
#
class BlasMode(enum.Enum):
    symmetric = enum_auto()
    hermitian = enum_auto()


#
BlasModeTag = {
    BlasMode.symmetric: "cutlass::BlasMode::kSymmetric",
    BlasMode.hermitian: "cutlass::BlasMode::kHermitian",
}


#
class ComplexTransform(enum.Enum):
    none = enum_auto()
    conj = enum_auto()


#
ComplexTransformTag = {
    ComplexTransform.none: "cutlass::ComplexTransform::kNone",
    ComplexTransform.conj: "cutlass::ComplexTransform::kConjugate",
}

# Used for cutlass3x complex kernel collective mainloop builder instantiation
ComplexTransformTag3x = {
    ComplexTransform.none: "cute::identity",
    ComplexTransform.conj: "cute::conjugate",
}

#
RealComplexBijection = [
    (DataType.f16, DataType.cf16),
    (DataType.f32, DataType.cf32),
    (DataType.f64, DataType.cf64),
]


#
def is_complex(data_type):
    return any(data_type == c for _r, c in RealComplexBijection)


def is_block_scaled(gemm_kind):
    return gemm_kind in (
        GemmKind.BlockScaledUniversal3x,
        GemmKind.GroupedBlockScaledUniversal3x,
    )


def is_blockwise(gemm_kind):
    return gemm_kind in (
        GemmKind.BlockwiseUniversal3x,
        GemmKind.GroupedBlockwiseUniversal3x,
    )


def is_grouped(gemm_kind):
    return gemm_kind in (
        GemmKind.GroupedUniversal3x,
        GemmKind.GroupedBlockScaledUniversal3x,
        GemmKind.GroupedBlockwiseUniversal3x,
    )


#
def get_complex_from_real(real_type):
    for r, c in RealComplexBijection:
        if real_type == r:
            return c
    return DataType.invalid


#
def get_real_from_complex(complex_type):
    for r, c in RealComplexBijection:
        if complex_type == c:
            return r
    return DataType.invalid


# TMA requires an alignment of 128 bits for all data types
def get_tma_alignment(data_type):
    if data_type == DataType.void:
        return 0
    elif DataTypeSize[data_type] == 6:
        return 128  # 96B alignment for 16U6 format
    else:
        return 128 // DataTypeSize[data_type]


#
class ComplexMultiplyOp(enum.Enum):
    multiply_add = enum_auto()
    gaussian = enum_auto()


###################################################################################################


#
class MathOperation(enum.Enum):
    multiply_add = enum_auto()
    multiply_add_saturate = enum_auto()
    multiply_add_mixed_input_upcast = enum_auto()
    xor_popc = enum_auto()
    and_popc = enum_auto()
    multiply_add_fast_bf16 = enum_auto()
    multiply_add_fast_f16 = enum_auto()
    multiply_add_fast_f32 = enum_auto()
    multiply_add_complex_fast_f32 = enum_auto()
    multiply_add_complex = enum_auto()
    multiply_add_complex_gaussian = enum_auto()
    multiply_add_fast_accum = enum_auto()


#
MathOperationTag = {
    MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd",
    MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate",
    MathOperation.multiply_add_mixed_input_upcast: "cutlass::arch::OpMultiplyAddMixedInputUpcast",
    MathOperation.xor_popc: "cutlass::arch::OpXorPopc",
    MathOperation.and_popc: "cutlass::arch::OpAndPopc",
    MathOperation.multiply_add_fast_bf16: "cutlass::arch::OpMultiplyAddFastBF16",
    MathOperation.multiply_add_fast_f16: "cutlass::arch::OpMultiplyAddFastF16",
    MathOperation.multiply_add_fast_f32: "cutlass::arch::OpMultiplyAddFastF32",
    MathOperation.multiply_add_complex_fast_f32: "cutlass::arch::OpMultiplyAddComplexFastF32",
    MathOperation.multiply_add_complex: "cutlass::arch::OpMultiplyAddComplex",
    MathOperation.multiply_add_complex_gaussian: "cutlass::arch::OpMultiplyAddGaussianComplex",
    MathOperation.multiply_add_fast_accum: "cutlass::arch::OpMultiplyAddFastAccum",
}

###################################################################################################


#
class LayoutType(enum.Enum):
    ColumnMajor = enum_auto()
    RowMajor = enum_auto()
    ColumnMajorInterleaved2 = enum_auto()
    RowMajorInterleaved2 = enum_auto()
    ColumnMajorInterleaved32 = enum_auto()
    RowMajorInterleaved32 = enum_auto()
    ColumnMajorInterleaved64 = enum_auto()
    RowMajorInterleaved64 = enum_auto()
    TensorNWC = enum_auto()
    TensorNHWC = enum_auto()
    TensorNDHWC = enum_auto()
    TensorNCHW = enum_auto()
    TensorNGHWC = enum_auto()
    TensorNC32HW32 = enum_auto()
    TensorNC64HW64 = enum_auto()
    TensorC32RSK32 = enum_auto()
    TensorC64RSK64 = enum_auto()
    TensorKCS = enum_auto()
    TensorKCSR = enum_auto()
    TensorKCSRT = enum_auto()


#
LayoutTag = {
    LayoutType.ColumnMajor: "cutlass::layout::ColumnMajor",
    LayoutType.RowMajor: "cutlass::layout::RowMajor",
    LayoutType.ColumnMajorInterleaved2: "cutlass::layout::ColumnMajorInterleaved<2>",
    LayoutType.RowMajorInterleaved2: "cutlass::layout::RowMajorInterleaved<2>",
    LayoutType.ColumnMajorInterleaved32: "cutlass::layout::ColumnMajorInterleaved<32>",
    LayoutType.RowMajorInterleaved32: "cutlass::layout::RowMajorInterleaved<32>",
    LayoutType.ColumnMajorInterleaved64: "cutlass::layout::ColumnMajorInterleaved<64>",
    LayoutType.RowMajorInterleaved64: "cutlass::layout::RowMajorInterleaved<64>",
    LayoutType.TensorNWC: "cutlass::layout::TensorNWC",
    LayoutType.TensorNHWC: "cutlass::layout::TensorNHWC",
    LayoutType.TensorNDHWC: "cutlass::layout::TensorNDHWC",
    LayoutType.TensorNCHW: "cutlass::layout::TensorNCHW",
    LayoutType.TensorNGHWC: "cutlass::layout::TensorNGHWC",
    LayoutType.TensorNC32HW32: "cutlass::layout::TensorNCxHWx<32>",
    LayoutType.TensorC32RSK32: "cutlass::layout::TensorCxRSKx<32>",
    LayoutType.TensorNC64HW64: "cutlass::layout::TensorNCxHWx<64>",
    LayoutType.TensorC64RSK64: "cutlass::layout::TensorCxRSKx<64>",
    LayoutType.TensorKCS: "cutlass::layout::TensorKCS",
    LayoutType.TensorKCSR: "cutlass::layout::TensorKCSR",
    LayoutType.TensorKCSRT: "cutlass::layout::TensorKCSRT",
}

#
TransposedLayout = {
    LayoutType.ColumnMajor: LayoutType.RowMajor,
    LayoutType.RowMajor: LayoutType.ColumnMajor,
    LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2,
    LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2,
    LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32,
    LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32,
    LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64,
    LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64,
    LayoutType.TensorNHWC: LayoutType.TensorNHWC,
}

#
ShortLayoutTypeNames = {
    LayoutType.ColumnMajor: "n",
    LayoutType.ColumnMajorInterleaved2: "n2",
    LayoutType.ColumnMajorInterleaved32: "n32",
    LayoutType.ColumnMajorInterleaved64: "n64",
    LayoutType.RowMajor: "t",
    LayoutType.RowMajorInterleaved2: "t2",
    LayoutType.RowMajorInterleaved32: "t32",
    LayoutType.RowMajorInterleaved64: "t64",
    LayoutType.TensorNWC: "nwc",
    LayoutType.TensorNHWC: "nhwc",
    LayoutType.TensorNDHWC: "ndhwc",
    LayoutType.TensorNCHW: "nchw",
    LayoutType.TensorNGHWC: "nghwc",
    LayoutType.TensorNC32HW32: "nc32hw32",
    LayoutType.TensorNC64HW64: "nc64hw64",
    LayoutType.TensorC32RSK32: "c32rsk32",
    LayoutType.TensorC64RSK64: "c64rsk64",
    LayoutType.TensorKCS: "kcs",
    LayoutType.TensorKCSR: "kcsr",
    LayoutType.TensorKCSRT: "kcsrt",
}

#
ShortComplexLayoutNames = {
    (LayoutType.ColumnMajor, ComplexTransform.none): "n",
    (LayoutType.ColumnMajor, ComplexTransform.conj): "c",
    (LayoutType.RowMajor, ComplexTransform.none): "t",
    (LayoutType.RowMajor, ComplexTransform.conj): "h",
}


###################################################################################################
class KernelScheduleType(enum.Enum):
    ScheduleAuto = enum_auto()
    Multistage = enum_auto()
    CpAsyncWarpSpecialized = enum_auto()
    CpAsyncWarpSpecializedPingpong = enum_auto()
    CpAsyncWarpSpecializedCooperative = enum_auto()
    Tma = enum_auto()
    TmaWarpSpecialized = enum_auto()
    TmaWarpSpecializedPingpong = enum_auto()
    TmaWarpSpecializedCooperative = enum_auto()
    TmaWarpSpecializedFP8FastAccum = enum_auto()
    TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
    TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
    ImplicitTmaWarpSpecializedSm90 = enum_auto()
    PtrArrayTmaWarpSpecializedCooperative = enum_auto()
    PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
    PtrArrayTmaWarpSpecializedPingpong = enum_auto()
    PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()

    BlockwiseTmaWarpSpecializedCooperative = enum_auto()
    PtrArrayBlockwiseTmaWarpSpecializedCooperative = enum_auto()

    TmaWarpSpecialized1SmSm100 = enum_auto()
    TmaWarpSpecialized2SmSm100 = enum_auto()
    ImplicitTmaWarpSpecialized1SmSm100 = enum_auto()
    ImplicitTmaWarpSpecialized2SmSm100 = enum_auto()

    PtrArrayTmaWarpSpecialized1SmSm100 = enum_auto()
    PtrArrayTmaWarpSpecialized2SmSm100 = enum_auto()

    PtrArrayTmaWarpSpecialized1SmBlockScaledSm100 = enum_auto()
    PtrArrayTmaWarpSpecialized2SmBlockScaledSm100 = enum_auto()
    PtrArrayNvf4TmaWarpSpecialized1SmSm100 = enum_auto()
    PtrArrayNvf4TmaWarpSpecialized2SmSm100 = enum_auto()
    PtrArrayMxf4TmaWarpSpecialized1SmSm100 = enum_auto()
    PtrArrayMxf4TmaWarpSpecialized2SmSm100 = enum_auto()
    PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
    PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()

    SparseTmaWarpSpecialized1SmSm100 = enum_auto()
    SparseTmaWarpSpecialized2SmSm100 = enum_auto()

    BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto()
    BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto()
    Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
    Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()

    BlockwiseTmaWarpSpecialized1SmSm100 = enum_auto()
    BlockwiseTmaWarpSpecialized2SmSm100 = enum_auto()

    PtrArrayBlockwiseTmaWarpSpecialized1SmSm100 = enum_auto()
    PtrArrayBlockwiseTmaWarpSpecialized2SmSm100 = enum_auto()

    Mxf4TmaWarpSpecialized1SmSm100 = enum_auto()
    Mxf4TmaWarpSpecialized2SmSm100 = enum_auto()
    Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
    Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()

    Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto()
    Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto()
    Nvf4TmaWarpSpecializedCooperativeSm120 = enum_auto()
    Nvf4TmaWarpSpecializedPingpongSm120 = enum_auto()
    Mxf4TmaWarpSpecializedCooperativeSm120 = enum_auto()
    Mxf4TmaWarpSpecializedPingpongSm120 = enum_auto()

    F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto()

    BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto()
    BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto()


KernelScheduleTag = {
    KernelScheduleType.ScheduleAuto: "cutlass::gemm::collective::KernelScheduleAuto",
    KernelScheduleType.Multistage: "cutlass::gemm::KernelMultistage",
    KernelScheduleType.CpAsyncWarpSpecialized: "cutlass::gemm::KernelCpAsyncWarpSpecialized",
    KernelScheduleType.CpAsyncWarpSpecializedPingpong: "cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong",
    KernelScheduleType.CpAsyncWarpSpecializedCooperative: "cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative",
    KernelScheduleType.Tma: "cutlass::gemm::KernelTma",
    KernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized",
    KernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong",
    KernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative",
    KernelScheduleType.TmaWarpSpecializedFP8FastAccum: "cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum",
    KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: "cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum",
    KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: "cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum",
    KernelScheduleType.ImplicitTmaWarpSpecializedSm90: "cutlass::conv::KernelImplicitTmaWarpSpecializedSm90",
    KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum",
    KernelScheduleType.TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized1SmSm100",
    KernelScheduleType.TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized2SmSm100",
    KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: "cutlass::conv::KernelImplicitTmaWarpSpecialized1SmSm100",
    KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: "cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100",
    KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100",
    KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100",
    KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100",
    KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100",
    KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100",
    KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100",
    KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100",
    KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100",
    KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100",
    KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100",
    KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100",
    KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100",
    KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100",
    KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100",
    KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100",
    KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100",
    KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: "cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative",
    KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: "cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum",
    KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: "cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong",
    KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: "cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum",
    KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: "cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum",
    KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100",
    KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100",
    KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100",
    KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100",
    KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100",
    KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100",
    KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100",
    KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100",
    KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: "cutlass::gemm::KernelTmaWarpSpecializedMxf8f6f4Sm120",
    KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: "cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120",
    KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: "cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120",
    KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: "cutlass::gemm::KernelTmaWarpSpecializedPingpongNvf4Sm120",
    KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: "cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120",
    KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: "cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf4Sm120",
    KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: "cutlass::gemm::KernelScheduleSparseF8f6f4Sm120",
    KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: "cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120",
    KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: "cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120",
}

#
KernelScheduleSuffixes = {
    KernelScheduleType.ScheduleAuto: "",
    KernelScheduleType.Multistage: "_cpasync",
    KernelScheduleType.CpAsyncWarpSpecialized: "_cpasync_warpspecialized",
    KernelScheduleType.CpAsyncWarpSpecializedPingpong: "_cpasync_warpspecialized_pingpong",
    KernelScheduleType.CpAsyncWarpSpecializedCooperative: "_cpasync_warpspecialized_cooperative",
    KernelScheduleType.Tma: "_unspecialized",
    KernelScheduleType.TmaWarpSpecialized: "_warpspecialized",
    KernelScheduleType.TmaWarpSpecializedPingpong: "_warpspecialized_pingpong",
    KernelScheduleType.TmaWarpSpecializedCooperative: "_warpspecialized_cooperative",
    KernelScheduleType.TmaWarpSpecializedFP8FastAccum: "_warpspecialized_fp8_fastaccum",
    KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: "_warpspecialized_cooperative_fp8_fastaccum",
    KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: "_warpspecialized_pingpong_fp8_fastaccum",
    KernelScheduleType.ImplicitTmaWarpSpecializedSm90: "_warpspecialized",
    KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: "_warpspecialized_cooperative",
    KernelScheduleType.TmaWarpSpecialized1SmSm100: "_1sm",
    KernelScheduleType.TmaWarpSpecialized2SmSm100: "_2sm",
    KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: "_1sm",
    KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: "_2sm",
    KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: "_1sm",
    KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: "_2sm",
    KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: "_1sm",
    KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: "_2sm",
    KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: "_1sm",
    KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: "_2sm",
    KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: "_q_1sm",
    KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: "_q_2sm",
    KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: "_1sm",
    KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: "_2sm",
    KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: "_1sm",
    KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: "_2sm",
    KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: "_o_vs32_1sm",
    KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: "_o_vs32_2sm",
    KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: "_o_vs16_1sm",
    KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: "_o_vs16_2sm",
    KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: "_warpspecialized_cooperative",
    KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: "_warpspecialized_cooperative_fp8_fastaccum",
    KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: "_warpspecialized_pingpong",
    KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: "_warpspecialized_pingpong_fp8_fastaccum",
    KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: "_warpspecialized_cooperative",
    KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "_1sm",
    KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "_2sm",
    KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "_o_vs16_1sm",
    KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: "_o_vs16_2sm",
    KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: "_o_vs32_1sm",
    KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: "_o_vs32_2sm",
    KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "_o_vs32_1sm",
    KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "_o_vs32_2sm",
    KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: "_cooperative_q",
    KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: "_pingpong_q",
    KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: "_cooperative_o_vs16",
    KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: "_pingpong_o_vs16",
    KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: "_cooperative_o_vs32",
    KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: "_pingpong_o_vs32",
    KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: "_q",
    KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: "_cooperative_q",
    KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: "_pingpong_q",
}


class EpilogueScheduleType(enum.Enum):
    ScheduleAuto = enum_auto()
    EpilogueTransposed = enum_auto()
    NoSmemWarpSpecialized = enum_auto()
    PtrArrayNoSmemWarpSpecialized = enum_auto()
    NoSmemWarpSpecialized1Sm = enum_auto()
    NoSmemWarpSpecialized2Sm = enum_auto()
    PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
    PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
    TmaWarpSpecialized = enum_auto()
    TmaWarpSpecializedCooperative = enum_auto()
    TmaWarpSpecialized1Sm = enum_auto()
    TmaWarpSpecialized2Sm = enum_auto()
    PtrArrayTmaWarpSpecialized1Sm = enum_auto()
    PtrArrayTmaWarpSpecialized2Sm = enum_auto()
    PtrArrayTmaWarpSpecializedPingpong = enum_auto()
    PtrArrayTmaWarpSpecializedCooperative = enum_auto()


#
EpilogueScheduleTag = {
    EpilogueScheduleType.ScheduleAuto: "cutlass::epilogue::collective::EpilogueScheduleAuto",
    EpilogueScheduleType.EpilogueTransposed: "cutlass::gemm::EpilogueTransposed",
    EpilogueScheduleType.NoSmemWarpSpecialized: "cutlass::epilogue::NoSmemWarpSpecialized",
    EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: "cutlass::epilogue::PtrArrayNoSmemWarpSpecialized",
    EpilogueScheduleType.NoSmemWarpSpecialized1Sm: "cutlass::epilogue::NoSmemWarpSpecialized1Sm",
    EpilogueScheduleType.NoSmemWarpSpecialized2Sm: "cutlass::epilogue::NoSmemWarpSpecialized2Sm",
    EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: "cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm",
    EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: "cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm",
    EpilogueScheduleType.TmaWarpSpecialized: "cutlass::epilogue::TmaWarpSpecialized",
    EpilogueScheduleType.TmaWarpSpecializedCooperative: "cutlass::epilogue::TmaWarpSpecializedCooperative",
    EpilogueScheduleType.TmaWarpSpecialized1Sm: "cutlass::epilogue::TmaWarpSpecialized1Sm",
    EpilogueScheduleType.TmaWarpSpecialized2Sm: "cutlass::epilogue::TmaWarpSpecialized2Sm",
    EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: "cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm",
    EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: "cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm",
    EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: "cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative",
    EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: "cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong",
}

#
EpilogueScheduleSuffixes = {
    EpilogueScheduleType.ScheduleAuto: "",
    EpilogueScheduleType.EpilogueTransposed: "",
    EpilogueScheduleType.NoSmemWarpSpecialized: "_epi_nosmem",
    EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: "_epi_nosmem",
    EpilogueScheduleType.NoSmemWarpSpecialized1Sm: "_epi_nosmem",
    EpilogueScheduleType.NoSmemWarpSpecialized2Sm: "_epi_nosmem",
    EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: "_epi_nosmem",
    EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: "_epi_nosmem",
    EpilogueScheduleType.TmaWarpSpecialized: "_epi_tma",
    EpilogueScheduleType.TmaWarpSpecializedCooperative: "_epi_tma",
    EpilogueScheduleType.TmaWarpSpecialized1Sm: "",
    EpilogueScheduleType.TmaWarpSpecialized2Sm: "_epi_tma",
    EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: "_tma_1sm",
    EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: "_tma_2sm",
    EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: "_epi_tma",
    EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: "_epi_tma",
}


class EpilogueFunctor3x(enum.Enum):
    LinearCombination = enum_auto()
    LinearCombinationBlockScaleFactor = enum_auto()


#
EpilogueFunctor3xTag = {
    EpilogueFunctor3x.LinearCombination: "cutlass::epilogue::fusion::LinearCombination",
    EpilogueFunctor3x.LinearCombinationBlockScaleFactor: "cutlass::epilogue::fusion::LinCombBlockScaleFactor",
}


# TMA epilogues have certain alignment requirements as calculated in get_tma_alignment(data_type)
def is_tma_epilogue(epilogue_schedule_type):
    return epilogue_schedule_type in [
        EpilogueScheduleType.ScheduleAuto,
        EpilogueScheduleType.TmaWarpSpecialized,
        EpilogueScheduleType.TmaWarpSpecializedCooperative,
        EpilogueScheduleType.TmaWarpSpecialized1Sm,
        EpilogueScheduleType.TmaWarpSpecialized2Sm,
        EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
        EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
        EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
        EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
    ]


def to_grouped_schedule(schedule, grouped):
    if not grouped:
        return schedule

    group_schedule_map = {
        # SM90
        KernelScheduleType.TmaWarpSpecializedCooperative: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
        KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative,
        KernelScheduleType.TmaWarpSpecializedPingpong: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
        KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
        KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum,
        EpilogueScheduleType.TmaWarpSpecialized: EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
        EpilogueScheduleType.TmaWarpSpecializedCooperative: EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
        EpilogueScheduleType.NoSmemWarpSpecialized: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized,
        # SM100
        KernelScheduleType.TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100,
        KernelScheduleType.TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100,
        KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100,
        KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100,
        KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100,
        KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100,
        KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100,
        KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100,
        KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100,
        KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100,
        EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
        EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
    }

    return group_schedule_map[schedule]


class TileSchedulerType(enum.Enum):
    Default = enum_auto()
    Persistent = enum_auto()
    StreamK = enum_auto()


#
TileSchedulerTag = {
    TileSchedulerType.Default: "void",
    TileSchedulerType.Persistent: "cutlass::gemm::PersistentScheduler",
    TileSchedulerType.StreamK: "cutlass::gemm::StreamKScheduler",
}

#
TileSchedulerSuffixes = {
    TileSchedulerType.Default: "",
    TileSchedulerType.Persistent: "",
    TileSchedulerType.StreamK: "_stream_k",
}

###################################################################################################


#
class SideMode(enum.Enum):
    Left = enum_auto()
    Right = enum_auto()


#
SideModeTag = {
    SideMode.Left: "cutlass::SideMode::kLeft",
    SideMode.Right: "cutlass::SideMode::kRight",
}

#
ShortSideModeNames = {SideMode.Left: "ls", SideMode.Right: "rs"}

###################################################################################################


#
class FillMode(enum.Enum):
    Lower = enum_auto()
    Upper = enum_auto()


#
FillModeTag = {
    FillMode.Lower: "cutlass::FillMode::kLower",
    FillMode.Upper: "cutlass::FillMode::kUpper",
}

#
ShortFillModeNames = {FillMode.Lower: "l", FillMode.Upper: "u"}

###################################################################################################


#
class DiagType(enum.Enum):
    NonUnit = enum_auto()
    Unit = enum_auto()


#
DiagTypeTag = {
    DiagType.NonUnit: "cutlass::DiagType::kNonUnit",
    DiagType.Unit: "cutlass::DiagType::kUnit",
}

#
ShortDiagTypeNames = {DiagType.NonUnit: "nu", DiagType.Unit: "un"}

###################################################################################################


#
class OpcodeClass(enum.Enum):
    Simt = enum_auto()
    TensorOp = enum_auto()
    WmmaTensorOp = enum_auto()
    SparseTensorOp = enum_auto()
    BlockScaledTensorOp = enum_auto()


OpcodeClassNames = {
    OpcodeClass.Simt: "simt",
    OpcodeClass.TensorOp: "tensorop",
    OpcodeClass.WmmaTensorOp: "wmma_tensorop",
    OpcodeClass.SparseTensorOp: "sptensorop",
    OpcodeClass.BlockScaledTensorOp: "bstensorop",
}

OpcodeClassTag = {
    OpcodeClass.Simt: "cutlass::arch::OpClassSimt",
    OpcodeClass.TensorOp: "cutlass::arch::OpClassTensorOp",
    OpcodeClass.WmmaTensorOp: "cutlass::arch::OpClassWmmaTensorOp",
    OpcodeClass.SparseTensorOp: "cutlass::arch::OpClassSparseTensorOp",
    OpcodeClass.BlockScaledTensorOp: "cutlass::arch::OpClassBlockScaledTensorOp",
}

###################################################################################################


#
class OperationKind(enum.Enum):
    Gemm = enum_auto()
    RankK = enum_auto()
    Rank2K = enum_auto()
    Trmm = enum_auto()
    Symm = enum_auto()
    Conv2d = enum_auto()
    Conv3d = enum_auto()


#
OperationKindNames = {
    OperationKind.Gemm: "gemm",
    OperationKind.RankK: "rank_k",
    OperationKind.Rank2K: "rank_2k",
    OperationKind.Trmm: "trmm",
    OperationKind.Symm: "symm",
    OperationKind.Conv2d: "conv2d",
    OperationKind.Conv3d: "conv3d",
}


#
class Target(enum.Enum):
    library = enum_auto()


#
ArchitectureNames = {
    50: "maxwell",
    60: "pascal",
    61: "pascal",
    70: "volta",
    75: "turing",
    80: "ampere",
    89: "ada",
    90: "hopper",
}

#
SharedMemPerCC = {
    70: 96,  #  96KB of SMEM
    72: 96,  #  96KB of SMEM
    75: 64,  #  64KB of SMEM
    80: 163,  # 163KB of SMEM - 1KB reserved for the driver
    86: 99,  #  99KB of SMEM - 1KB reserved for the driver
    87: 163,  # 163KB of SMEM - 1KB reserved for the driver
    89: 99,  #  99KB of SMEM - 1KB reserved for the driver
    90: 227,  # 227KB of SMEM - 1KB reserved for the driver
}

###################################################################################################


#
def SubstituteTemplate(template, values):
    text = template
    changed = True
    while changed:
        changed = False
        for key, value in values.items():
            regex = "\\$\\{%s\\}" % key
            newtext = re.sub(regex, value, text)
            if newtext != text:
                changed = True
            text = newtext
    return text


###################################################################################################


#
class GemmKind(enum.Enum):
    Gemm = enum_auto()
    Sparse = enum_auto()
    Universal = enum_auto()
    Universal3x = enum_auto()
    SparseUniversal3x = enum_auto()
    PlanarComplex = enum_auto()
    PlanarComplexArray = enum_auto()
    Grouped = enum_auto()
    BlockScaledUniversal3x = enum_auto()
    GroupedUniversal3x = enum_auto()
    GroupedBlockScaledUniversal3x = enum_auto()
    BlockwiseUniversal3x = enum_auto()
    GroupedBlockwiseUniversal3x = enum_auto()


#
GemmKindNames = {
    GemmKind.Gemm: "gemm",
    GemmKind.Sparse: "spgemm",
    GemmKind.Universal: "gemm",
    GemmKind.Universal3x: "gemm",
    GemmKind.SparseUniversal3x: "spgemm",
    GemmKind.PlanarComplex: "gemm_planar_complex",
    GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
    GemmKind.Grouped: "gemm_grouped",
    GemmKind.BlockScaledUniversal3x: "gemm",
    GemmKind.GroupedUniversal3x: "gemm_grouped",
    GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped",
    GemmKind.BlockwiseUniversal3x: "gemm",
    GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped",
}


#
class RankKKind(enum.Enum):
    Universal = enum_auto()


#
RankKKindNames = {RankKKind.Universal: "rank_k"}


#
class TrmmKind(enum.Enum):
    Universal = enum_auto()


#
TrmmKindNames = {TrmmKind.Universal: "trmm"}


#
class SymmKind(enum.Enum):
    Universal = enum_auto()


#
SymmKindNames = {SymmKind.Universal: "symm"}


#
class EpilogueFunctor(enum.Enum):
    LinearCombination = enum_auto()
    LinearCombinationClamp = enum_auto()


#
EpilogueFunctorTag = {
    EpilogueFunctor.LinearCombination: "cutlass::epilogue::thread::LinearCombination",
    EpilogueFunctor.LinearCombinationClamp: "cutlass::epilogue::thread::LinearCombinationClamp",
}


#
class MixedInputMode(enum.Enum):
    ConvertOnly = enum_auto()
    ScaleOnly = enum_auto()
    ScaleWithZeroPoint = enum_auto()


#
class SwizzlingFunctor(enum.Enum):
    Identity1 = enum_auto()
    Identity2 = enum_auto()
    Identity4 = enum_auto()
    Identity8 = enum_auto()
    Horizontal = enum_auto()
    StridedDgradIdentity1 = enum_auto()
    StridedDgradIdentity4 = enum_auto()
    StridedDgradHorizontal = enum_auto()
    StreamK = enum_auto()


#
SwizzlingFunctorTag = {
    SwizzlingFunctor.Identity1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>",
    SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>",
    SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
    SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>",
    SwizzlingFunctor.Horizontal: "cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle",
    SwizzlingFunctor.StridedDgradIdentity1: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>",
    SwizzlingFunctor.StridedDgradIdentity4: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>",
    SwizzlingFunctor.StridedDgradHorizontal: "cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle",
    SwizzlingFunctor.StreamK: "cutlass::gemm::threadblock::ThreadblockSwizzleStreamK",
}


#
class GroupScheduleMode(enum.Enum):
    Device = (enum_auto(),)
    Host = enum_auto()


#
GroupScheduleModeTag = {
    GroupScheduleMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly",
    GroupScheduleMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute",
}

#
ShortGroupScheduleModeNames = {
    GroupScheduleMode.Device: "Device",
    GroupScheduleMode.Host: "Host",
}

###################################################################################################


#
class ConvKind(enum.IntEnum):
    Fprop = 0
    Dgrad = 1
    Wgrad = 2


#
ConvKindTag = {
    ConvKind.Fprop: "cutlass::conv::Operator::kFprop",
    ConvKind.Dgrad: "cutlass::conv::Operator::kDgrad",
    ConvKind.Wgrad: "cutlass::conv::Operator::kWgrad",
}

ConvKindNames = {
    ConvKind.Fprop: "fprop",
    ConvKind.Dgrad: "dgrad",
    ConvKind.Wgrad: "wgrad",
}


class ConvMode(enum.IntEnum):
    CrossCorrelation = 0
    Convolution = 1


#
class IteratorAlgorithm(enum.Enum):
    Analytic = 0
    Optimized = 1
    FixedChannels = 2
    FewChannels = 3
    FixedStrideDilation = 4


#
IteratorAlgorithmTag = {
    IteratorAlgorithm.Analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic",
    IteratorAlgorithm.Optimized: "cutlass::conv::IteratorAlgorithm::kOptimized",
    IteratorAlgorithm.FixedChannels: "cutlass::conv::IteratorAlgorithm::kFixedChannels",
    IteratorAlgorithm.FewChannels: "cutlass::conv::IteratorAlgorithm::kFewChannels",
    IteratorAlgorithm.FixedStrideDilation: "cutlass::conv::IteratorAlgorithm::kFixedStrideDilation",
}

IteratorAlgorithmNames = {
    IteratorAlgorithm.Analytic: "analytic",
    IteratorAlgorithm.Optimized: "optimized",
    IteratorAlgorithm.FixedChannels: "fixed_channels",
    IteratorAlgorithm.FewChannels: "few_channels",
    IteratorAlgorithm.FixedStrideDilation: "fixed_stride_dilation",
}


#
class StrideSupport(enum.Enum):
    Strided = 0
    Unity = 1
    Fixed = 2


#
StrideSupportTag = {
    StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided",
    StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity",
    StrideSupport.Fixed: "cutlass::conv::StrideSupport::kFixed",
}

StrideSupportNames = {
    StrideSupport.Strided: "",
    StrideSupport.Unity: "unity_stride",
    StrideSupport.Fixed: "fixed_stride",
}


#
class GroupMode(enum.Enum):
    NoneGroup = enum_auto()  # dense conv (G=1)
    SingleGroup = enum_auto()  # grouped convolution (single group per CTA)
    MultipleGroup = enum_auto()  # grouped convolution ( multiple groups per CTA)
    Depthwise = enum_auto()  # Depthwise convolution ( C=K=G )


#
GroupModeTag = {
    GroupMode.NoneGroup: "cutlass::conv::GroupMode::kNone",
    GroupMode.SingleGroup: "cutlass::conv::GroupMode::kSingleGroup",
    GroupMode.MultipleGroup: "cutlass::conv::GroupMode::kMultipleGroup",
    GroupMode.Depthwise: "cutlass::conv::GroupMode::kDepthwise",
}

GroupModeNames = {
    GroupMode.NoneGroup: "",
    GroupMode.SingleGroup: "single_group",
    GroupMode.MultipleGroup: "multiple_group",
    GroupMode.Depthwise: "depthwise",
}

DynamicClusterShape = [0, 0, 1]

###################################################################################################


#
class MathInstruction:
    def __init__(
        self,
        instruction_shape,
        element_a,
        element_b,
        element_accumulator,
        opcode_class,
        math_operation=MathOperation.multiply_add,
        element_scale_factor=None,
    ):
        self.instruction_shape = instruction_shape
        self.element_a = element_a
        self.element_b = element_b
        self.element_accumulator = element_accumulator
        self.opcode_class = opcode_class
        self.math_operation = math_operation
        self.element_scale_factor = element_scale_factor


#
class TileDescription:
    def __init__(
        self,
        threadblock_shape,
        stages,
        warp_count,
        math_instruction,
        min_compute,
        max_compute,
        cluster_shape=(1, 1, 1),
        explicit_vector_sizes=None,
    ):
        self.threadblock_shape = threadblock_shape
        self.tile_shape = threadblock_shape
        self.stages = stages
        self.warp_count = warp_count
        self.math_instruction = math_instruction
        self.minimum_compute_capability = min_compute
        self.maximum_compute_capability = max_compute
        self.cluster_shape = cluster_shape
        self.explicit_vector_sizes = explicit_vector_sizes

    def procedural_name(self):
        if self.minimum_compute_capability >= 90:
            return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format(
                tbm=self.threadblock_shape[0],
                tbn=self.threadblock_shape[1],
                tbk=self.threadblock_shape[2],
                cm=self.cluster_shape[0],
                cn=self.cluster_shape[1],
                ck=self.cluster_shape[2],
                s=self.stages,
            )
        else:
            return "%dx%d_%dx%d" % (
                self.threadblock_shape[0],
                self.threadblock_shape[1],
                self.threadblock_shape[2],
                self.stages,
            )


class Direct2dConvFixedStrideDilationTileDescription:
    def __init__(
        self,
        threadblock_output_shape,
        filter_shape,
        stages,
        stride,
        dilation,
        warp_count,
        math_instruction,
        min_compute,
        max_compute,
    ):
        self.threadblock_shape = [
            threadblock_output_shape[0]
            * threadblock_output_shape[1]
            * threadblock_output_shape[2],
            threadblock_output_shape[3],
            filter_shape[0] * filter_shape[1],
        ]
        self.threadblock_output_shape = threadblock_output_shape
        self.filter_shape = filter_shape
        self.stages = stages
        self.warp_count = warp_count
        self.stride = stride
        self.dilation = dilation
        self.math_instruction = math_instruction
        self.minimum_compute_capability = min_compute
        self.maximum_compute_capability = max_compute

    def procedural_name(self):
        str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (
            self.threadblock_shape[0],
            self.threadblock_shape[1],
            self.threadblock_shape[2],
            self.threadblock_output_shape[0],
            self.threadblock_output_shape[1],
            self.threadblock_output_shape[2],
            self.threadblock_output_shape[3],
            self.stages,
            self.filter_shape[0],
            self.filter_shape[1],
        )
        # Fixed Strided and dilation
        if self.stride != [-1, -1] and self.dilation != [-1, -1]:
            str_name += "_stride%dx%d_dilation%dx%d" % (
                self.stride[0],
                self.stride[1],
                self.dilation[0],
                self.dilation[1],
            )
        return str_name


#
class TensorDescription:
    def __init__(
        self, element, layout, alignment=1, complex_transform=ComplexTransform.none
    ):
        self.element = element
        self.layout = layout
        self.alignment = alignment
        self.complex_transform = complex_transform


#
class SymmetricTensorDescription:
    def __init__(
        self,
        element,
        layout,
        fill_mode,
        alignment=1,
        complex_transform=ComplexTransform.none,
        side_mode=SideMode.Left,
    ):
        self.element = element
        self.layout = layout
        self.fill_mode = fill_mode
        self.alignment = alignment
        self.complex_transform = complex_transform
        self.side_mode = side_mode


#
class TriangularTensorDescription:
    def __init__(
        self,
        element,
        layout,
        side_mode,
        fill_mode,
        diag_type,
        alignment=1,
        complex_transform=ComplexTransform.none,
    ):
        self.element = element
        self.layout = layout
        self.side_mode = side_mode
        self.fill_mode = fill_mode
        self.diag_type = diag_type
        self.alignment = alignment
        self.complex_transform = complex_transform


#
def CalculateSmemUsage(operation):
    cta_shape = operation.tile_description.threadblock_shape
    stages = operation.tile_description.stages

    if (
        operation.operation_kind == OperationKind.Gemm
        and operation.gemm_kind == GemmKind.Sparse
    ):
        # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity)
        if DataTypeSize[operation.A.element] == 32:
            elements_per_8b_md = 2
        elif DataTypeSize[operation.A.element] == 4:
            elements_per_8b_md = 8
        else:
            elements_per_8b_md = 4

        smem_per_stage = (
            DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8
            + DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8
            + cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md
        )
    else:
        # Few BLAS3 operations only have A tensor
        data_type_size_a = DataTypeSize[operation.A.element]
        data_type_size_b = DataTypeSize[operation.A.element]
        if operation.is_mixed_input():
            data_type_size_b = DataTypeSize[operation.B.element]

        smem_per_stage = (
            data_type_size_a * cta_shape[0] * cta_shape[2] // 8
            + data_type_size_b * cta_shape[1] * cta_shape[2] // 8
        )

    smem_usage = smem_per_stage * stages
    return smem_usage >> 10


class GemmUniversalMode(enum.IntEnum):
    """
    Types corresponding to GemmUniversalMode
    """

    Gemm = 0
    GemmSplitKParallel = 1
    Batched = 2
    Array = 3


class SplitKMode(enum.IntEnum):
    """
    Types corresponding to SplitKMode
    """

    NoneSplitK = 0
    Serial = 1
    Parallel = 2
