#!/usr/bin/python3 -su

# Copyright (C) 2025 - 2025 ENCRYPTED SUPPORT LLC <adrelanos@whonix.org>
# See the file COPYING for copying conditions.

# pylint: disable=broad-exception-caught,invalid-name

"""
tor-wait-for-network - Waits to see if an IPv6 address Tor will want
to listen on becomes available and bindable. The script will exit 0 if IPv6
is disabled, an IPv6 address becomes bindable, or ten seconds passes,
whichever is detected first.

This is necessary, because Whonix-Gateway's internal network interface's IPv6
address doesn't show up immediately. If Tor starts before the IPv6 address
appears, and it attempts to listen on that address, it will crash on startup.
Tor is reconfigured by generate-tor-service-defaults-torrc-anondist to not
listen on the IPv6 address if it doesn't exist, but that would result in Tor
never listening on the IPv6 address even if it became available shortly later.
This script allows Tor to wait for the minimum amount of time necessary to be
able to bind to the IPv6 address, while not waiting forever for an IPv6 address
that will never appear.
"""

#### meta start
#### project Whonix
#### category tor
#### gateway_only yes
#### description
## Wait for network to become available so Tor can bind on it.
#### meta end

import os
import socket
import sys
import time
import traceback
import ipaddress
from typing import NoReturn, TextIO
from pathlib import Path


# pylint: disable=too-few-public-methods
class GlobalData:
    """
    Global data for the script.
    """
    ipv6_addr_file_str: str = "/proc/net/if_inet6"
    ipv6_disabled_file_path: Path = Path(
        "/proc/sys/net/ipv6/conf/all/disable_ipv6"
    )
    bindable_ipv6_addrs: set[str] = set()


def suppress_wait_and_exit() -> NoReturn:
    sys.exit(0)


def can_bind_ipv6(addr_text: str, dev: str, scope_hex: str) -> bool:
    """
    Try to bind a TCP/IPv6 socket to the given address on an ephemeral port.
    Returns True if bind succeeds, False otherwise.
    """
    s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
    try:
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

        scope_id = 0
        try:
            scope = int(scope_hex, 16)
        except ValueError:
            scope = 0

        if scope != 0 or addr_text.startswith("fe80:"):
            try:
                scope_id = socket.if_nametoindex(dev)
            except OSError:
                scope_id = 0

        s.bind((addr_text, 0, 0, scope_id))
        return True
    except:
        return False
    finally:
        s.close()


def scan_ipv6_addrs(ipv6_addr_file: TextIO) -> None:
    """
    Checks IPv6 addresses in /proc/net/if_inet6.
    For the first non-loopback address found, checks if we can actually
    bind() to it.

    Exits the program when bind() succeeds. If no non-loopback IPv6 address
    is present or bindable yet, the function returns and the caller should
    try again later.
    """

    ipv6_addr_file.seek(0)
    file_contents_list: list[str] = ipv6_addr_file.read().split("\n")

    """
    file_contents_list = [
        "00000000000000000000000000000001 01 80 10 80 lo",
        "fd17625cf03700020a0027fffe987619 02 40 00 00 eth0",
        "fe800000000000000a0027fffe987619 02 40 20 80 eth0",
        "fe800000000000000a0027fffe2dded2 03 40 20 80 eth1",
        "fd19c33d88bc00000000000000000010 03 60 00 80 eth1",
    ]
    """

    any_candidate: bool = False
    all_bindable: bool = True
    newly_bindable: list[tuple[str, str]] = []

    for line in file_contents_list:
        line = line.strip()
        if not line:
            continue

        parts = line.split()
        if len(parts) < 6:
            continue

        addr_hex, _ifindex, _plen, scope_hex, _flags, dev = parts

        ## Skip loopback; ::1 is usually bindable early and would defeat the
        ## purpose of waiting for the real interface.
        if dev == "lo":
            continue

        ## TODO: Transform this into a standalone script. Move to helper-scripts.
        ##       Supply network interface name on the command line.
        ## HARDCODED: network interface eth1
        if not dev == "eth1":
            continue

        any_candidate = True

        try:
            addr_int = int(addr_hex, 16)
            addr_text = str(ipaddress.IPv6Address(addr_int))
        except ValueError:
            continue

        if can_bind_ipv6(addr_text, dev, scope_hex):
            if addr_text not in GlobalData.bindable_ipv6_addrs:
                GlobalData.bindable_ipv6_addrs.add(addr_text)
                newly_bindable.append((addr_text, dev))
        else:
            print(f"tor-wait-for-network: INFO: Non-loopback IPv6 address '{addr_text}' dev '{dev}' present but not bindable yet.", file=sys.stderr)
            all_bindable = False

    for addr_text, dev in newly_bindable:
        print(
            f"tor-wait-for-network: INFO: Non-loopback IPv6 address '{addr_text}' dev '{dev}' is bindable.",
            file=sys.stderr,
        )

    if any_candidate and all_bindable and GlobalData.bindable_ipv6_addrs:
        print(
            f"tor-wait-for-network: INFO: All non-loopback IPv6 addresses on dev '{dev}' are bindable. Exiting.",
            file=sys.stderr,
        )
        suppress_wait_and_exit()

    return


def main() -> NoReturn:
    """
    Main function.
    """

    if not os.getuid() == 0:
        print("tor-wait-for-network: ERROR: Must be run as root!", file=sys.stderr)
        sys.exit(1)

    check_interval: float = 0.2
    check_max_count: int = 50
    ipv6_addr_file: TextIO | None = None

    try:
        if (
            GlobalData.ipv6_disabled_file_path.read_text(
                encoding="utf-8"
            ).strip()
            == "1"
        ):
            print(
                "tor-wait-for-network: INFO: IPv6 is disabled "
                f"('{GlobalData.ipv6_disabled_file_path}' contained '1'). "
                "Exiting.",
                file=sys.stderr,
            )
            sys.exit(0)
    except Exception:
        print(
            "tor-wait-for-network: WARNING: Could not read "
            f"'{GlobalData.ipv6_disabled_file_path}'! Continuing regardless.",
            file=sys.stderr,
        )

    try:
        # pylint: disable=consider-using-with
        ipv6_addr_file = open(
            GlobalData.ipv6_addr_file_str, "r", encoding="utf-8"
        )
    except Exception:
        print(
            f"tor-wait-for-network: INFO: Cannot open GlobalData.ipv6_addr_file_str '{GlobalData.ipv6_addr_file_str}'. Exiting.",
            file=sys.stderr,
        )
        suppress_wait_and_exit()

    ## Debugging.
    #file_contents_list: list[str] = ipv6_addr_file.read().strip()
    #print(f"tor-wait-for-network: INFO: GlobalData.ipv6_addr_file_str '{GlobalData.ipv6_addr_file_str}' contents:\n{file_contents_list}")
    ## example file_contents_list:
    """
    fe8000000000000002163efffe5e6c00 02 40 20 80     eth0
    00000000000000000000000000000001 01 80 10 80       lo
    """
    ## or:
    """
    fe8000000000000045ae4b6de18be6f3 02 40 20 80     eth0
    fd17625cf0370002065789e9a43f2cfe 02 40 00 00     eth0
    00000000000000000000000000000001 01 80 10 80       lo
    """
    ## or:
    """
    00000000000000000000000000000001 01 80 10 80       lo
    fd17625cf03700020a0027fffe987619 02 40 00 00     eth0
    fe800000000000000a0027fffe987619 02 40 20 80     eth0
    fe800000000000000a0027fffe2dded2 03 40 20 80     eth1
    fd19c33d88bc00000000000000000010 03 60 00 80     eth1
    """

    assert ipv6_addr_file is not None
    scan_ipv6_addrs(ipv6_addr_file)

    ## Unfortunately, the if_inet6 file is not pollable the way
    ## /proc/self/mounts is (listening for POLLIN results in a busy-loop,
    ## listening for POLLPRI results in a hang because polling if_inet6 never
    ## results in a POLLPRI event), so we have to poll manually.
    for _ in range(0, check_max_count):
        time.sleep(check_interval)
        scan_ipv6_addrs(ipv6_addr_file)

    print(
        "tor-wait-for-network: INFO: Non-loopback IPv6 address was not found or not bindable within timeout. Exiting.",
        file=sys.stderr,
    )
    suppress_wait_and_exit()


if __name__ == "__main__":
    main()
