import argparse
import subprocess
import os
import time
import tarfile

def run_command(command):
    """Run a shell command and return its output"""
    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    if result.returncode != 0:
        print(f"Error running command: {command}")
        print(result.stderr.decode('utf-8'))
        return None
    return result.stdout.decode('utf-8')

def collect_diagnostic_data(output_dir):
    """Collect diagnostic data and save it into an archive file"""
    # Gather diagnostic data
    kernel_version = run_command('uname -r')
    kernel_parameters = run_command('cat /proc/cmdline')
    distribution_version = run_command('cat /etc/*-release')
    current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    container_count = run_command('vzlist -H| wc -l')
    memory_info = run_command('free -mw')
    cpu_info = run_command('lscpu; mpstat -P ALL')
    slabinfo = run_command('cat /proc/slabinfo')
    meminfo = run_command('cat /proc/meminfo')
    cgroups = run_command('cat /proc/cgroups')
    numanodes = run_command('numactl --hardware; numastat -mv')
    vmstat = run_command('vmstat 5 3')
    iostat = run_command('iostat -tx 5 3 /dev/[^p]*')
    top_output = run_command('top -n 1 -b')
    #systemd_cgtop = run_command('systemd-cgtop')
    lscgroup = run_command('lscgroup -g memory:/machine.slice')
    current_la = run_command('uptime | grep -o "load average: .*"')
    running_processes = run_command("vzps -eLo ppid,pid,tid,state,wchan:32,veid:32,cmd|awk '$4~/D/'")
    rd_states_count = run_command('ps -eo state | grep -c "^R\|^D"')
    sysrq_command = run_command('echo w > /proc/sysrq-trigger')
    dmesg_output = run_command('dmesg')

    # Save data into a text file
    output_file_path = os.path.join(output_dir, 'diagnostic_data.txt')
    with open(output_file_path, 'w') as output_file:
        output_file.write(f"Kernel Version: {kernel_version}\n")
        output_file.write(f"Kernel Parameters: {kernel_parameters}\n")
        output_file.write(f"Distribution Version: {distribution_version}\n")
        output_file.write(f"Current Time: {current_time}\n")
        output_file.write(f"Container Count: {container_count}\n")
        output_file.write(f"Memory Info:\n{memory_info}\n")
        output_file.write(f"CPU Info:\n{cpu_info}\n")
        output_file.write(f"Slabinfo:\n{slabinfo}\n")
        output_file.write(f"Meminfo:\n{meminfo}\n")
        output_file.write(f"cgroups:\n{cgroups}\n")
        output_file.write(f"Numanodes:\n{numanodes}\n")
        output_file.write(f"Vmstat:\n{vmstat}\n")
        output_file.write(f"IOstat:\n{iostat}\n")
        output_file.write(f"Top Output:\n{top_output}\n")
        #output_file.write(f"Systemd-cgtop:\n{systemd_cgtop}\n")
        output_file.write(f"Lscgroup:\n{lscgroup}\n")
        output_file.write(f"Current LA: {current_la}\n")
        output_file.write(f"Running Processes:\n{running_processes}\n")
        output_file.write(f"R-D States Count: {rd_states_count}\n")
        output_file.write(f"Sysrq Command: {sysrq_command}\n")
        output_file.write(f"Dmesg Output:\n{dmesg_output}\n")

    # Create a tar archive
    archive_file_path = os.path.join(output_dir, 'diagnostic_data.tar.gz')
    with tarfile.open(archive_file_path, 'w:gz') as archive:
        archive.add(output_file_path, arcname=os.path.basename(output_file_path))

    print(f"Diagnostic data saved to: {archive_file_path}")

def check_load_average(now):
    """Check the system load average"""
    if now:
        output_directory = '/vz/tmp'
        collect_diagnostic_data(output_directory)
        return
    while True:
        current_la = float(run_command('uptime | awk -F \'load average: \' \'{print $2}\' | awk -F, \'{print $1}\''))
        if current_la > 1000:
            print("System load average exceeded 1000. Starting diagnostic data collection...")
            output_directory = '/vz/tmp'
            collect_diagnostic_data(output_directory)
            break
        else:
            print(f"System load average is {current_la}. Waiting...")
            time.sleep(60)  # Wait for 1 minute before checking again

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Collect diagnostic data if load average exceeds 1000.')
    parser.add_argument('--now', action='store_true', help='Collect diagnostics immediately')
    args = parser.parse_args()
    check_load_average(args.now)
