前面写过一篇关于python实现类似expect shell的交互式能力的文章,现在补全一下加上sftp的能力脚本。
例子在代码中__example()
方法。
依赖paramiko
库,所以需要执行pip install paramiko
来安装。
import os
import queue
import re
import threading
import time
import traceback
import stat
import datetime
import paramiko
from paramiko import SSHClient, SSHException, SFTPClient, Channel
class SFTPMultipleClient(object):
"""
支持多sftp连接拉取文件
支持快捷执行远程命令
"""
def __init__(self, host, port, username, pwd, work_count=1) -> None:
super().__init__()
# client = SSHClient()
# client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
# client.connect(hostname=host, timeout=60, port=port, username=username, password=pwd)
# self.ssh_client = client
# self.sftp_client = client.open_sftp()
self.host = host
self.port = port
self.username = username
self.pwd = pwd
self.work_count = work_count
self.thread_local_data = threading.local()
self.pull_queue = queue.Queue()
self.ssh_clientlist = []
self.sftp_clientlist = []
self.thread_list = {}
self.email_send_files = []
self.ptySessions = []
(self.ssh_client, self.sftp_client) = self._gen_sftp_client(host, port, username,
pwd) # type:(SSHClient, SFTPClient)
for i in range(work_count):
self.thread_list[self._start_back_task(self._pull_event_handler)] = False
def _start_back_task(self, fun, args=()):
t = threading.Thread(target=fun, daemon=True, args=args)
t.start()
return t
def remote_scp_progress(self, a, b):
"""
远程scp进度打印
:param a: 当前已传输大小(单位字节)
:param b: 文件总大小(单位字节)
"""
if a > b:
a = b
second = self.thread_local_data.timer.last_press_time_delta
if second > 1 or a == b:
self.thread_local_data.timer.press()
speed = a - self.thread_local_data.last_progress
self.thread_local_data.last_progress = a
remaining_time_second = (b - a) / speed
s = self.thread_local_data.get_file_desc + "%s %s %02d%s %dKB/s %02d:%02d " % \
(a, b, int(a / b * 100), "%", speed / 1024, remaining_time_second / 60, remaining_time_second % 60)
print(s)
# print(s, end="", flush=True)
# _back = ""
# for i in range(len(s)):
# _back += '\b'
# print(_back, end="")
def _gen_sftp_client(self, host, port, username, pwd) -> (SSHClient, SFTPClient):
"""
生成sftp客户端
当连接断开会进行重试,最大重试连接5次
"""
_try_count = 0
_ssh_client = SSHClient()
_ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
_sftp_client = None
tname = threading.current_thread().name
stime = 1
while _try_count < 5:
connect_failed = False
try:
_ssh_client.connect(hostname=host, timeout=60, port=port, username=username,
password=pwd)
_sftp_client = _ssh_client.open_sftp()
print(tname + " ssh连接成功.")
break
except Exception as e:
connect_failed = True
traceback.print_exc()
if connect_failed:
try:
_ssh_client.close()
if _sftp_client:
_sftp_client.close()
except Exception as e:
traceback.print_exc()
_try_count += 1
if _try_count == 5:
raise Exception("尝试重新连接达到最大次数.断开连接.")
traceback.print_exc()
print(tname + " 第%d次重新连接ssh..." % _try_count)
stime *= 2
time.sleep(stime)
self.ssh_clientlist.append(_ssh_client)
self.sftp_clientlist.append(_sftp_client)
self.thread_local_data.ssh_client = _ssh_client
self.thread_local_data.sftp_client = _sftp_client
return _ssh_client, _sftp_client
def _pull_event_handler(self):
"""
拉取任务处理
"""
self._gen_sftp_client(self.host, self.port, self.username, self.pwd)
# self.thread_local_data.ssh_client = _ssh_client
# self.thread_local_data.sftp_client = _sftp_client
while True:
self.thread_list[threading.current_thread()] = False
data = self.pull_queue.get(block=True, timeout=None)
self.thread_list[threading.current_thread()] = True
# 通过队列插入None来结束线程,如果有n个线程在监听队列,那么需要n个None来结束
if data is None:
self.thread_list[threading.current_thread()] = False
break
try:
self._remote_scp(self.thread_local_data.ssh_client, self.thread_local_data.sftp_client,
data.remote_path,
data.local_path, data.max_size, callback=self.remote_scp_progress)
fun = data.call_bak_fun
if fun:
call_bak_fun_args = ()
call_bak_fun_kwargs = {}
if type(fun) in [list, tuple]:
if not callable(fun[0]):
raise Exception("call_bak_fun 不是可调用对象")
if len(fun) > 1:
call_bak_fun_args = fun[1]
if len(fun) > 2:
call_bak_fun_kwargs = fun[2]
else:
if not callable(fun):
raise Exception("call_bak_fun 不是可调用对象")
fun[0](*call_bak_fun_args, **call_bak_fun_kwargs)
except Exception as e:
traceback.print_exc()
# pull_queue.put(data)
# print("线程%s 结束." % threading.current_thread().name)
def _remote_scp(self, _ssh_client, _sftp_client, _remote_path, _local_path, max_size=None, callback=None):
reconnect = False
try:
if os.path.exists(_local_path):
rf_stat = _sftp_client.lstat(_remote_path)
lf_stat = os.stat(_local_path)
if rf_stat.st_size == lf_stat.st_size:
print(_local_path + " already exists.")
return
if max_size and rf_stat.st_size > max_size: # 如果大于1m则认为是空板图片
print(_local_path + " > " + max_size + " skipped.")
return
print(threading.current_thread().name + "copy file:%s << %s\t" % (_local_path, _remote_path))
self.thread_local_data.get_file_desc = "copy file:%s << %s\t" % (_local_path, _remote_path)
self.thread_local_data.timer = Timer()
self.thread_local_data.last_progress = 0
_sftp_client.get(_remote_path, _local_path, callback=callback)
except FileNotFoundError as e:
traceback.print_exc()
print("continue...")
except SSHException as e:
traceback.print_exc()
reconnect = True
except OSError as e:
print("os error====")
traceback.print_exc()
reconnect = True
if reconnect:
print("重新连接ssh...")
_sftp_client.close()
_ssh_client.close()
self.sftp_clientlist.remove(_sftp_client)
self.ssh_clientlist.remove(_ssh_client)
_ssh_client, _sftp_client = self._gen_sftp_client(self.host, self.port, self.username, self.pwd)
self._remote_scp(_ssh_client, _sftp_client, _remote_path, _local_path, max_size, callback)
if callback:
print()
def pull_file(self, remote_path, local_path, max_size=None, print_progress=False):
"""
拉取文件
:param remote_path: 远程文件路径
:param local_path: 本地文件路径
:param max_size: 远程文件如果超过了这个大小,则不做拉取
:param print_progress: 是否打印文件拉取进度
"""
if print_progress:
self._remote_scp(self.ssh_client, self.sftp_client, remote_path, local_path, max_size,
self.remote_scp_progress)
else:
self._remote_scp(self.ssh_client, self.sftp_client, remote_path, local_path, max_size)
def submit_pull_work(self, remote_path, local_path, max_size=None, call_bak_fun=None):
"""
提交拉取文件任务
:param remote_path: 远程文件路径
:param local_path: 本地文件路径
:param max_size: 远程文件如果超过了这个大小,则不做拉取
:param call_bak_fun: 拉取完成回调方法
"""
e_data = Task(remote_path, local_path, max_size, call_bak_fun)
self.pull_queue.put(e_data)
def exec_command(self, command, *args, raise_e=True, **kwargs):
"""
执行远程命令
"""
stdin, stdout, stderr = self.ssh_client.exec_command(command, *args, **kwargs)
if raise_e and stdout.channel.recv_exit_status() != 0:
raise Exception(stderr.readline())
return stdin, stdout, stderr
def mkdir(self, path, recursive=True, mode=0o777):
"""
创建文件夹
:param path: 远程文件夹
:param recursive: 是否递归创建 默认true
:param mode: 权限,默认0o777-全部权限
"""
if not recursive:
self.sftp_client.mkdir(path, mode)
return
dirs = path.split(os.path.sep)
current_dir = "/"
for i, d in enumerate(dirs):
d = os.path.join(current_dir, d)
create = True
try:
if stat.S_ISDIR(self.sftp_client.stat(d).st_mode):
create = False
except FileNotFoundError:
create = True
if create:
self.sftp_client.mkdir(d, mode)
current_dir = d
def destroy(self):
"""
销毁当前实例
"""
while not self.pull_queue.empty(): # 等待队列任务处理完毕
time.sleep(0.2)
# 结束线程
for thread in self.thread_list:
self.pull_queue.put(None) # 通过None来停止线程
while True:
if thread.is_alive() and self.thread_list[thread]:
time.sleep(1)
else:
break
# self.thread_list.pop(thread)
# for i in range(len(self.thread_list)):
# self.pull_queue.put(None) # 通过None来停止线程
# self.thread_list.pop() # 移除成员
for ptySession in self.ptySessions:
ptySession.destroy()
for i in range(len(self.sftp_clientlist)):
sftp_client = self.sftp_clientlist.pop()
sftp_client.close()
for i in range(len(self.ssh_clientlist)):
ssh_client = self.ssh_clientlist.pop()
ssh_client.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.destroy()
def open_pty_session(self, ending_char, timeout=30):
"""
从当前channel开启一个session并激活
@return:
"""
_session = self.ssh_client.get_transport().open_session(timeout=1 * 3600) # type:Channel
_session.get_pty()
_session.invoke_shell()
_ptysession = PtySession(_session, ending_char=ending_char, timeout=timeout)
self.ptySessions.append(_ptysession)
return _ptysession
def listdir_attr(self, path):
"""
列出远程目录
:param path: 远程目录
:return: list of `.SFTPAttributes` objects
"""
_reconnect = False
try:
return self.sftp_client.listdir_attr(path)
except SSHException as e:
traceback.print_exc()
_reconnect = True
except OSError as e:
traceback.print_exc()
_reconnect = True
if _reconnect:
print("重新连接ssh...")
self.sftp_client.close()
self.ssh_client.close()
self.sftp_clientlist.remove(self.sftp_client)
self.ssh_clientlist.remove(self.ssh_client)
self.ssh_client, self.sftp_client = self._gen_sftp_client(self.host, self.port, self.username, self.pwd)
return self.listdir_attr(path)
# 一次读取字节长度
recv_len = 1024 * 5
class Task(object):
def __init__(self, remote_path, local_path, max_size=None, call_bak_fun=None) -> None:
super().__init__()
self.remote_path = remote_path # 远程路径
self.local_path = local_path # 本地路径
self.max_size = max_size # 最大大小,当大于这个大小将跳过不做处理
self.call_bak_fun = call_bak_fun # 回调方法 格式:(fn,arg,kw)
class PtySession(object):
def __init__(self, session, ending_char, timeout=30) -> None:
super().__init__()
self.session = session # type:Channel
self.last_line = ""
self.ending_char = ending_char # 每执行一个命令之后,标记输出的结束字符
self.clear_tail()
self.timeout = timeout # 超时时间,秒
def clear_tail(self, _ending_char=None):
"""
清理输出还处于缓冲区中未读取的流
"""
if not _ending_char:
_ending_char = self.ending_char
while True:
time.sleep(0.2)
# self.session.recv_ready()在读取过程中不一定总是True,只有当读取缓冲流中有字节读取时,才会为True。所以在读取头一次后获取下次流到缓冲区中前为False
if self.session.recv_ready():
self.last_line = self.session.recv(recv_len)
self.last_line = self.last_line.decode('utf-8')
print(self.last_line)
if re.search(_ending_char, self.last_line):
break
def destroy(self):
"""
销毁并关闭session
:return:
:rtype:
"""
self.clear_tail()
self.session.close()
def exp(self, *exp_cmds):
"""
期望并执行,与expect的用法类似。
session.exp(
(
"\$", "scp test.txt luckydog@127.0.0.1:~/\r",
(
("yes/no", "yes\r",
("Password:", "luckydog\r")),
("Password:", "luckydog\r")
)
)
)
:param exp_cmds: 第一个元素为获取的期望结束字符,第二个元素为需要执行的命令,如果传入的第三个元素,则第三个元素必须为元祖,并且也同父级一样,属递归结构。类似GNU的递归缩写。
:type exp_cmds: tuple
"""
interval = 0.2
cur_time = 0.0
try:
while True:
if self.session.recv_ready():
self.last_line = self.session.recv(recv_len).decode('utf-8')
print(self.last_line)
elif self.session.send_ready():
for exp_cmd in exp_cmds:
_cmd = exp_cmd[1]
if not _cmd.endswith("\r"):
_cmd += "\r"
match = re.search(exp_cmd[0], self.last_line)
if match and match.group():
self.session.send(_cmd)
# 清空最后一行数据缓存,便于下个命令的读取流输出。此行代码去除,会导致无法等待命令执行完毕提前执行后续代码的问题。
self.last_line = ""
if len(exp_cmd) == 3 and exp_cmd[2]:
self.exp(*exp_cmd[2])
return
cur_time += interval
if cur_time >= self.timeout:
raise Exception("timeout...")
time.sleep(interval)
except Exception as e:
traceback.print_exc()
# finally:
# self.clear_tail()
def send(self, cmd, _ending_char=None):
"""
单纯的发送命令到目标服务器执行。
:param cmd: 命令
:type cmd: str
"""
self.last_line = ""
if not cmd.endswith("\r"):
cmd += "\r"
self.session.send(cmd)
self.clear_tail(_ending_char)
# -----------------------
class Timer(object):
def __init__(self) -> None:
super().__init__()
self._start_time = None
self._end_time = None
self._seconds_delta = None
self._last_press_time = None # 最后一个计次时间
def start(self):
if not self._start_time:
self._start_time = datetime.datetime.now()
self._last_press_time = self._start_time
else:
raise Exception("timer is already started.")
return self._start_time
def end(self):
if not self._end_time:
self._end_time = datetime.datetime.now()
self._seconds_delta = (self._end_time - self._start_time).total_seconds()
else:
raise Exception("timer is already stopped.")
def press(self):
if not self._start_time:
self._last_press_time = self.start()
return 0
now = datetime.datetime.now()
last_time = self._last_press_time
self._last_press_time = now
return (now - last_time).total_seconds()
@property
def last_press_time_delta(self):
if not self._last_press_time:
self.press()
return 0
return (datetime.datetime.now() - self._last_press_time).total_seconds()
@property
def start_time(self):
return self._start_time
@property
def end_time(self):
return self._end_time
def timedelta_seconds(self):
if self._seconds_delta:
return self._seconds_delta
else:
raise Exception("timer not stopped.")
# ===================test=========================
def _file_pull_complete(filename, **kw):
print(filename + "文件上传完成")
def __example():
sftp = SFTPMultipleClient(host="130.16.16.18", port=22, username="baiyang", pwd="xxx")
# 创建目录
sftp.mkdir("/home/username/123/")
# 下载文件到本地(同步执行)
sftp.pull_file(
"/home/username/temp/rumenz.img", # 远程文件
"/Users/username/rumenz.img", # 本地路径文件,需要带上文件名
print_progress=True # 打印文件拉取进度
)
# 提交下载远程文件异步任务(异步线程执行不阻塞)
sftp.submit_pull_work(
"/home/username/temp/rumenz.img", # 远程文件
"/Users/username/rumenz.img", # 本地路径文件,需要带上文件名
call_bak_fun=(_file_pull_complete, ("rumenz.img",)) # 文件上传完成回调,第一个元素为方法,第二个是入参
)
# 列出文件列表
files = sftp.listdir_attr("/home/baiyang")
for f in files:
print(f)
# 开启交互回话
session = sftp.open_pty_session("\$")
# 发送单条命令
session.send("cd ~/123")
# # 发送单条命令并等待获取_ending_char符号(结束符支持el)
session.send("scp test.txt username@130.16.16.133:~/\r", "Password:|yes/no")
# # 当最后一行为yes/no 则执行yes,如果是Password则输入密码
session.exp(
("yes/no", "yes\r",
("Password:", "xxx\r")),
("Password:", "xxx\r")
)
# scp方式2
'''
伪代码
if (yes/no):
send yes
if (Password:)
send 密码
elif (Password:)
send 密码
'''
session.exp(
(
"\$", "scp test.txt username@130.16.16.133:~/\r", # 期望$,并发送scp命令
(
("yes/no", "yes\r", ( # 如果返回yes/no,则发送yes
"Password:", "xxx\r")), # 发送yes后,如果返回Password:结尾,则发送密码
("Password:", "xxx\r") # 如果返回的是Password:结尾,则发送密码
)
)
)
# 销毁
session.destroy()
sftp.destroy()
# 支持通过with的方式使用SFTPMultipleClient,这样不需要显示的调用 sftp.destroy()
with SFTPMultipleClient(host="130.16.16.18", port=22, username="baiyang", pwd="xxx") as sftp:
# 列出文件列表
files = sftp.listdir_attr("/home/baiyang")
for f in files:
print(f)
if __name__ == '__main__':
__example()