blobsend/blobsend/cli.py

109 lines
3.2 KiB
Python

import argparse
from urllib.parse import urlparse
from concurrent.futures import ThreadPoolExecutor
from blobsend import CHUNK_SIZE
from blobsend.client_file import FileChunkClient
# from blobsend.client_basic_file import FileChunkClient
from blobsend.client_ssh import SshChunkClient
# from blobsend.client_ssh import SshChunkClientOld as SshChunkClient
from blobsend.client_ssh import SshChunkClientParallelConnections
SCHEMES = {
"file": FileChunkClient,
"ssh": SshChunkClient,
"pssh": SshChunkClientParallelConnections,
}
def get_client(uri, extra_args, is_src):
clss = SCHEMES[uri.scheme or "file"]
return clss.from_uri(uri, extra_args, is_src)
def get_args():
parser = argparse.ArgumentParser(description="file blob copy utility")
parser.add_argument("-o", "--options", nargs="+", help="extra arguments")
parser.add_argument("src", help="source file uri")
parser.add_argument("dest", help="dest file uri")
return parser.parse_args(), parser
def send(src, dest):
num_chunks = src.get_length() // CHUNK_SIZE
dest_hashes_iter = dest.get_hashes()
for src_chunk_number, src_chunk_hash in src.get_hashes():
dest_chunk_number = None
dest_chunk_hash = None
try:
dest_chunk_number, dest_chunk_hash = next(dest_hashes_iter)
except StopIteration:
pass
if dest_chunk_number is not None and src_chunk_number != dest_chunk_number:
raise Exception("sequence mismatch?")
if src_chunk_hash != dest_chunk_hash:
print("Copying chunk", src_chunk_number, "/", num_chunks)
blob = src.get_chunk(src_chunk_number)
dest.put_chunk(src_chunk_number, blob)
dest.set_length(src.get_length())
def send_parallel(src, dest):
num_chunks = src.get_length() // CHUNK_SIZE
dest_hashes_iter = dest.get_hashes()
def copy_chunk(chunk_number):
print("Copying chunk", chunk_number, "/", num_chunks)
blob = src.get_chunk(chunk_number)
dest.put_chunk(chunk_number, blob)
with ThreadPoolExecutor(max_workers=10) as pool:
futures = []
for src_chunk_number, src_chunk_hash in src.get_hashes():
dest_chunk_number = None
dest_chunk_hash = None
try:
dest_chunk_number, dest_chunk_hash = next(dest_hashes_iter)
except StopIteration:
pass
if dest_chunk_number is not None and src_chunk_number != dest_chunk_number:
raise Exception("sequence mismatch?")
# print("chunk", src_chunk_number, src_chunk_hash, "vs", dest_chunk_hash)
if src_chunk_hash != dest_chunk_hash:
futures.append(pool.submit(copy_chunk, src_chunk_number))
for f in futures:
f.result()
dest.set_length(src.get_length())
def main():
args, parser = get_args()
extra_args = {}
for arg in args.options or []:
name, value = arg.split("=", 1)
extra_args[name] = value
src = get_client(urlparse(args.src), extra_args, True)
dest = get_client(urlparse(args.dest), extra_args, False)
# send(src, dest)
send_parallel(src, dest)
src.close()
dest.close()
if __name__ == '__main__':
main()