diff --git a/blobsend/__init__.py b/blobsend/__init__.py index 7b2936e..1e46e51 100644 --- a/blobsend/__init__.py +++ b/blobsend/__init__.py @@ -1,10 +1,59 @@ import hashlib +from threading import Semaphore CHUNK_SIZE = 1024 * 1024 * 4 # 4 mb chunks def hash_chunk(data): - h = hashlib.md5() + h = hashlib.sha256() h.update(data) return h.hexdigest() + + +class FilePool(object): + def __init__(self, instances, open_function): + """ + Pool of many files handle + """ + self.instances = [open_function() for i in range(0, instances)] + self.lock = Semaphore(instances) + + def get(self): + """ + Return a context-manager that contains the file object + """ + return FilePoolSlice(self) + + def close(self): + for f in self.instances: + f.close() + + def _get_locked(self): + """ + Acquire and return an instance from the pool + """ + self.lock.acquire() + return self.instances.pop() + + def _return_locked(self, instance): + """ + Release an instance emitted by get_locked + """ + self.instances.append(instance) + self.lock.release() + + +class FilePoolSlice(object): + def __init__(self, pool): + self.pool = pool + self.mine = None + + def __enter__(self): + #TODO don't use me more than once + self.mine = self.pool._get_locked() + self.mine.seek(0) + return self.mine + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.pool._return_locked(self.mine) diff --git a/blobsend/cli.py b/blobsend/cli.py index 7dd2ebe..eb0ae47 100644 --- a/blobsend/cli.py +++ b/blobsend/cli.py @@ -1,13 +1,18 @@ import argparse from urllib.parse import urlparse +from concurrent.futures import ThreadPoolExecutor from blobsend import CHUNK_SIZE from blobsend.client_file import FileChunkClient +# from blobsend.client_basic_file import FileChunkClient from blobsend.client_ssh import SshChunkClient +# from blobsend.client_ssh import SshChunkClientOld as SshChunkClient +from blobsend.client_ssh import SshChunkClientParallelConnections SCHEMES = { "file": FileChunkClient, "ssh": SshChunkClient, + "pssh": SshChunkClientParallelConnections, } @@ -24,13 +29,7 @@ def get_args(): return parser.parse_args(), parser -def main(): - args, parser = get_args() - print(args) - - src = get_client(urlparse(args.src), True) - dest = get_client(urlparse(args.dest), False) - +def send(src, dest): num_chunks = src.get_length() // CHUNK_SIZE dest_hashes_iter = dest.get_hashes() @@ -52,6 +51,50 @@ def main(): dest.set_length(src.get_length()) + +def send_parallel(src, dest): + num_chunks = src.get_length() // CHUNK_SIZE + + dest_hashes_iter = dest.get_hashes() + + def copy_chunk(chunk_number): + print("Copying chunk", chunk_number, "/", num_chunks) + blob = src.get_chunk(chunk_number) + dest.put_chunk(chunk_number, blob) + + with ThreadPoolExecutor(max_workers=10) as pool: + futures = [] + for src_chunk_number, src_chunk_hash in src.get_hashes(): + dest_chunk_number = None + dest_chunk_hash = None + try: + dest_chunk_number, dest_chunk_hash = next(dest_hashes_iter) + except StopIteration: + pass + + if dest_chunk_number is not None and src_chunk_number != dest_chunk_number: + raise Exception("sequence mismatch?") + + # print("chunk", src_chunk_number, src_chunk_hash, "vs", dest_chunk_hash) + + if src_chunk_hash != dest_chunk_hash: + futures.append(pool.submit(copy_chunk, src_chunk_number)) + for f in futures: + f.result() + + dest.set_length(src.get_length()) + + +def main(): + args, parser = get_args() + print(args) + + src = get_client(urlparse(args.src), True) + dest = get_client(urlparse(args.dest), False) + + # send(src, dest) + send_parallel(src, dest) + src.close() dest.close() diff --git a/blobsend/client_file.py b/blobsend/client_file.py index 7cfa802..9307f04 100644 --- a/blobsend/client_file.py +++ b/blobsend/client_file.py @@ -1,18 +1,27 @@ import os from blobsend.client_base import BaseChunkClient -from blobsend import CHUNK_SIZE, hash_chunk +from blobsend import CHUNK_SIZE, hash_chunk, FilePool class FileChunkClient(BaseChunkClient): - def __init__(self, fpath, chunk_size=CHUNK_SIZE): + def __init__(self, fpath, is_src, chunk_size=CHUNK_SIZE): super().__init__(chunk_size) self.fpath = fpath - self.file = open(self.fpath, "ab+") # for get chunk operations, this generic file is used instead of doing lots of open/close - self.file.seek(0) + + if not is_src and not os.path.exists(self.fpath): + with open(self.fpath, "wb"): + pass + + def _fpool_open(): + f = open(self.fpath, "rb+") # for get chunk operations, this generic file is used instead of doing lots of open/close + f.seek(0) + return f + + self.fpool = FilePool(10, _fpool_open) def get_hashes(self): i = 0 - with open(self.fpath, "rb+") as f: + with self.fpool.get() as f: while True: data = f.read(self.chunk_size) if not data: @@ -27,8 +36,9 @@ class FileChunkClient(BaseChunkClient): position = chunk_number * self.chunk_size if position > os.path.getsize(self.fpath):#TODO not sure if > or >= raise Exception("requested chunk {} is beyond EOF".format(chunk_number)) - self.file.seek(position)#TODO not thread safe - return self.file.read(self.chunk_size) + with self.fpool.get() as f: + f.seek(position)#TODO not thread safe + return f.read(self.chunk_size) def put_chunk(self, chunk_number, contents): """ @@ -43,21 +53,23 @@ class FileChunkClient(BaseChunkClient): """ get the file size """ - self.file.seek(0, 2) # seek to end - return self.file.tell() + with self.fpool.get() as f: + f.seek(0, 2) # seek to end + return f.tell() def set_length(self, length): if length < self.get_length(): - self.file.truncate(length) + with self.fpool.get() as f: + f.truncate(length) # do nothing for the case of extending the file # put_chunk handles it def close(self): - self.file.close() + self.fpool.close() @staticmethod def from_uri(uri, is_src): """ instantiate a client from the given uri """ - return FileChunkClient(uri.path) + return FileChunkClient(uri.path, is_src) diff --git a/blobsend/client_ssh.py b/blobsend/client_ssh.py index c456e48..7a8f2da 100644 --- a/blobsend/client_ssh.py +++ b/blobsend/client_ssh.py @@ -1,6 +1,6 @@ import paramiko from blobsend.client_base import BaseChunkClient -from blobsend import CHUNK_SIZE, hash_chunk +from blobsend import CHUNK_SIZE, hash_chunk, FilePool """ @@ -12,6 +12,164 @@ REMOTE_UTILITY = "/Users/dave/code/blobsend/testenv/bin/_blobsend_ssh_remote"# class SshChunkClient(BaseChunkClient): + def __init__(self, server, username, password, fpath, is_src, chunk_size=CHUNK_SIZE): + super().__init__(chunk_size) + self.fpath = fpath + self.ssh = paramiko.SSHClient() + self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.ssh.connect(hostname=server, + username=username, + password=password) + + self.sftp = self.ssh.open_sftp() + + # If the file doesnt exist and we are the destination, create it + if not is_src: + try: + with self.sftp.open(self.fpath, "r"): + pass + except FileNotFoundError: + with self.sftp.open(self.fpath, "wb"): + pass + + def _fpool_open(): + sftp = self.ssh.open_sftp() + f = sftp.open(self.fpath, "r+") + f.seek(0) + return f + + self.fpool = FilePool(8, _fpool_open) + + def get_hashes(self): + stdin, stdout, stderr = self.ssh.exec_command("{} chunks {}".format(REMOTE_UTILITY, self.fpath))#TODO safe arg escapes + stdin.close() + for line in iter(lambda: stdout.readline(1024), ""): + chunk_number, chunk_hash = line.strip().split(" ") + yield (int(chunk_number), chunk_hash, ) + exit = stdout.channel.recv_exit_status() + if exit != 0: + raise Exception("hash command exit code was {}: {}".format(exit, stderr.read())) + + def get_chunk(self, chunk_number): + position = chunk_number * self.chunk_size + if position > self.get_length(): + raise Exception("requested chunk {} is beyond EOF".format(chunk_number)) + with self.fpool.get() as f: + f.seek(position) + return f.read(self.chunk_size) + + def put_chunk(self, chunk_number, contents): + position = chunk_number * self.chunk_size + with self.fpool.get() as f: + f.seek(position) + f.write(contents) + + def get_length(self): + with self.fpool.get() as f: + f.seek(0, 2) # seek to end + return f.tell() + + def set_length(self, length): + if length < self.get_length(): + with self.fpool.get() as f: + f.truncate(length) + # do nothing for the case of extending the file + # put_chunk handles it + + def close(self): + self.fpool.close() + + @staticmethod + def from_uri(uri, is_src): + """ + instantiate a client from the given uri + """ + return SshChunkClient(uri.hostname, uri.username, uri.password, uri.path, is_src) + + +class SshChunkClientParallelConnections(BaseChunkClient): + def __init__(self, server, username, password, fpath, is_src, chunk_size=CHUNK_SIZE): + super().__init__(chunk_size) + self.fpath = fpath + self.ssh = paramiko.SSHClient() + self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.ssh.connect(hostname=server, + username=username, + password=password) + + self.sftp = self.ssh.open_sftp() + + # If the file doesnt exist and we are the destination, create it + if not is_src: + try: + with self.sftp.open(self.fpath, "r"): + pass + except FileNotFoundError: + with self.sftp.open(self.fpath, "wb"): + pass + + def _fpool_open(): + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(hostname=server, + username=username, + password=password) + + sftp = ssh.open_sftp() + f = sftp.open(self.fpath, "r+") + f.seek(0) + return f + + self.fpool = FilePool(8, _fpool_open) + + def get_hashes(self): + stdin, stdout, stderr = self.ssh.exec_command("{} chunks {}".format(REMOTE_UTILITY, self.fpath))#TODO safe arg escapes + stdin.close() + for line in iter(lambda: stdout.readline(1024), ""): + chunk_number, chunk_hash = line.strip().split(" ") + yield (int(chunk_number), chunk_hash, ) + exit = stdout.channel.recv_exit_status() + if exit != 0: + raise Exception("hash command exit code was {}: {}".format(exit, stderr.read())) + + def get_chunk(self, chunk_number): + position = chunk_number * self.chunk_size + if position > self.get_length(): + raise Exception("requested chunk {} is beyond EOF".format(chunk_number)) + with self.fpool.get() as f: + f.seek(position) + return f.read(self.chunk_size) + + def put_chunk(self, chunk_number, contents): + position = chunk_number * self.chunk_size + with self.fpool.get() as f: + f.seek(position) + f.write(contents) + + def get_length(self): + with self.fpool.get() as f: + f.seek(0, 2) # seek to end + return f.tell() + + def set_length(self, length): + if length < self.get_length(): + with self.fpool.get() as f: + f.truncate(length) + # do nothing for the case of extending the file + # put_chunk handles it + + def close(self): + self.fpool.close() + + @staticmethod + def from_uri(uri, is_src): + """ + instantiate a client from the given uri + """ + return SshChunkClientParallelConnections(uri.hostname, uri.username, uri.password, uri.path, is_src) + + +class SshChunkClientOld(BaseChunkClient): def __init__(self, server, username, password, fpath, is_src, chunk_size=CHUNK_SIZE): super().__init__(chunk_size) self.fpath = fpath @@ -32,8 +190,6 @@ class SshChunkClient(BaseChunkClient): with self.sftp.open(self.fpath, "wb") as f: pass - # it seems like mode "ab+" doesn't work the same way under paramiko - # it refuses to seek before the open point (which is the end of the file) self.file = self.sftp.open(self.fpath, "r+") def get_hashes(self): @@ -42,6 +198,9 @@ class SshChunkClient(BaseChunkClient): for line in iter(lambda: stdout.readline(1024), ""): chunk_number, chunk_hash = line.strip().split(" ") yield (int(chunk_number), chunk_hash, ) + exit = stdout.channel.recv_exit_status() + if exit != 0: + raise Exception("hash command exit code was {}: {}".format(exit, stderr.read())) def get_chunk(self, chunk_number): position = chunk_number * self.chunk_size @@ -61,7 +220,8 @@ class SshChunkClient(BaseChunkClient): def set_length(self, length): if length < self.get_length(): - self.file.truncate(length) + with self.fpool.get() as f: + f.truncate(length) # do nothing for the case of extending the file # put_chunk handles it @@ -73,4 +233,4 @@ class SshChunkClient(BaseChunkClient): """ instantiate a client from the given uri """ - return SshChunkClient(uri.hostname, uri.username, uri.password, uri.path, is_src) + return SshChunkClientOld(uri.hostname, uri.username, uri.password, uri.path, is_src) diff --git a/blobsend/client_ssh_remote.py b/blobsend/client_ssh_remote.py index b6c366d..6ab798f 100644 --- a/blobsend/client_ssh_remote.py +++ b/blobsend/client_ssh_remote.py @@ -6,7 +6,7 @@ from blobsend.client_file import FileChunkClient def cmd_chunks(args, parser): - c = FileChunkClient(args.fpath) + c = FileChunkClient(args.fpath, False) for chunk_number, chunk_hash in c.get_hashes(): print(chunk_number, chunk_hash) sys.stdout.flush()