#!/bin/bash
set -u

function print_usage()
{
    cat << EOF
$(basename "$0") -- CUDA GPU Trace

    No arguments.

    Output:
        Start : Start time of trace event in seconds
        Duration : Length of event in nanoseconds
        CorrId : Correlation ID
        GrdX, GrdY, GrdZ : Grid values
        BlkX, BlkY, BlkZ : Block values
        Reg/Trd : Registers per thread
        StcSMem : Size of Static Shared Memory
        DymSMem : Size of Dynamic Shared Memory
        Bytes : Size of memory operation
        Thru : Throughput in MB per Second
        SrcMemKd : Memcpy source memory kind or memset memory kind
        DstMemKd : Memcpy destination memory kind
        Device : GPU device name and ID
        Ctx : Context ID
        Strm : Stream ID
        Name : Trace event name

    This report displays a trace of CUDA kernels and memory operations.
    Items are sorted by start time.
EOF
}

### BEGIN include inc_setup ###

EXIT_HELP=25
EXIT_DB=26
EXIT_NODATA=27

# Verify number of params
if [ $# -lt 1 ]
then
    print_usage ${BASH_SOURCE[0]}
    exit ${EXIT_HELP}
fi

# Set DB file
DATABASE="$1"

# Verify DB file exists
if [ ! -f "${DATABASE}" ]
then
    exit ${EXIT_DB}
fi

# Verify DB file contents
# The sqlite3 file format is defined at https://sqlite.org/fileformat.html
DB_FILE_HEADER=$(head -c 16 "$DATABASE" | tr '\0' '\n')
if [ "${DB_FILE_HEADER}" != "SQLite format 3" ]
then
    exit ${EXIT_DB}
fi

# Helper function for error messages
function echoerr() # accepts multiple args
{
    echo "$@" >&2
}

# Setup standard vars

# If we were run by nsys, the path to the preferred sqlite3 should have been
# passed as an env-var.  If not, hope the user has it in their path.
SQLITE3="${NSYS_STATS_SCRIPTS_SQLITE:-sqlite3}"
SQLITE3OPTS="-header -csv -readonly"

RUN_SQLITE="eval \"${SQLITE3}\" ${SQLITE3OPTS} \"${DATABASE}\""

### END include inc_setup ###

### BEGIN: include from inc_table_exists ###

TABLE_EXISTS_TABLES=( )

function table_exists()
{
    local TABLE_NAME=$1

    if [ "${#TABLE_EXISTS_TABLES[@]}" -eq 0 ]
    then
        TABLE_EXISTS_TABLES=( $("${SQLITE3}" ${SQLITE3OPTS} "${DATABASE}" \
                "SELECT name FROM sqlite_master WHERE type = 'table' OR type = 'view'") )
    fi

    for TABLE in "${TABLE_EXISTS_TABLES[@]}"
    do
        if [ "${TABLE}" = "${TABLE_NAME}" ]
        then
            echo "true"
            return 1
        fi
    done
    echo "false"
    return 0
}

### END: include from inc_table_exists ###

### BEGIN include inc_helper_cte ###

MemKindStrsCTE="""
    MemKindStrs (id, name) AS (
    VALUES
        (0,     'Pageable'),
        (1,     'Pinned'),
        (2,     'Device'),
        (3,     'Array'),
        (4,     'Unknown')
    ),
"""

MemcpyOperStrsCTE="""
    MemcpyOperStrs (id, name) AS (
    VALUES
        (0,     '[CUDA memcpy Unknown]'),
        (1,     '[CUDA memcpy HtoD]'),
        (2,     '[CUDA memcpy DtoH]'),
        (3,     '[CUDA memcpy HtoA]'),
        (4,     '[CUDA memcpy AtoH]'),
        (5,     '[CUDA memcpy AtoA]'),
        (6,     '[CUDA memcpy AtoD]'),
        (7,     '[CUDA memcpy DtoA]'),
        (8,     '[CUDA memcpy DtoD]'),
        (9,     '[CUDA memcpy HtoH]'),
        (10,    '[CUDA memcpy PtoP]'),
        (11,    '[CUDA Unified Memory memcpy HtoD]'),
        (12,    '[CUDA Unified Memory memcpy DtoH]'),
        (13,    '[CUDA Unified Memory memcpy DtoD]')
    ),
"""

### END include inc_helper_cte ###


QUERY=()

if $(table_exists "CUPTI_ACTIVITY_KIND_KERNEL")
then
    if [ ${#QUERY[@]} -gt 0 ]
    then
        QUERY+=("UNION ALL")
    fi

    Q="""
        SELECT
            start AS "start",
            (end - start) AS "duration",
            gridX AS "gridX",
            gridY AS "gridY",
            gridZ AS "gridZ",
            blockX AS "blockX",
            blockY AS "blockY",
            blockZ AS "blockZ",
            registersPerThread AS "regsperthread",
            staticSharedMemory AS "ssmembytes",
            dynamicSharedMemory AS "dsmembytes",
            NULL AS "bytes",
            NULL AS "srcmemkind",
            NULL AS "dstmemkind",
            NULL AS "memsetval",
            printf('%s (%d)', gpu.name, deviceId) AS "device",
            contextId AS "context",
            streamId AS "stream",
            dmn.value AS "name",
            correlationId AS "correlation"
        FROM
            CUPTI_ACTIVITY_KIND_KERNEL
        LEFT JOIN
            StringIds AS dmn
            ON CUPTI_ACTIVITY_KIND_KERNEL.demangledName = dmn.id
        LEFT JOIN
            TARGET_INFO_CUDA_GPU AS gpu
            USING( deviceId )
    """
    QUERY+=("$Q")
fi

if $(table_exists "CUPTI_ACTIVITY_KIND_MEMCPY")
then
    if [ ${#QUERY[@]} -gt 0 ]
    then
        QUERY+=("UNION ALL")
    fi

    Q="""
        SELECT
            start AS "start",
            (end - start) AS "duration",
            NULL AS "gridX",
            NULL AS "gridY",
            NULL AS "gridZ",
            NULL AS "blockX",
            NULL AS "blockY",
            NULL AS "blockZ",
            NULL AS "regsperthread",
            NULL AS "ssmembytes",
            NULL AS "dsmembytes",
            bytes AS "bytes",
            msrck.name AS "srcmemkind",
            mdstk.name AS "dstmemkind",
            NULL AS "memsetval",
            printf('%s (%d)', gpu.name, deviceId) AS "device",
            contextId AS "context",
            streamId AS "stream",
            memopstr.name AS "name",
            correlationId AS "correlation"
        FROM
            CUPTI_ACTIVITY_KIND_MEMCPY AS memcpy
        LEFT JOIN
            MemcpyOperStrs AS memopstr
            ON memcpy.copyKind = memopstr.id
        LEFT JOIN
            MemKindStrs AS msrck
            ON memcpy.srcKind = msrck.id
        LEFT JOIN
            MemKindStrs AS mdstk
            ON memcpy.dstKind = mdstk.id
        LEFT JOIN
            TARGET_INFO_CUDA_GPU AS gpu
            USING( deviceId )
    """
    QUERY+=("$Q")
fi

if $(table_exists "CUPTI_ACTIVITY_KIND_MEMSET")
then
    if [ ${#QUERY[@]} -gt 0 ]
    then
        QUERY+=("UNION ALL")
    fi

    Q="""
        SELECT
            start AS "start",
            (end - start) AS "duration",
            NULL AS "gridX",
            NULL AS "gridY",
            NULL AS "gridZ",
            NULL AS "blockX",
            NULL AS "blockY",
            NULL AS "blockZ",
            NULL AS "regsperthread",
            NULL AS "ssmembytes",
            NULL AS "dsmembytes",
            bytes AS "bytes",
            mk.name AS "srcmemkind",
            NULL AS "dstmemkind",
            value AS "memsetval",
            printf('%s (%d)', gpu.name, deviceId) AS "device",
            contextId AS "context",
            streamId AS "stream",
            '[CUDA memset]' AS "name",
            correlationId AS "correlation"
        FROM
            CUPTI_ACTIVITY_KIND_MEMSET AS memset
        LEFT JOIN
            MemKindStrs AS mk
            ON memset.memKind = mk.id
        LEFT JOIN
            TARGET_INFO_CUDA_GPU AS gpu
            USING( deviceId )
    """
    QUERY+=("$Q")
fi

if [ ${#QUERY[@]} -eq 0 ]
then
    echoerr "$DATABASE does not contain trace data."
    exit ${EXIT_NODATA}
fi

${RUN_SQLITE} << EOF

WITH
    ${MemKindStrsCTE}
    ${MemcpyOperStrsCTE}
    recs AS (
        ${QUERY[@]}
    )
    SELECT
        printf('%.6f', start / 1000000000.0 ) AS "Start(sec)",
        duration AS "Duration(nsec)",
        correlation AS "CorrId",
        gridX AS "GrdX",
        gridY AS "GrdY",
        gridZ AS "GrdZ",
        blockX AS "BlkX",
        blockY AS "BlkY",
        blockZ AS "BlkZ",
        regsperthread AS "Reg/Trd",
        ssmembytes AS "StcSMem",
        dsmembytes AS "DymSMem",
        bytes AS "Bytes",
        CASE
            WHEN bytes IS NULL
                THEN ''
            ELSE
                printf('%.3f', (bytes * 1000.0) / duration)
        END AS "Thru(MB/s)",
        srcmemkind AS "SrcMemKd",
        dstmemkind AS "DstMemKd",
        device AS "Device",
        context AS "Ctx",
        stream AS "Strm",
        name AS "Name"
    FROM
            recs
    ORDER BY start;

EOF
