diff --git a/CHANGELOG.md b/CHANGELOG.md index 6bcf4e1..7c19304 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] +## [v0.1.0] = 2019-12-05 + +### Changed +- The s3 and stepfunction modules are now a class, and the init function accepts a boto3 Session. If not provided a default session is created +- s3.upload now accepts an `http_url` keyword. If set to True it will return the https URL instead of the S3 URL +- s3.find now returns complete s3 URL for each found object, not just the key + ## [v0.0.3] = 2019-11-22 ### Added @@ -17,5 +24,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. Initial Release [Unreleased]: https://github.com/matthewhanson/boto3-utils/compare/master...develop +[v0.1.0]: https://github.com/matthewhanson/boto3-utils/compare/0.0.3...0.1.0 [v0.0.3]: https://github.com/matthewhanson/boto3-utils/compare/0.0.2...0.0.3 [v0.0.2]: https://github.com/matthewhanson/boto3-utils/tree/0.0.2 diff --git a/boto3utils/__init__.py b/boto3utils/__init__.py index 7152555..4622384 100644 --- a/boto3utils/__init__.py +++ b/boto3utils/__init__.py @@ -1 +1,3 @@ -from .version import __version__ \ No newline at end of file +from .version import __version__ + +from .s3 import s3 \ No newline at end of file diff --git a/boto3utils/s3.py b/boto3utils/s3.py index f928135..071803e 100644 --- a/boto3utils/s3.py +++ b/boto3utils/s3.py @@ -15,175 +15,178 @@ logger = logging.getLogger(__name__) -# s3 client -s3 = boto3.client('s3') +class s3(object): -def urlparse(url): - """ Split S3 URL into bucket, key, filename """ - if url[0:5] != 's3://': - raise Exception('Invalid S3 url %s' % url) + def __init__(self, session=None): + if session is None: + self.s3 = boto3.client('s3') + else: + self.s3 = session.client('s3') - url_obj = url.replace('s3://', '').split('/') + @classmethod + def urlparse(cls, url): + """ Split S3 URL into bucket, key, filename """ + if url[0:5] != 's3://': + raise Exception('Invalid S3 url %s' % url) - # remove empty items - url_obj = list(filter(lambda x: x, url_obj)) - - return { - 'bucket': url_obj[0], - 'key': '/'.join(url_obj[1:]), - 'filename': url_obj[-1] if len(url_obj) > 1 else '' - } - - -def s3_to_https(url, region=getenv('AWS_REGION', getenv('AWS_DEFAULT_REGION', 'us-east-1'))): - """ Convert an s3 URL to an s3 https URL """ - parts = urlparse(url) - return 'https://%s.s3.%s.amazonaws.com/%s' % (parts['bucket'], region, parts['key']) + url_obj = url.replace('s3://', '').split('/') + # remove empty items + url_obj = list(filter(lambda x: x, url_obj)) -def exists(url): - """ Check if this URL exists on S3 """ - parts = urlparse(url) - try: - s3.head_object(Bucket=parts['bucket'], Key=parts['key']) - return True - except ClientError as exc: - if exc.response['Error']['Code'] != '404': - raise - return False + return { + 'bucket': url_obj[0], + 'key': '/'.join(url_obj[1:]), + 'filename': url_obj[-1] if len(url_obj) > 1 else '' + } + @classmethod + def s3_to_https(cls, url, region=getenv('AWS_REGION', getenv('AWS_DEFAULT_REGION', 'us-east-1'))): + """ Convert an s3 URL to an s3 https URL """ + parts = cls.urlparse(url) + return 'https://%s.s3.%s.amazonaws.com/%s' % (parts['bucket'], region, parts['key']) -def upload(filename, uri, public=False, extra={}): - """ Upload object to S3 uri (bucket + prefix), keeping same base filename """ - logger.debug("Uploading %s to %s" % (filename, uri)) - s3_uri = urlparse(uri) - uri_out = 's3://%s' % op.join(s3_uri['bucket'], s3_uri['key']) - if public: - extra['ACL'] = 'public-read' - with open(filename, 'rb') as data: - s3.upload_fileobj(data, s3_uri['bucket'], s3_uri['key'], ExtraArgs=extra) - return uri_out - - -def download(uri, path=''): - """ - Download object from S3 - :param uri: URI of object to download - :param path: Output path - """ - s3_uri = urlparse(uri) - fout = op.join(path, s3_uri['filename']) - logger.debug("Downloading %s as %s" % (uri, fout)) - if path != '': - makedirs(path, exist_ok=True) - - with open(fout, 'wb') as f: - s3.download_fileobj( - Bucket=s3_uri['bucket'], - Key=s3_uri['key'], - Fileobj=f - ) - return fout - - -def read(url): - """ Read object from s3 """ - parts = urlparse(url) - response = s3.get_object(Bucket=parts['bucket'], Key=parts['key']) - body = response['Body'].read() - if op.splitext(parts['key'])[1] == '.gz': - body = GzipFile(None, 'rb', fileobj=BytesIO(body)).read() - return body.decode('utf-8') - - -def read_json(url): - """ Download object from S3 as JSON """ - return json.loads(read(url)) - - -# function derived from https://alexwlchan.net/2018/01/listing-s3-keys-redux/ -def find(url, suffix=''): - """ - Generate objects in an S3 bucket. - :param url: The beginning part of the URL to match (bucket + optional prefix) - :param suffix: Only fetch objects whose keys end with this suffix. - """ - parts = urlparse(url) - kwargs = {'Bucket': parts['bucket']} - - # If the prefix is a single string (not a tuple of strings), we can - # do the filtering directly in the S3 API. - if isinstance(parts['key'], str): - kwargs['Prefix'] = parts['key'] - - while True: - # The S3 API response is a large blob of metadata. - # 'Contents' contains information about the listed objects. - resp = s3.list_objects_v2(**kwargs) - try: - contents = resp['Contents'] - except KeyError: - return - - for obj in contents: - key = obj['Key'] - if key.startswith(parts['key']) and key.endswith(suffix): - yield obj['Key'] - - # The S3 API is paginated, returning up to 1000 keys at a time. - # Pass the continuation token into the next response, until we - # reach the final page (when this field is missing). + def exists(self, url): + """ Check if this URL exists on S3 """ + parts = self.urlparse(url) try: - kwargs['ContinuationToken'] = resp['NextContinuationToken'] - except KeyError: - break - - -def latest_inventory(url, prefix=None, suffix=None, start_date=None, end_date=None, datetime_key='LastModifiedDate'): - """ Return generator function for objects in Bucket with suffix (all files if suffix=None) """ - parts = urlparse(url) - # get latest manifest file - today = datetime.now() - manifest_key = None - for dt in [today, today - timedelta(1)]: - _key = op.join(parts['key'], dt.strftime('%Y-%m-%d')) - _url = 's3://%s/%s' % (parts['bucket'], _key) - keys = [k for k in find(_url, suffix='manifest.json')] - if len(keys) == 1: - manifest_key = keys[0] - break - # read through latest manifest looking for matches - if manifest_key: - _url = 's3://%s/%s' % (parts['bucket'], manifest_key) - manifest = read_json(_url) - # get file schema - keys = [str(key).strip() for key in manifest['fileSchema'].split(',')] - - logger.info('Getting latest inventory from %s' % url) - counter = 0 - for f in manifest.get('files', []): - _url = 's3://%s/%s' % (parts['bucket'], f['key']) - inv = read(_url).split('\n') - for line in inv: - counter += 1 - if counter % 100000 == 0: - logger.debug('%s: Scanned %s records' % (datetime.now(), str(counter))) - info = {keys[i]: v for i, v in enumerate(line.replace('"', '').split(','))} - if 'Key' not in info: - continue - # skip to next if last modified date not between start_date and end_date - dt = datetime.strptime(info[datetime_key], "%Y-%m-%dT%H:%M:%S.%fZ").date() - if (start_date is not None and dt < start_date) or (end_date is not None and dt > end_date): - continue - if prefix is not None: - # if path doesn't match provided prefix skip to next record - if info['Key'][:len(prefix)] != prefix: + self.s3.head_object(Bucket=parts['bucket'], Key=parts['key']) + return True + except ClientError as exc: + if exc.response['Error']['Code'] != '404': + raise + return False + + def upload(self, filename, url, public=False, extra={}, http_url=False): + """ Upload object to S3 uri (bucket + prefix), keeping same base filename """ + logger.debug("Uploading %s to %s" % (filename, url)) + parts = self.urlparse(url) + url_out = 's3://%s' % op.join(parts['bucket'], parts['key']) + if public: + extra['ACL'] = 'public-read' + with open(filename, 'rb') as data: + self.s3.upload_fileobj(data, parts['bucket'], parts['key'], ExtraArgs=extra) + if http_url: + region = self.s3.get_bucket_location(Bucket=parts['bucket'])['LocationConstraint'] + return self.s3_to_https(url_out, region) + else: + return url_out + + def download(self, uri, path=''): + """ + Download object from S3 + :param uri: URI of object to download + :param path: Output path + """ + s3_uri = self.urlparse(uri) + fout = op.join(path, s3_uri['filename']) + logger.debug("Downloading %s as %s" % (uri, fout)) + if path != '': + makedirs(path, exist_ok=True) + + with open(fout, 'wb') as f: + self.s3.download_fileobj( + Bucket=s3_uri['bucket'], + Key=s3_uri['key'], + Fileobj=f + ) + return fout + + def read(self, url): + """ Read object from s3 """ + parts = self.urlparse(url) + response = self.s3.get_object(Bucket=parts['bucket'], Key=parts['key']) + body = response['Body'].read() + if op.splitext(parts['key'])[1] == '.gz': + body = GzipFile(None, 'rb', fileobj=BytesIO(body)).read() + return body.decode('utf-8') + + def read_json(self, url): + """ Download object from S3 as JSON """ + return json.loads(self.read(url)) + + + # function derived from https://alexwlchan.net/2018/01/listing-s3-keys-redux/ + def find(self, url, suffix=''): + """ + Generate objects in an S3 bucket. + :param url: The beginning part of the URL to match (bucket + optional prefix) + :param suffix: Only fetch objects whose keys end with this suffix. + """ + parts = self.urlparse(url) + kwargs = {'Bucket': parts['bucket']} + + # If the prefix is a single string (not a tuple of strings), we can + # do the filtering directly in the S3 API. + if isinstance(parts['key'], str): + kwargs['Prefix'] = parts['key'] + + while True: + # The S3 API response is a large blob of metadata. + # 'Contents' contains information about the listed objects. + resp = self.s3.list_objects_v2(**kwargs) + try: + contents = resp['Contents'] + except KeyError: + return + + for obj in contents: + key = obj['Key'] + if key.startswith(parts['key']) and key.endswith(suffix): + yield f"s3://{parts['bucket']}/{obj['Key']}" + + # The S3 API is paginated, returning up to 1000 keys at a time. + # Pass the continuation token into the next response, until we + # reach the final page (when this field is missing). + try: + kwargs['ContinuationToken'] = resp['NextContinuationToken'] + except KeyError: + break + + def latest_inventory(self, url, prefix=None, suffix=None, start_date=None, end_date=None, datetime_key='LastModifiedDate'): + """ Return generator function for objects in Bucket with suffix (all files if suffix=None) """ + parts = self.urlparse(url) + # get latest manifest file + today = datetime.now() + manifest_url = None + for dt in [today, today - timedelta(1)]: + _key = op.join(parts['key'], dt.strftime('%Y-%m-%d')) + _url = 's3://%s/%s' % (parts['bucket'], _key) + manifests = [k for k in self.find(_url, suffix='manifest.json')] + if len(manifests) == 1: + manifest_url = manifests[0] + break + # read through latest manifest looking for matches + if manifest_url: + manifest = self.read_json(manifest_url) + # get file schema + keys = [str(key).strip() for key in manifest['fileSchema'].split(',')] + + logger.info('Getting latest inventory from %s' % url) + counter = 0 + for f in manifest.get('files', []): + _url = 's3://%s/%s' % (parts['bucket'], f['key']) + inv = self.read(_url).split('\n') + for line in inv: + counter += 1 + if counter % 100000 == 0: + logger.debug('%s: Scanned %s records' % (datetime.now(), str(counter))) + info = {keys[i]: v for i, v in enumerate(line.replace('"', '').split(','))} + if 'Key' not in info: + continue + # skip to next if last modified date not between start_date and end_date + dt = datetime.strptime(info[datetime_key], "%Y-%m-%dT%H:%M:%S.%fZ").date() + if (start_date is not None and dt < start_date) or (end_date is not None and dt > end_date): continue - if suffix is None or info['Key'].endswith(suffix): - if 'Bucket' in keys and 'Key' in keys: - info['url'] = 's3://%s/%s' % (info['Bucket'], info['Key']) - yield info + if prefix is not None: + # if path doesn't match provided prefix skip to next record + if info['Key'][:len(prefix)] != prefix: + continue + if suffix is None or info['Key'].endswith(suffix): + if 'Bucket' in keys and 'Key' in keys: + info['url'] = 's3://%s/%s' % (info['Bucket'], info['Key']) + yield info def get_presigned_url(url, aws_region=None, @@ -199,7 +202,7 @@ def get_presigned_url(url, aws_region=None, logger.debug('Not using signed URL for %s' % url) return url, None - parts = urlparse(url) + parts = s3.urlparse(url) bucket = parts['bucket'] key = parts['key'] diff --git a/boto3utils/stepfunctions.py b/boto3utils/stepfunctions.py index dd2362a..420df35 100644 --- a/boto3utils/stepfunctions.py +++ b/boto3utils/stepfunctions.py @@ -9,33 +9,38 @@ logger = logging.getLogger(__name__) -# module level client -sfn = boto3.client('stepfunctions', config=Config(read_timeout=70)) +class stepfunctions(object): + def __init__(self, session=None): + config = Config(read_timeout=70) + if session is None: + self.sfn = boto3.client('stepfunctions', config=config) + else: + self.sfn = session.client('stepfunctions', config=config) -def run_activity(process, arn, **kwargs): - """ Run an activity around the process function provided """ - while True: - logger.info('Querying for task') - try: - task = sfn.get_activity_task(activityArn=arn) - except ReadTimeout: - logger.warning('Activity read timed out') - continue - token = task.get('taskToken', None) - if token is None: - continue - logger.debug('taskToken: %s' % token) - try: - payload = task.get('input', '{}') - logger.info('Payload: %s' % payload) - # run process function with payload as kwargs - output = process(json.loads(payload)) - # Send task success - sfn.send_task_success(taskToken=token, output=json.dumps(output)) - except Exception as e: - err = str(e) - tb = format_exc() - logger.error("Exception when running task: %s - %s" % (err, json.dumps(tb))) - err = (err[252] + ' ...') if len(err) > 252 else err - sfn.send_task_failure(taskToken=token, error=str(err), cause=tb) + def run_activity(self, process, arn, **kwargs): + """ Run an activity around the process function provided """ + while True: + logger.info('Querying for task') + try: + task = self.sfn.get_activity_task(activityArn=arn) + except ReadTimeout: + logger.warning('Activity read timed out') + continue + token = task.get('taskToken', None) + if token is None: + continue + logger.debug('taskToken: %s' % token) + try: + payload = task.get('input', '{}') + logger.info('Payload: %s' % payload) + # run process function with payload as kwargs + output = process(json.loads(payload)) + # Send task success + self.sfn.send_task_success(taskToken=token, output=json.dumps(output)) + except Exception as e: + err = str(e) + tb = format_exc() + logger.error("Exception when running task: %s - %s" % (err, json.dumps(tb))) + err = (err[252] + ' ...') if len(err) > 252 else err + self.sfn.send_task_failure(taskToken=token, error=str(err), cause=tb) diff --git a/boto3utils/version.py b/boto3utils/version.py index ffcc925..b794fd4 100644 --- a/boto3utils/version.py +++ b/boto3utils/version.py @@ -1 +1 @@ -__version__ = '0.0.3' +__version__ = '0.1.0' diff --git a/requirements-dev.txt b/requirements-dev.txt index d5ee25a..5b71d4b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,3 @@ -pytest~=3.6.1 -pytest-cov~=2.5.1 -moto~=1.3.13 +pytest~=5.3 +pytest-cov~=2.8 +moto~=1.3 diff --git a/test/test_s3.py b/test/test_s3.py index ffab538..6ee2ab5 100644 --- a/test/test_s3.py +++ b/test/test_s3.py @@ -50,46 +50,49 @@ def test_s3_to_https(): assert(url == 'https://bucket.s3.us-west-2.amazonaws.com/prefix/filename') def test_exists(s3mock): - exists = s3.exists('s3://%s/%s' % (BUCKET, 'keymaster')) + exists = s3().exists('s3://%s/%s' % (BUCKET, 'keymaster')) assert(exists is False) - exists = s3.exists('s3://%s/%s' % (BUCKET, KEY)) + exists = s3().exists('s3://%s/%s' % (BUCKET, KEY)) assert(exists) def test_exists_invalid(): with pytest.raises(Exception): - s3.exists('invalid') + s3().exists('invalid') def test_upload_download(s3mock): url = 's3://%s/mytestfile' % BUCKET - s3.upload(__file__, url, public=True) - exists = s3.exists(url) + s3().upload(__file__, url, public=True) + exists = s3().exists(url) assert(exists) path = os.path.join(testpath, 'test_s3/test_upload_download') - fname = s3.download(url, path) + fname = s3().download(url, path) assert(os.path.exists(fname)) assert(os.path.join(path, os.path.basename(url)) == fname) rmtree(path) def test_read_json(s3mock): url = 's3://%s/test.json' % BUCKET - out = s3.read_json(url) + out = s3().read_json(url) assert(out['field'] == 'value') def test_find(s3mock): - urls = list(s3.find('s3://%s/test' % BUCKET)) + url = 's3://%s/test' % BUCKET + urls = list(s3().find(url)) assert(len(urls) > 0) - assert('test.json' in urls) + assert(url + '.json' in urls) def test_latest_inventory(): url = 's3://sentinel-inventory/sentinel-s1-l1c/sentinel-s1-l1c-inventory' suffix = 'productInfo.json' - for f in s3.latest_inventory(url, suffix=suffix): + session = boto3.Session() + _s3 = s3(session) + for f in _s3.latest_inventory(url, suffix=suffix): dt = datetime.strptime(f['LastModifiedDate'], "%Y-%m-%dT%H:%M:%S.%fZ") hours = (datetime.today() - dt).seconds // 3600 assert(hours < 24) assert(f['url'].endswith(suffix)) break - for f in s3.latest_inventory(url): + for f in _s3.latest_inventory(url): dt = datetime.strptime(f['LastModifiedDate'], "%Y-%m-%dT%H:%M:%S.%fZ") hours = (datetime.today() - dt).seconds // 3600 assert(hours < 24)