diff --git a/blobsend/cli.py b/blobsend/cli.py index eb0ae47..2a66e31 100644 --- a/blobsend/cli.py +++ b/blobsend/cli.py @@ -16,13 +16,14 @@ SCHEMES = { } -def get_client(uri, is_src): +def get_client(uri, extra_args, is_src): clss = SCHEMES[uri.scheme or "file"] - return clss.from_uri(uri, is_src) + return clss.from_uri(uri, extra_args, is_src) def get_args(): parser = argparse.ArgumentParser(description="file blob copy utility") + parser.add_argument("-o", "--options", nargs="+", help="extra arguments") parser.add_argument("src", help="source file uri") parser.add_argument("dest", help="dest file uri") @@ -87,10 +88,14 @@ def send_parallel(src, dest): def main(): args, parser = get_args() - print(args) - src = get_client(urlparse(args.src), True) - dest = get_client(urlparse(args.dest), False) + extra_args = {} + for arg in args.options or []: + name, value = arg.split("=", 1) + extra_args[name] = value + + src = get_client(urlparse(args.src), extra_args, True) + dest = get_client(urlparse(args.dest), extra_args, False) # send(src, dest) send_parallel(src, dest) diff --git a/blobsend/client_file.py b/blobsend/client_file.py index 9307f04..aabc34b 100644 --- a/blobsend/client_file.py +++ b/blobsend/client_file.py @@ -68,7 +68,7 @@ class FileChunkClient(BaseChunkClient): self.fpool.close() @staticmethod - def from_uri(uri, is_src): + def from_uri(uri, extra_args, is_src): """ instantiate a client from the given uri """ diff --git a/blobsend/client_ssh.py b/blobsend/client_ssh.py index 7a8f2da..616c8ad 100644 --- a/blobsend/client_ssh.py +++ b/blobsend/client_ssh.py @@ -1,3 +1,4 @@ +import os import paramiko from blobsend.client_base import BaseChunkClient from blobsend import CHUNK_SIZE, hash_chunk, FilePool @@ -12,15 +13,23 @@ REMOTE_UTILITY = "/Users/dave/code/blobsend/testenv/bin/_blobsend_ssh_remote"# class SshChunkClient(BaseChunkClient): - def __init__(self, server, username, password, fpath, is_src, chunk_size=CHUNK_SIZE): + def __init__(self, server, username, fpath, is_src, chunk_size=CHUNK_SIZE, password=None, sshkey=None): super().__init__(chunk_size) self.fpath = fpath self.ssh = paramiko.SSHClient() self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - self.ssh.connect(hostname=server, - username=username, - password=password) + connect_args = dict( + hostname=server, + username=username + ) + if sshkey: + print("using ssh key", sshkey) + connect_args.update(pkey=paramiko.RSAKey.from_private_key_file(sshkey)) + else: + connect_args.update(password=password) + + self.ssh.connect(**connect_args) self.sftp = self.ssh.open_sftp() # If the file doesnt exist and we are the destination, create it @@ -80,23 +89,31 @@ class SshChunkClient(BaseChunkClient): self.fpool.close() @staticmethod - def from_uri(uri, is_src): + def from_uri(uri, extra_args, is_src): """ instantiate a client from the given uri """ - return SshChunkClient(uri.hostname, uri.username, uri.password, uri.path, is_src) + return SshChunkClient(uri.hostname, uri.username, uri.path, is_src, password=uri.password, sshkey=extra_args.get("sshkey")) class SshChunkClientParallelConnections(BaseChunkClient): - def __init__(self, server, username, password, fpath, is_src, chunk_size=CHUNK_SIZE): + def __init__(self, server, username, fpath, is_src, chunk_size=CHUNK_SIZE, password=None, sshkey=None): super().__init__(chunk_size) self.fpath = fpath self.ssh = paramiko.SSHClient() self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - self.ssh.connect(hostname=server, - username=username, - password=password) + connect_args = dict( + hostname=server, + username=username + ) + if sshkey: + print("using ssh key", sshkey) + connect_args.update(pkey=paramiko.RSAKey.from_private_key_file(sshkey)) + else: + connect_args.update(password=password) + + self.ssh.connect(**connect_args) self.sftp = self.ssh.open_sftp() # If the file doesnt exist and we are the destination, create it @@ -111,9 +128,7 @@ class SshChunkClientParallelConnections(BaseChunkClient): def _fpool_open(): ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(hostname=server, - username=username, - password=password) + ssh.connect(**connect_args) sftp = ssh.open_sftp() f = sftp.open(self.fpath, "r+") @@ -162,11 +177,11 @@ class SshChunkClientParallelConnections(BaseChunkClient): self.fpool.close() @staticmethod - def from_uri(uri, is_src): + def from_uri(uri, extra_args, is_src): """ instantiate a client from the given uri """ - return SshChunkClientParallelConnections(uri.hostname, uri.username, uri.password, uri.path, is_src) + return SshChunkClientParallelConnections(uri.hostname, uri.username, uri.path, is_src, password=uri.password, sshkey=extra_args.get("sshkey")) class SshChunkClientOld(BaseChunkClient):