#!/usr/bin/env python

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
#
# See LICENSE for more details.

"""
Helper script that runs and interacts with the process executed by aexpect
"""

import os
import sys
import logging
import pty
import tempfile
import select

from aexpect.shared import BASE_DIR
from aexpect.shared import get_filenames
from aexpect.shared import get_reader_filename
from aexpect.shared import get_lock_fd
from aexpect.shared import unlock_fd
from aexpect.shared import wait_for_lock
from aexpect.shared import makestandard
from aexpect.shared import makeraw


def main():     # too-many-* pylint:disable=R0914,R0912,R0915
    """Wait for commands and then handle communication with the process"""
    a_id = sys.stdin.readline().strip()
    echo = sys.stdin.readline().strip() == "True"
    readers = sys.stdin.readline().strip().split(",")
    command = sys.stdin.readline().strip() + f" && echo {a_id} > /dev/null"

    base_dir = os.path.join(BASE_DIR, f"aexpect_{a_id}")

    # Define filenames to be used for communication
    (shell_pid_filename,
     status_filename,
     output_filename,
     inpipe_filename,
     ctrlpipe_filename,
     lock_server_running_filename,
     lock_client_starting_filename,
     log_filename) = get_filenames(base_dir)

    assert os.path.isdir(base_dir)

    logging_format = '%(asctime)s %(levelname)-5.5s| %(message)s'
    date_format = '%m/%d %H:%M:%S'
    logging.basicConfig(filename=log_filename, level=logging.DEBUG,
                        format=logging_format, datefmt=date_format)
    server_log = logging.getLogger()

    server_log.info('Server %s starting with parameters', str(a_id))
    server_log.info('echo: %s', str(echo))
    server_log.info('readers: %s', str(readers))
    server_log.info('command: %s', str(command))

    # Populate the reader filenames list
    reader_filenames = [get_reader_filename(base_dir, reader)
                        for reader in readers]

    # Set $TERM = dumb
    os.putenv("TERM", "dumb")

    server_log.info('Forking child process for command')
    (shell_pid, shell_fd) = pty.fork()
    if shell_pid == 0:
        # Child process: run the command in a subshell
        if len(command) > 255:
            new_stack = None
            if len(command) > 2000000:
                # Stack size would probably not suffice (and no open files)
                # (1 + len(command) * 4 / 8290304) * 8196
                # 2MB => 8196kb, 4MB => 16392, ...
                new_stack = (1 + len(command) / 2072576) * 8196
                command = f"ulimit -s {new_stack}\nulimit -n 819200\n{command}"
            tmp_file = tempfile.mktemp(suffix='.sh',
                                       prefix='aexpect-', dir=base_dir)
            with open(tmp_file, "w", encoding="utf-8") as fd_cmd:
                fd_cmd.write(command)
            os.execv("/bin/bash", ["/bin/bash", "-c", f"source {tmp_file}"])
            os.remove(tmp_file)
        else:
            os.execv("/bin/bash", ["/bin/bash", "-c", command])
    else:
        # Parent process
        server_log.info('Acquiring server lock on %s',
                        lock_server_running_filename)
        lock_server_running = get_lock_fd(lock_server_running_filename)

        # Set terminal echo on/off and disable pre- and post-processing
        makestandard(shell_fd, echo)

        server_log.info('Opening output file %s', output_filename)
        with open(output_filename, "wb") as output_file:
            server_log.info('Opening input pipe %s', inpipe_filename)
            os.mkfifo(inpipe_filename)
            inpipe_fd = os.open(inpipe_filename, os.O_RDWR)
            server_log.info('Opening control pipe %s', ctrlpipe_filename)
            os.mkfifo(ctrlpipe_filename)
            ctrlpipe_fd = os.open(ctrlpipe_filename, os.O_RDWR)
            # Open output pipes (readers)
            reader_fds = []
            for filename in reader_filenames:
                server_log.info('Opening output pipe %s', filename)
                os.mkfifo(filename)
                reader_fds.append(os.open(filename, os.O_RDWR))
            server_log.info('Reader fd list: %s', reader_fds)

            # Write shell PID to file
            server_log.info('Writing shell PID file %s', shell_pid_filename)
            with open(shell_pid_filename, "w", encoding="utf-8") as file_obj:
                file_obj.write(str(shell_pid))

            # Print something to stdout so the client can start working
            print(f"Server {a_id} ready")    # py3k pylint: disable=C0325
            sys.stdout.flush()

            # Initialize buffers
            buffers = [b"" for reader in readers]

            # Read from child and write to files/pipes
            server_log.info('Entering main read loop')
            status = 1  # lgtm [py/multiple-definition]
            while True:
                check_termination = False
                # Make a list of reader pipes whose buffers are not empty
                fds = [reader_fd for (i, reader_fd) in enumerate(reader_fds)
                       if buffers[i]]
                # Wait until there's something to do
                r_fds, w_fds = select.select([shell_fd, inpipe_fd,
                                              ctrlpipe_fd],
                                             fds, [], 0.5)[:2]
                # If a reader pipe is ready for writing --
                for (i, reader_fd) in enumerate(reader_fds):
                    if reader_fd in w_fds:
                        bytes_written = os.write(reader_fd, buffers[i])
                        buffers[i] = buffers[i][bytes_written:]
                if ctrlpipe_fd in r_fds:
                    cmd_len = int(os.read(ctrlpipe_fd, 10))
                    data = os.read(ctrlpipe_fd, cmd_len)
                    if data == "raw":
                        makeraw(shell_fd)
                    elif data == "standard":
                        makestandard(shell_fd, echo)
                # If there's data to read from the child process --
                if shell_fd in r_fds:
                    try:
                        data = os.read(shell_fd, 16384)
                    except OSError:
                        data = b""
                    if not data:
                        check_termination = True
                    # Remove carriage returns from the data; they often cause
                    # trouble and are normally not needed
                    data = data.replace(b"\r", b"")
                    output_file.write(data)
                    output_file.flush()
                    for i in range(len(readers)):
                        buffers[i] += data
                # os.read() raised an exception or there was nothing to read
                if check_termination or shell_fd not in r_fds:
                    pid, status = os.waitpid(shell_pid, os.WNOHANG)
                    if pid:
                        status = os.WEXITSTATUS(status)
                        break
                # If there's data to read from the client --
                if inpipe_fd in r_fds:
                    data = os.read(inpipe_fd, 1024)
                    os.write(shell_fd, data)

            server_log.info('Out of the main read loop. Writing status to %s',
                            status_filename)
            with open(status_filename, "w", encoding="utf-8") as file_obj:
                file_obj.write(str(status))

            # Wait for the client to finish initializing
            wait_for_lock(lock_client_starting_filename)

        # Close all files and pipes
        os.close(inpipe_fd)
        server_log.info('Closed input pipe')
        for reader_fd in reader_fds:
            os.close(reader_fd)
            server_log.info('Closed reader fd %s', reader_fd)

        unlock_fd(lock_server_running)
        server_log.info('Exiting normally')
        sys.exit(0)


if __name__ == "__main__":
    main()
