scp.py
463 lines
| 15.9 KiB
| text/x-python
|
PythonLexer
|
r573 | # scp.py | |
# Copyright (C) 2008 James Bardin <j.bardin@gmail.com> | |||
""" | |||
Utilities for sending files over ssh using the scp1 protocol. | |||
""" | |||
__version__ = '0.10.0' | |||
import locale | |||
import os | |||
import re | |||
from socket import timeout as SocketTimeout | |||
# this is quote from the shlex module, added in py3.3 | |||
_find_unsafe = re.compile(br'[^\w@%+=:,./~-]').search | |||
def _sh_quote(s): | |||
"""Return a shell-escaped version of the string `s`.""" | |||
if not s: | |||
return b"" | |||
if _find_unsafe(s) is None: | |||
return s | |||
# use single quotes, and put single quotes into double quotes | |||
# the string $'b is then quoted as '$'"'"'b' | |||
return b"'" + s.replace(b"'", b"'\"'\"'") + b"'" | |||
# Unicode conversion functions; assume UTF-8 | |||
def asbytes(s): | |||
"""Turns unicode into bytes, if needed. | |||
Assumes UTF-8. | |||
""" | |||
if isinstance(s, bytes): | |||
return s | |||
else: | |||
return s.encode('utf-8') | |||
def asunicode(s): | |||
"""Turns bytes into unicode, if needed. | |||
Uses UTF-8. | |||
""" | |||
if isinstance(s, bytes): | |||
return s.decode('utf-8', 'replace') | |||
else: | |||
return s | |||
# os.path.sep is unicode on Python 3, no matter the platform | |||
bytes_sep = asbytes(os.path.sep) | |||
# Unicode conversion function for Windows | |||
# Used to convert local paths if the local machine is Windows | |||
def asunicode_win(s): | |||
"""Turns bytes into unicode, if needed. | |||
""" | |||
if isinstance(s, bytes): | |||
return s.decode(locale.getpreferredencoding()) | |||
else: | |||
return s | |||
class SCPClient(object): | |||
""" | |||
An scp1 implementation, compatible with openssh scp. | |||
Raises SCPException for all transport related errors. Local filesystem | |||
and OS errors pass through. | |||
Main public methods are .put and .get | |||
The get method is controlled by the remote scp instance, and behaves | |||
accordingly. This means that symlinks are resolved, and the transfer is | |||
halted after too many levels of symlinks are detected. | |||
The put method uses os.walk for recursion, and sends files accordingly. | |||
Since scp doesn't support symlinks, we send file symlinks as the file | |||
(matching scp behaviour), but we make no attempt at symlinked directories. | |||
""" | |||
def __init__(self, transport, buff_size=16384, socket_timeout=5.0, | |||
progress=None, sanitize=_sh_quote): | |||
""" | |||
Create an scp1 client. | |||
@param transport: an existing paramiko L{Transport} | |||
@type transport: L{Transport} | |||
@param buff_size: size of the scp send buffer. | |||
@type buff_size: int | |||
@param socket_timeout: channel socket timeout in seconds | |||
@type socket_timeout: float | |||
@param progress: callback - called with (filename, size, sent) during | |||
transfers | |||
@param sanitize: function - called with filename, should return | |||
safe or escaped string. Uses _sh_quote by default. | |||
@type progress: function(string, int, int) | |||
""" | |||
self.transport = transport | |||
self.buff_size = buff_size | |||
self.socket_timeout = socket_timeout | |||
self.channel = None | |||
self.preserve_times = False | |||
self._progress = progress | |||
self._recv_dir = b'' | |||
self._rename = False | |||
self._utime = None | |||
self.sanitize = sanitize | |||
self._dirtimes = {} | |||
def __enter__(self): | |||
self.channel = self._open() | |||
return self | |||
def __exit__(self, type, value, traceback): | |||
self.close() | |||
def put(self, files, remote_path=b'.', | |||
recursive=False, preserve_times=False): | |||
""" | |||
Transfer files to remote host. | |||
@param files: A single path, or a list of paths to be transfered. | |||
recursive must be True to transfer directories. | |||
@type files: string OR list of strings | |||
@param remote_path: path in which to receive the files on the remote | |||
host. defaults to '.' | |||
@type remote_path: str | |||
@param recursive: transfer files and directories recursively | |||
@type recursive: bool | |||
@param preserve_times: preserve mtime and atime of transfered files | |||
and directories. | |||
@type preserve_times: bool | |||
""" | |||
self.preserve_times = preserve_times | |||
self.channel = self._open() | |||
self._pushed = 0 | |||
self.channel.settimeout(self.socket_timeout) | |||
scp_command = (b'scp -t ', b'scp -r -t ')[recursive] | |||
self.channel.exec_command(scp_command + | |||
self.sanitize(asbytes(remote_path))) | |||
self._recv_confirm() | |||
if not isinstance(files, (list, tuple)): | |||
files = [files] | |||
if recursive: | |||
self._send_recursive(files) | |||
else: | |||
self._send_files(files) | |||
self.close() | |||
def get(self, remote_path, local_path='', | |||
recursive=False, preserve_times=False): | |||
""" | |||
Transfer files from remote host to localhost | |||
@param remote_path: path to retreive from remote host. since this is | |||
evaluated by scp on the remote host, shell wildcards and | |||
environment variables may be used. | |||
@type remote_path: str | |||
@param local_path: path in which to receive files locally | |||
@type local_path: str | |||
@param recursive: transfer files and directories recursively | |||
@type recursive: bool | |||
@param preserve_times: preserve mtime and atime of transfered files | |||
and directories. | |||
@type preserve_times: bool | |||
""" | |||
if not isinstance(remote_path, (list, tuple)): | |||
remote_path = [remote_path] | |||
remote_path = [self.sanitize(asbytes(r)) for r in remote_path] | |||
self._recv_dir = local_path or os.getcwd() | |||
self._rename = (len(remote_path) == 1 and | |||
not os.path.isdir(os.path.abspath(local_path))) | |||
if len(remote_path) > 1: | |||
if not os.path.exists(self._recv_dir): | |||
raise SCPException("Local path '%s' does not exist" % | |||
asunicode(self._recv_dir)) | |||
elif not os.path.isdir(self._recv_dir): | |||
raise SCPException("Local path '%s' is not a directory" % | |||
asunicode(self._recv_dir)) | |||
rcsv = (b'', b' -r')[recursive] | |||
prsv = (b'', b' -p')[preserve_times] | |||
self.channel = self._open() | |||
self._pushed = 0 | |||
self.channel.settimeout(self.socket_timeout) | |||
self.channel.exec_command(b"scp" + | |||
rcsv + | |||
prsv + | |||
b" -f " + | |||
b' '.join(remote_path)) | |||
self._recv_all() | |||
self.close() | |||
def _open(self): | |||
"""open a scp channel""" | |||
if self.channel is None: | |||
self.channel = self.transport.open_session() | |||
return self.channel | |||
def close(self): | |||
"""close scp channel""" | |||
if self.channel is not None: | |||
self.channel.close() | |||
self.channel = None | |||
def _read_stats(self, name): | |||
"""return just the file stats needed for scp""" | |||
if os.name == 'nt': | |||
name = asunicode(name) | |||
stats = os.stat(name) | |||
mode = oct(stats.st_mode)[-4:] | |||
size = stats.st_size | |||
atime = int(stats.st_atime) | |||
mtime = int(stats.st_mtime) | |||
return (mode, size, mtime, atime) | |||
def _send_files(self, files): | |||
for name in files: | |||
basename = asbytes(os.path.basename(name)) | |||
(mode, size, mtime, atime) = self._read_stats(name) | |||
if self.preserve_times: | |||
self._send_time(mtime, atime) | |||
file_hdl = open(name, 'rb') | |||
# The protocol can't handle \n in the filename. | |||
# Quote them as the control sequence \^J for now, | |||
# which is how openssh handles it. | |||
self.channel.sendall(("C%s %d " % (mode, size)).encode('ascii') + | |||
basename.replace(b'\n', b'\\^J') + b"\n") | |||
self._recv_confirm() | |||
file_pos = 0 | |||
if self._progress: | |||
if size == 0: | |||
# avoid divide-by-zero | |||
self._progress(basename, 1, 1) | |||
else: | |||
self._progress(basename, size, 0) | |||
buff_size = self.buff_size | |||
chan = self.channel | |||
while file_pos < size: | |||
chan.sendall(file_hdl.read(buff_size)) | |||
file_pos = file_hdl.tell() | |||
if self._progress: | |||
self._progress(basename, size, file_pos) | |||
chan.sendall('\x00') | |||
file_hdl.close() | |||
self._recv_confirm() | |||
def _chdir(self, from_dir, to_dir): | |||
# Pop until we're one level up from our next push. | |||
# Push *once* into to_dir. | |||
# This is dependent on the depth-first traversal from os.walk | |||
# add path.sep to each when checking the prefix, so we can use | |||
# path.dirname after | |||
common = os.path.commonprefix([from_dir + bytes_sep, | |||
to_dir + bytes_sep]) | |||
# now take the dirname, since commonprefix is character based, | |||
# and we either have a seperator, or a partial name | |||
common = os.path.dirname(common) | |||
cur_dir = from_dir.rstrip(bytes_sep) | |||
while cur_dir != common: | |||
cur_dir = os.path.split(cur_dir)[0] | |||
self._send_popd() | |||
# now we're in our common base directory, so on | |||
self._send_pushd(to_dir) | |||
def _send_recursive(self, files): | |||
for base in files: | |||
if not os.path.isdir(base): | |||
# filename mixed into the bunch | |||
self._send_files([base]) | |||
continue | |||
last_dir = asbytes(base) | |||
for root, dirs, fls in os.walk(base): | |||
self._chdir(last_dir, asbytes(root)) | |||
self._send_files([os.path.join(root, f) for f in fls]) | |||
last_dir = asbytes(root) | |||
# back out of the directory | |||
while self._pushed > 0: | |||
self._send_popd() | |||
def _send_pushd(self, directory): | |||
(mode, size, mtime, atime) = self._read_stats(directory) | |||
basename = asbytes(os.path.basename(directory)) | |||
if self.preserve_times: | |||
self._send_time(mtime, atime) | |||
self.channel.sendall(('D%s 0 ' % mode).encode('ascii') + | |||
basename.replace(b'\n', b'\\^J') + b'\n') | |||
self._recv_confirm() | |||
self._pushed += 1 | |||
def _send_popd(self): | |||
self.channel.sendall('E\n') | |||
self._recv_confirm() | |||
self._pushed -= 1 | |||
def _send_time(self, mtime, atime): | |||
self.channel.sendall(('T%d 0 %d 0\n' % (mtime, atime)).encode('ascii')) | |||
self._recv_confirm() | |||
def _recv_confirm(self): | |||
# read scp response | |||
msg = b'' | |||
try: | |||
msg = self.channel.recv(512) | |||
except SocketTimeout: | |||
raise SCPException('Timout waiting for scp response') | |||
# slice off the first byte, so this compare will work in py2 and py3 | |||
if msg and msg[0:1] == b'\x00': | |||
return | |||
elif msg and msg[0:1] == b'\x01': | |||
raise SCPException(asunicode(msg[1:])) | |||
elif self.channel.recv_stderr_ready(): | |||
msg = self.channel.recv_stderr(512) | |||
raise SCPException(asunicode(msg)) | |||
elif not msg: | |||
raise SCPException('No response from server') | |||
else: | |||
raise SCPException('Invalid response from server', msg) | |||
def _recv_all(self): | |||
# loop over scp commands, and receive as necessary | |||
command = {b'C': self._recv_file, | |||
b'T': self._set_time, | |||
b'D': self._recv_pushd, | |||
b'E': self._recv_popd} | |||
while not self.channel.closed: | |||
# wait for command as long as we're open | |||
self.channel.sendall('\x00') | |||
msg = self.channel.recv(1024) | |||
if not msg: # chan closed while recving | |||
break | |||
assert msg[-1:] == b'\n' | |||
msg = msg[:-1] | |||
code = msg[0:1] | |||
try: | |||
command[code](msg[1:]) | |||
except KeyError: | |||
raise SCPException(asunicode(msg[1:])) | |||
# directory times can't be set until we're done writing files | |||
self._set_dirtimes() | |||
def _set_time(self, cmd): | |||
try: | |||
times = cmd.split(b' ') | |||
mtime = int(times[0]) | |||
atime = int(times[2]) or mtime | |||
except: | |||
self.channel.send(b'\x01') | |||
raise SCPException('Bad time format') | |||
# save for later | |||
self._utime = (atime, mtime) | |||
def _recv_file(self, cmd): | |||
chan = self.channel | |||
parts = cmd.strip().split(b' ', 2) | |||
try: | |||
mode = int(parts[0], 8) | |||
size = int(parts[1]) | |||
if self._rename: | |||
path = self._recv_dir | |||
self._rename = False | |||
elif os.name == 'nt': | |||
path = os.path.join(asunicode_win(self._recv_dir), | |||
parts[2].decode('utf-8')) | |||
else: | |||
path = os.path.join(asbytes(self._recv_dir), | |||
parts[2]) | |||
except: | |||
chan.send('\x01') | |||
chan.close() | |||
raise SCPException('Bad file format') | |||
try: | |||
file_hdl = open(path, 'wb') | |||
except IOError as e: | |||
chan.send(b'\x01' + str(e).encode('utf-8')) | |||
chan.close() | |||
raise | |||
if self._progress: | |||
if size == 0: | |||
# avoid divide-by-zero | |||
self._progress(path, 1, 1) | |||
else: | |||
self._progress(path, size, 0) | |||
buff_size = self.buff_size | |||
pos = 0 | |||
chan.send(b'\x00') | |||
try: | |||
while pos < size: | |||
# we have to make sure we don't read the final byte | |||
if size - pos <= buff_size: | |||
buff_size = size - pos | |||
file_hdl.write(chan.recv(buff_size)) | |||
pos = file_hdl.tell() | |||
if self._progress: | |||
self._progress(path, size, pos) | |||
msg = chan.recv(512) | |||
if msg and msg[0:1] != b'\x00': | |||
raise SCPException(asunicode(msg[1:])) | |||
except SocketTimeout: | |||
chan.close() | |||
raise SCPException('Error receiving, socket.timeout') | |||
file_hdl.truncate() | |||
try: | |||
os.utime(path, self._utime) | |||
self._utime = None | |||
os.chmod(path, mode) | |||
# should we notify the other end? | |||
finally: | |||
file_hdl.close() | |||
# '\x00' confirmation sent in _recv_all | |||
def _recv_pushd(self, cmd): | |||
parts = cmd.split(b' ', 2) | |||
try: | |||
mode = int(parts[0], 8) | |||
if self._rename: | |||
path = self._recv_dir | |||
self._rename = False | |||
elif os.name == 'nt': | |||
path = os.path.join(asunicode_win(self._recv_dir), | |||
parts[2].decode('utf-8')) | |||
else: | |||
path = os.path.join(asbytes(self._recv_dir), | |||
parts[2]) | |||
except: | |||
self.channel.send(b'\x01') | |||
raise SCPException('Bad directory format') | |||
try: | |||
if not os.path.exists(path): | |||
os.mkdir(path, mode) | |||
elif os.path.isdir(path): | |||
os.chmod(path, mode) | |||
else: | |||
raise SCPException('%s: Not a directory' % path) | |||
self._dirtimes[path] = (self._utime) | |||
self._utime = None | |||
self._recv_dir = path | |||
except (OSError, SCPException) as e: | |||
self.channel.send(b'\x01' + asbytes(str(e))) | |||
raise | |||
def _recv_popd(self, *cmd): | |||
self._recv_dir = os.path.split(self._recv_dir)[0] | |||
def _set_dirtimes(self): | |||
try: | |||
for d in self._dirtimes: | |||
os.utime(d, self._dirtimes[d]) | |||
finally: | |||
self._dirtimes = {} | |||
class SCPException(Exception): | |||
"""SCP exception class""" | |||
pass |