#!/usr/bin/env python

# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import nsysstats


class NetworkDevicesCongestion(nsysstats.StatsReport):
    DEFAULT_TICKS_THRESHOLD = 10000

    display_name = "Network Devices Congestion"
    usage = f"""{{SCRIPT}}[:ticks_threshold=<ticks_per_ms>] -- {{DISPLAY_NAME}}

    ticks_threshold=<ticks_per_ms> - Threshold in ticks/ms above which we report
        congestion. Default is {DEFAULT_TICKS_THRESHOLD}.

    Output: All time values default to nanoseconds
        Start : Start timestamp of congestion interval
        End : End timestamp of congestion interval
        Duration : Duration of congestion interval
        Send wait rate: Rate of congestion during the interval
        GUID : The device GUID
        Name : The device name

    This report displays congestion events with a high send wait rate. By
    default, only events with a send wait rate above {DEFAULT_TICKS_THRESHOLD} ticks/ms are shown,
    but a custom threshold value can be set.

    Each event defines a period of time when the device experienced some level
    of congestion. The level of congestion is defined by the send wait rate,
    given in time ticks per millisecond (ticks/ms). The specific duration of a
    tick is device specific, but can be assumed to be nanoseconds in scale.
    Congestion is measured by counting the number of ticks during which the port
    had data to transmit, but no data was sent because of insufficient credits
    or because of lack of arbitration. The presented value of send wait rate is
    the amount of ticks counted during an event, normalized over the event's
    duration. Higher send wait rate values indicate more congestion.

    Because the specific duration of a tick is device dependent, analysis
    should focus on the relative send wait rates of events generated by the same
    device. Comparing absolute send wait rates across devices is only meaningful
    if the time tick duration is known to be similar.

    For IB Switch metrics, we do not present the device name, only the GUID.
"""

    query_stub = """
        WITH
            recs AS (
                {NETWORKING_SUBQUERY}
            )
            SELECT
                start AS "Start:ts_ns",
                end AS "End:ts_ns",
                duration AS "Duration:dur_ns",
                value AS "Send wait rate (ticks/ms)",
                printf('%x', guid) AS "GUID",
                label AS "Device name"
            FROM
                    recs
            ORDER BY start;
"""

    query_nics_congestion = """
        SELECT
            nmetric.start AS start,
            nmetric.end AS end,
            nmetric.end - nmetric.start AS duration,
            nmetric.value AS value,
            nicinfo.GUID AS guid,
            deviceid.label AS label
        FROM
            NET_NIC_METRIC AS nmetric
        JOIN
            NIC_ID_MAP
            USING (globalId)
        JOIN
            TARGET_INFO_NIC_INFO AS nicinfo
            USING (nicId)
        JOIN
            TARGET_INFO_NETWORK_METRICS AS netmetricsinfo
            ON nmetric.metricsListId == netmetricsinfo.metricsListId
                AND nmetric.metricsIdx == netmetricsinfo.metricsIdx
        JOIN
            ENUM_NET_DEVICE_ID AS deviceid
            ON nicinfo.deviceId == deviceid.id
        WHERE netmetricsinfo.name == 'IB: Send waits'
            AND nmetric.value > {TICKS_THRESHOLD}
"""

    query_ib_switches_congestion = """
        SELECT
            smetric.start AS start,
            smetric.end AS end,
            smetric.end - smetric.start AS duration,
            smetric.value AS value,
            smetric.globalId AS guid,
            NULL AS label
        FROM
            NET_IB_SWITCH_METRIC as smetric
        JOIN
            TARGET_INFO_NETWORK_METRICS AS netmetricsinfo
            ON smetric.metricsListId == netmetricsinfo.metricsListId
                AND smetric.metricsIdx == netmetricsinfo.metricsIdx
        WHERE netmetricsinfo.name == 'IB: Send waits'
            AND smetric.value > {TICKS_THRESHOLD}
"""

    query_union = """
        UNION ALL
"""

    _arg_opts = [
        [
            ["ticks_threshold"],
            {
                "type": int,
                "help": "ticks threshold",
                "default": DEFAULT_TICKS_THRESHOLD,
            },
        ],
    ]

    def check_table_existence(self, tables):
        for table_title, message in tables.items():
            if not self.table_exists(table_title):
                return False, message
        return True, ""

    def setup(self):
        err = super().setup()
        if err is not None:
            return err

        sub_queries = []

        if self.table_exists("NET_NIC_METRIC"):
            table_checks = {
                "ENUM_NET_DEVICE_ID": "{DBFILE} does not contain network device IDs.",
                "NIC_ID_MAP": "{DBFILE} does not contain NIC ID map.",
                "TARGET_INFO_NETWORK_METRICS": "{DBFILE} file does not contain network metrics information table.",
                "TARGET_INFO_NIC_INFO": "{DBFILE} does not contain NIC info data.",
            }
            tables_exist, message = self.check_table_existence(table_checks)
            if not tables_exist:
                return message

            sub_queries.append(
                self.query_nics_congestion.format(
                    TICKS_THRESHOLD=self.parsed_args.ticks_threshold
                )
            )

        if self.table_exists("NET_IB_SWITCH_METRIC"):
            table_checks = {
                "TARGET_INFO_NETWORK_METRICS": "{DBFILE} file does not contain network metrics information table."
            }
            tables_exist, message = self.check_table_existence(table_checks)
            if not tables_exist:
                return message

            sub_queries.append(
                self.query_ib_switches_congestion.format(
                    TICKS_THRESHOLD=self.parsed_args.ticks_threshold
                )
            )

        if len(sub_queries) == 0:
            return "{DBFILE} does not contain NIC or IB switch metrics."

        self.query = self.query_stub.format(
            NETWORKING_SUBQUERY=self.query_union.join(sub_queries)
        )


if __name__ == "__main__":
    NetworkDevicesCongestion.Main()
