python3实现类似expect shell的交互式与SFTP的脚本

前面写过一篇关于python实现类似expect shell的交互式能力的文章,现在补全一下加上sftp的能力脚本。
例子在代码中__example()方法。

依赖paramiko库,所以需要执行pip install paramiko来安装。

import os
import queue
import re
import threading
import time
import traceback
import stat
import datetime

import paramiko
from paramiko import SSHClient, SSHException, SFTPClient, Channel


class SFTPMultipleClient(object):
    """
    支持多sftp连接拉取文件
    支持快捷执行远程命令
    """

    def __init__(self, host, port, username, pwd, work_count=1) -> None:
        super().__init__()
        # client = SSHClient()
        # client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        # client.connect(hostname=host, timeout=60, port=port, username=username, password=pwd)
        # self.ssh_client = client
        # self.sftp_client = client.open_sftp()

        self.host = host
        self.port = port
        self.username = username
        self.pwd = pwd
        self.work_count = work_count

        self.thread_local_data = threading.local()
        self.pull_queue = queue.Queue()
        self.ssh_clientlist = []
        self.sftp_clientlist = []
        self.thread_list = {}
        self.email_send_files = []
        self.ptySessions = []
        (self.ssh_client, self.sftp_client) = self._gen_sftp_client(host, port, username,
                                                                    pwd)  # type:(SSHClient, SFTPClient)

        for i in range(work_count):
            self.thread_list[self._start_back_task(self._pull_event_handler)] = False

    def _start_back_task(self, fun, args=()):
        t = threading.Thread(target=fun, daemon=True, args=args)
        t.start()
        return t

    def remote_scp_progress(self, a, b):
        """
        远程scp进度打印
        :param a: 当前已传输大小(单位字节)
        :param b: 文件总大小(单位字节)
        """
        if a > b:
            a = b
        second = self.thread_local_data.timer.last_press_time_delta
        if second > 1 or a == b:
            self.thread_local_data.timer.press()
            speed = a - self.thread_local_data.last_progress
            self.thread_local_data.last_progress = a
            remaining_time_second = (b - a) / speed
            s = self.thread_local_data.get_file_desc + "%s  %s %02d%s %dKB/s %02d:%02d " % \
                (a, b, int(a / b * 100), "%", speed / 1024, remaining_time_second / 60, remaining_time_second % 60)
            print(s)
        # print(s, end="", flush=True)
        # _back = ""
        # for i in range(len(s)):
        #     _back += '\b'
        # print(_back, end="")

    def _gen_sftp_client(self, host, port, username, pwd) -> (SSHClient, SFTPClient):
        """
        生成sftp客户端
        当连接断开会进行重试,最大重试连接5次
        """
        _try_count = 0
        _ssh_client = SSHClient()
        _ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        _sftp_client = None
        tname = threading.current_thread().name
        stime = 1
        while _try_count < 5:
            connect_failed = False
            try:
                _ssh_client.connect(hostname=host, timeout=60, port=port, username=username,
                                    password=pwd)
                _sftp_client = _ssh_client.open_sftp()
                print(tname + " ssh连接成功.")
                break
            except Exception as e:
                connect_failed = True
                traceback.print_exc()
            if connect_failed:
                try:
                    _ssh_client.close()
                    if _sftp_client:
                        _sftp_client.close()
                except Exception as e:
                    traceback.print_exc()
                _try_count += 1
                if _try_count == 5:
                    raise Exception("尝试重新连接达到最大次数.断开连接.")
                traceback.print_exc()
                print(tname + " 第%d次重新连接ssh..." % _try_count)
                stime *= 2
                time.sleep(stime)

        self.ssh_clientlist.append(_ssh_client)
        self.sftp_clientlist.append(_sftp_client)
        self.thread_local_data.ssh_client = _ssh_client
        self.thread_local_data.sftp_client = _sftp_client
        return _ssh_client, _sftp_client

    def _pull_event_handler(self):
        """
        拉取任务处理
        """
        self._gen_sftp_client(self.host, self.port, self.username, self.pwd)
        # self.thread_local_data.ssh_client = _ssh_client
        # self.thread_local_data.sftp_client = _sftp_client
        while True:
            self.thread_list[threading.current_thread()] = False
            data = self.pull_queue.get(block=True, timeout=None)
            self.thread_list[threading.current_thread()] = True
            # 通过队列插入None来结束线程,如果有n个线程在监听队列,那么需要n个None来结束
            if data is None:
                self.thread_list[threading.current_thread()] = False
                break
            try:
                self._remote_scp(self.thread_local_data.ssh_client, self.thread_local_data.sftp_client,
                                 data.remote_path,
                                 data.local_path, data.max_size, callback=self.remote_scp_progress)
                fun = data.call_bak_fun
                if fun:
                    call_bak_fun_args = ()
                    call_bak_fun_kwargs = {}
                    if type(fun) in [list, tuple]:
                        if not callable(fun[0]):
                            raise Exception("call_bak_fun 不是可调用对象")
                        if len(fun) > 1:
                            call_bak_fun_args = fun[1]
                        if len(fun) > 2:
                            call_bak_fun_kwargs = fun[2]
                    else:
                        if not callable(fun):
                            raise Exception("call_bak_fun 不是可调用对象")
                    fun[0](*call_bak_fun_args, **call_bak_fun_kwargs)
            except Exception as e:
                traceback.print_exc()
                # pull_queue.put(data)
        # print("线程%s 结束." % threading.current_thread().name)

    def _remote_scp(self, _ssh_client, _sftp_client, _remote_path, _local_path, max_size=None, callback=None):
        reconnect = False
        try:
            if os.path.exists(_local_path):
                rf_stat = _sftp_client.lstat(_remote_path)
                lf_stat = os.stat(_local_path)
                if rf_stat.st_size == lf_stat.st_size:
                    print(_local_path + " already exists.")
                    return
                if max_size and rf_stat.st_size > max_size:  # 如果大于1m则认为是空板图片
                    print(_local_path + " > " + max_size + " skipped.")
                    return
            print(threading.current_thread().name + "copy file:%s << %s\t" % (_local_path, _remote_path))
            self.thread_local_data.get_file_desc = "copy file:%s << %s\t" % (_local_path, _remote_path)
            self.thread_local_data.timer = Timer()
            self.thread_local_data.last_progress = 0
            _sftp_client.get(_remote_path, _local_path, callback=callback)
        except FileNotFoundError as e:
            traceback.print_exc()
            print("continue...")
        except SSHException as e:
            traceback.print_exc()
            reconnect = True
        except OSError as e:
            print("os error====")
            traceback.print_exc()
            reconnect = True
        if reconnect:
            print("重新连接ssh...")
            _sftp_client.close()
            _ssh_client.close()
            self.sftp_clientlist.remove(_sftp_client)
            self.ssh_clientlist.remove(_ssh_client)
            _ssh_client, _sftp_client = self._gen_sftp_client(self.host, self.port, self.username, self.pwd)
            self._remote_scp(_ssh_client, _sftp_client, _remote_path, _local_path, max_size, callback)
        if callback:
            print()

    def pull_file(self, remote_path, local_path, max_size=None, print_progress=False):
        """
        拉取文件
        :param remote_path: 远程文件路径
        :param local_path: 本地文件路径
        :param max_size: 远程文件如果超过了这个大小,则不做拉取
        :param print_progress: 是否打印文件拉取进度
        """
        if print_progress:
            self._remote_scp(self.ssh_client, self.sftp_client, remote_path, local_path, max_size,
                             self.remote_scp_progress)
        else:
            self._remote_scp(self.ssh_client, self.sftp_client, remote_path, local_path, max_size)

    def submit_pull_work(self, remote_path, local_path, max_size=None, call_bak_fun=None):
        """
        提交拉取文件任务

        :param remote_path: 远程文件路径
        :param local_path: 本地文件路径
        :param max_size: 远程文件如果超过了这个大小,则不做拉取
        :param call_bak_fun: 拉取完成回调方法
        """
        e_data = Task(remote_path, local_path, max_size, call_bak_fun)
        self.pull_queue.put(e_data)

    def exec_command(self, command, *args, raise_e=True, **kwargs):
        """
        执行远程命令
        """
        stdin, stdout, stderr = self.ssh_client.exec_command(command, *args, **kwargs)
        if raise_e and stdout.channel.recv_exit_status() != 0:
            raise Exception(stderr.readline())
        return stdin, stdout, stderr

    def mkdir(self, path, recursive=True, mode=0o777):
        """
        创建文件夹
        :param path: 远程文件夹
        :param recursive: 是否递归创建 默认true
        :param mode: 权限,默认0o777-全部权限
        """
        if not recursive:
            self.sftp_client.mkdir(path, mode)
            return
        dirs = path.split(os.path.sep)
        current_dir = "/"
        for i, d in enumerate(dirs):
            d = os.path.join(current_dir, d)
            create = True
            try:
                if stat.S_ISDIR(self.sftp_client.stat(d).st_mode):
                    create = False
            except FileNotFoundError:
                create = True

            if create:
                self.sftp_client.mkdir(d, mode)
            current_dir = d

    def destroy(self):
        """
        销毁当前实例
        """
        while not self.pull_queue.empty():  # 等待队列任务处理完毕
            time.sleep(0.2)
        # 结束线程
        for thread in self.thread_list:
            self.pull_queue.put(None)  # 通过None来停止线程
            while True:
                if thread.is_alive() and self.thread_list[thread]:
                    time.sleep(1)
                else:
                    break
                # self.thread_list.pop(thread)

        # for i in range(len(self.thread_list)):
        #     self.pull_queue.put(None)  # 通过None来停止线程
        #     self.thread_list.pop()  # 移除成员
        for ptySession in self.ptySessions:
            ptySession.destroy()
        for i in range(len(self.sftp_clientlist)):
            sftp_client = self.sftp_clientlist.pop()
            sftp_client.close()
        for i in range(len(self.ssh_clientlist)):
            ssh_client = self.ssh_clientlist.pop()
            ssh_client.close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.destroy()

    def open_pty_session(self, ending_char, timeout=30):
        """
        从当前channel开启一个session并激活
        @return:
        """
        _session = self.ssh_client.get_transport().open_session(timeout=1 * 3600)  # type:Channel
        _session.get_pty()
        _session.invoke_shell()
        _ptysession = PtySession(_session, ending_char=ending_char, timeout=timeout)
        self.ptySessions.append(_ptysession)
        return _ptysession

    def listdir_attr(self, path):
        """
        列出远程目录

        :param path: 远程目录
        :return: list of `.SFTPAttributes` objects
        """
        _reconnect = False
        try:
            return self.sftp_client.listdir_attr(path)
        except SSHException as e:
            traceback.print_exc()
            _reconnect = True
        except OSError as e:
            traceback.print_exc()
            _reconnect = True
        if _reconnect:
            print("重新连接ssh...")
            self.sftp_client.close()
            self.ssh_client.close()
            self.sftp_clientlist.remove(self.sftp_client)
            self.ssh_clientlist.remove(self.ssh_client)
            self.ssh_client, self.sftp_client = self._gen_sftp_client(self.host, self.port, self.username, self.pwd)
            return self.listdir_attr(path)


# 一次读取字节长度
recv_len = 1024 * 5


class Task(object):
    def __init__(self, remote_path, local_path, max_size=None, call_bak_fun=None) -> None:
        super().__init__()
        self.remote_path = remote_path  # 远程路径
        self.local_path = local_path  # 本地路径
        self.max_size = max_size  # 最大大小,当大于这个大小将跳过不做处理
        self.call_bak_fun = call_bak_fun  # 回调方法 格式:(fn,arg,kw)


class PtySession(object):

    def __init__(self, session, ending_char, timeout=30) -> None:
        super().__init__()
        self.session = session  # type:Channel
        self.last_line = ""
        self.ending_char = ending_char  # 每执行一个命令之后,标记输出的结束字符
        self.clear_tail()
        self.timeout = timeout  # 超时时间,秒

    def clear_tail(self, _ending_char=None):
        """
        清理输出还处于缓冲区中未读取的流
        """
        if not _ending_char:
            _ending_char = self.ending_char
        while True:
            time.sleep(0.2)
            # self.session.recv_ready()在读取过程中不一定总是True,只有当读取缓冲流中有字节读取时,才会为True。所以在读取头一次后获取下次流到缓冲区中前为False
            if self.session.recv_ready():
                self.last_line = self.session.recv(recv_len)
                self.last_line = self.last_line.decode('utf-8')
                print(self.last_line)
            if re.search(_ending_char, self.last_line):
                break

    def destroy(self):
        """
        销毁并关闭session
        :return:
        :rtype:
        """
        self.clear_tail()
        self.session.close()

    def exp(self, *exp_cmds):
        """
        期望并执行,与expect的用法类似。
        session.exp(
            (
                "\$", "scp test.txt luckydog@127.0.0.1:~/\r",
                (
                    ("yes/no", "yes\r",
                     ("Password:", "luckydog\r")),
                    ("Password:", "luckydog\r")
                )
            )
        )
        :param exp_cmds: 第一个元素为获取的期望结束字符,第二个元素为需要执行的命令,如果传入的第三个元素,则第三个元素必须为元祖,并且也同父级一样,属递归结构。类似GNU的递归缩写。
        :type exp_cmds: tuple
        """
        interval = 0.2
        cur_time = 0.0
        try:
            while True:
                if self.session.recv_ready():
                    self.last_line = self.session.recv(recv_len).decode('utf-8')
                    print(self.last_line)
                elif self.session.send_ready():
                    for exp_cmd in exp_cmds:
                        _cmd = exp_cmd[1]
                        if not _cmd.endswith("\r"):
                            _cmd += "\r"
                        match = re.search(exp_cmd[0], self.last_line)
                        if match and match.group():
                            self.session.send(_cmd)
                            # 清空最后一行数据缓存,便于下个命令的读取流输出。此行代码去除,会导致无法等待命令执行完毕提前执行后续代码的问题。
                            self.last_line = ""
                            if len(exp_cmd) == 3 and exp_cmd[2]:
                                self.exp(*exp_cmd[2])
                            return
                cur_time += interval
                if cur_time >= self.timeout:
                    raise Exception("timeout...")
                time.sleep(interval)
        except Exception as e:
            traceback.print_exc()
        # finally:
        #     self.clear_tail()

    def send(self, cmd, _ending_char=None):
        """
        单纯的发送命令到目标服务器执行。
        :param cmd: 命令
        :type cmd: str
        """
        self.last_line = ""
        if not cmd.endswith("\r"):
            cmd += "\r"
        self.session.send(cmd)
        self.clear_tail(_ending_char)


# -----------------------
class Timer(object):
    def __init__(self) -> None:
        super().__init__()
        self._start_time = None
        self._end_time = None
        self._seconds_delta = None
        self._last_press_time = None  # 最后一个计次时间

    def start(self):
        if not self._start_time:
            self._start_time = datetime.datetime.now()
            self._last_press_time = self._start_time
        else:
            raise Exception("timer is already started.")
        return self._start_time

    def end(self):
        if not self._end_time:
            self._end_time = datetime.datetime.now()
            self._seconds_delta = (self._end_time - self._start_time).total_seconds()
        else:
            raise Exception("timer is already stopped.")

    def press(self):
        if not self._start_time:
            self._last_press_time = self.start()
            return 0
        now = datetime.datetime.now()
        last_time = self._last_press_time
        self._last_press_time = now
        return (now - last_time).total_seconds()

    @property
    def last_press_time_delta(self):
        if not self._last_press_time:
            self.press()
            return 0
        return (datetime.datetime.now() - self._last_press_time).total_seconds()

    @property
    def start_time(self):
        return self._start_time

    @property
    def end_time(self):
        return self._end_time

    def timedelta_seconds(self):
        if self._seconds_delta:
            return self._seconds_delta
        else:
            raise Exception("timer not stopped.")


# ===================test=========================
def _file_pull_complete(filename, **kw):
    print(filename + "文件上传完成")


def __example():
    sftp = SFTPMultipleClient(host="130.16.16.18", port=22, username="baiyang", pwd="xxx")
    # 创建目录
    sftp.mkdir("/home/username/123/")

    # 下载文件到本地(同步执行)
    sftp.pull_file(
        "/home/username/temp/rumenz.img",  # 远程文件
        "/Users/username/rumenz.img",  # 本地路径文件,需要带上文件名
        print_progress=True  # 打印文件拉取进度
    )

    # 提交下载远程文件异步任务(异步线程执行不阻塞)
    sftp.submit_pull_work(
        "/home/username/temp/rumenz.img",  # 远程文件
        "/Users/username/rumenz.img",  # 本地路径文件,需要带上文件名
        call_bak_fun=(_file_pull_complete, ("rumenz.img",))  # 文件上传完成回调,第一个元素为方法,第二个是入参
    )

    # 列出文件列表
    files = sftp.listdir_attr("/home/baiyang")
    for f in files:
        print(f)

    # 开启交互回话
    session = sftp.open_pty_session("\$")
    # 发送单条命令
    session.send("cd ~/123")
    # # 发送单条命令并等待获取_ending_char符号(结束符支持el)
    session.send("scp test.txt username@130.16.16.133:~/\r", "Password:|yes/no")

    # # 当最后一行为yes/no 则执行yes,如果是Password则输入密码
    session.exp(
        ("yes/no", "yes\r",
         ("Password:", "xxx\r")),
        ("Password:", "xxx\r")
    )

    # scp方式2
    '''
    伪代码
    if (yes/no):
        send yes
        if (Password:)
            send 密码
    elif (Password:)
        send 密码
    '''
    session.exp(
        (
            "\$", "scp test.txt username@130.16.16.133:~/\r",  # 期望$,并发送scp命令
            (
                ("yes/no", "yes\r", (  # 如果返回yes/no,则发送yes
                    "Password:", "xxx\r")),  # 发送yes后,如果返回Password:结尾,则发送密码
                ("Password:", "xxx\r")  # 如果返回的是Password:结尾,则发送密码
            )
        )
    )

    # 销毁
    session.destroy()
    sftp.destroy()

    # 支持通过with的方式使用SFTPMultipleClient,这样不需要显示的调用 sftp.destroy()
    with SFTPMultipleClient(host="130.16.16.18", port=22, username="baiyang", pwd="xxx") as sftp:
        # 列出文件列表
        files = sftp.listdir_attr("/home/baiyang")
        for f in files:
            print(f)


if __name__ == '__main__':
    __example()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值