|
| 1 | +"""Shared helpers for VWS client implementations.""" |
| 2 | + |
| 3 | +import base64 |
| 4 | +import json |
| 5 | +from datetime import date |
| 6 | +from typing import NoReturn |
| 7 | + |
| 8 | +from vws._image_utils import ImageType, get_image_data |
| 9 | +from vws.exceptions.vws_exceptions import ( |
| 10 | + AuthenticationFailureError, |
| 11 | + BadImageError, |
| 12 | + BadRequestError, |
| 13 | + DateRangeError, |
| 14 | + FailError, |
| 15 | + ImageTooLargeError, |
| 16 | + InvalidAcceptHeaderError, |
| 17 | + InvalidInstanceIdError, |
| 18 | + InvalidTargetTypeError, |
| 19 | + MetadataTooLargeError, |
| 20 | + ProjectHasNoAPIAccessError, |
| 21 | + ProjectInactiveError, |
| 22 | + ProjectSuspendedError, |
| 23 | + RequestQuotaReachedError, |
| 24 | + RequestTimeTooSkewedError, |
| 25 | + TargetNameExistError, |
| 26 | + TargetQuotaReachedError, |
| 27 | + TargetStatusNotSuccessError, |
| 28 | + TargetStatusProcessingError, |
| 29 | + UnknownTargetError, |
| 30 | +) |
| 31 | +from vws.reports import ( |
| 32 | + DatabaseSummaryReport, |
| 33 | + TargetRecord, |
| 34 | + TargetStatusAndRecord, |
| 35 | + TargetStatuses, |
| 36 | + TargetSummaryReport, |
| 37 | +) |
| 38 | +from vws.response import Response |
| 39 | + |
| 40 | + |
| 41 | +def raise_for_vws_result_code( |
| 42 | + result_code: str, response: Response |
| 43 | +) -> NoReturn: |
| 44 | + """Raise the appropriate VWS exception for the given result code.""" |
| 45 | + exception = { |
| 46 | + "AuthenticationFailure": AuthenticationFailureError, |
| 47 | + "BadImage": BadImageError, |
| 48 | + "BadRequest": BadRequestError, |
| 49 | + "DateRangeError": DateRangeError, |
| 50 | + "Fail": FailError, |
| 51 | + "ImageTooLarge": ImageTooLargeError, |
| 52 | + "MetadataTooLarge": MetadataTooLargeError, |
| 53 | + "ProjectHasNoAPIAccess": ProjectHasNoAPIAccessError, |
| 54 | + "ProjectInactive": ProjectInactiveError, |
| 55 | + "ProjectSuspended": ProjectSuspendedError, |
| 56 | + "RequestQuotaReached": RequestQuotaReachedError, |
| 57 | + "RequestTimeTooSkewed": RequestTimeTooSkewedError, |
| 58 | + "TargetNameExist": TargetNameExistError, |
| 59 | + "TargetQuotaReached": TargetQuotaReachedError, |
| 60 | + "TargetStatusNotSuccess": TargetStatusNotSuccessError, |
| 61 | + "TargetStatusProcessing": TargetStatusProcessingError, |
| 62 | + "UnknownTarget": UnknownTargetError, |
| 63 | + }[result_code] |
| 64 | + raise exception(response=response) |
| 65 | + |
| 66 | + |
| 67 | +def raise_for_vumark_result_code( |
| 68 | + result_code: str, response: Response |
| 69 | +) -> NoReturn: |
| 70 | + """Raise the appropriate VuMark exception for the given result |
| 71 | + code. |
| 72 | + """ |
| 73 | + exception = { |
| 74 | + "AuthenticationFailure": AuthenticationFailureError, |
| 75 | + "BadRequest": BadRequestError, |
| 76 | + "DateRangeError": DateRangeError, |
| 77 | + "Fail": FailError, |
| 78 | + "InvalidAcceptHeader": InvalidAcceptHeaderError, |
| 79 | + "InvalidInstanceId": InvalidInstanceIdError, |
| 80 | + "InvalidTargetType": InvalidTargetTypeError, |
| 81 | + "RequestTimeTooSkewed": RequestTimeTooSkewedError, |
| 82 | + "TargetStatusNotSuccess": TargetStatusNotSuccessError, |
| 83 | + "UnknownTarget": UnknownTargetError, |
| 84 | + }[result_code] |
| 85 | + raise exception(response=response) |
| 86 | + |
| 87 | + |
| 88 | +def parse_target_record_response(text: str) -> TargetStatusAndRecord: |
| 89 | + """Parse a get_target_record response body.""" |
| 90 | + result_data = json.loads(s=text) |
| 91 | + status = TargetStatuses(value=result_data["status"]) |
| 92 | + target_record_dict = dict(result_data["target_record"]) |
| 93 | + target_record = TargetRecord( |
| 94 | + target_id=target_record_dict["target_id"], |
| 95 | + active_flag=bool(target_record_dict["active_flag"]), |
| 96 | + name=target_record_dict["name"], |
| 97 | + width=float(target_record_dict["width"]), |
| 98 | + tracking_rating=int(target_record_dict["tracking_rating"]), |
| 99 | + reco_rating=target_record_dict["reco_rating"], |
| 100 | + ) |
| 101 | + return TargetStatusAndRecord( |
| 102 | + status=status, |
| 103 | + target_record=target_record, |
| 104 | + ) |
| 105 | + |
| 106 | + |
| 107 | +def parse_target_summary_response(text: str) -> TargetSummaryReport: |
| 108 | + """Parse a get_target_summary_report response body.""" |
| 109 | + result_data = dict(json.loads(s=text)) |
| 110 | + return TargetSummaryReport( |
| 111 | + status=TargetStatuses(value=result_data["status"]), |
| 112 | + database_name=result_data["database_name"], |
| 113 | + target_name=result_data["target_name"], |
| 114 | + upload_date=date.fromisoformat(result_data["upload_date"]), |
| 115 | + active_flag=bool(result_data["active_flag"]), |
| 116 | + tracking_rating=int(result_data["tracking_rating"]), |
| 117 | + total_recos=int(result_data["total_recos"]), |
| 118 | + current_month_recos=int(result_data["current_month_recos"]), |
| 119 | + previous_month_recos=int(result_data["previous_month_recos"]), |
| 120 | + ) |
| 121 | + |
| 122 | + |
| 123 | +def parse_database_summary_response(text: str) -> DatabaseSummaryReport: |
| 124 | + """Parse a get_database_summary_report response body.""" |
| 125 | + response_data = dict(json.loads(s=text)) |
| 126 | + return DatabaseSummaryReport( |
| 127 | + active_images=int(response_data["active_images"]), |
| 128 | + current_month_recos=int(response_data["current_month_recos"]), |
| 129 | + failed_images=int(response_data["failed_images"]), |
| 130 | + inactive_images=int(response_data["inactive_images"]), |
| 131 | + name=str(object=response_data["name"]), |
| 132 | + previous_month_recos=int(response_data["previous_month_recos"]), |
| 133 | + processing_images=int(response_data["processing_images"]), |
| 134 | + reco_threshold=int(response_data["reco_threshold"]), |
| 135 | + request_quota=int(response_data["request_quota"]), |
| 136 | + request_usage=int(response_data["request_usage"]), |
| 137 | + target_quota=int(response_data["target_quota"]), |
| 138 | + total_recos=int(response_data["total_recos"]), |
| 139 | + ) |
| 140 | + |
| 141 | + |
| 142 | +def build_add_target_content( |
| 143 | + *, |
| 144 | + name: str, |
| 145 | + width: float, |
| 146 | + image: ImageType, |
| 147 | + active_flag: bool, |
| 148 | + application_metadata: str | None, |
| 149 | +) -> bytes: |
| 150 | + """Build the request body for an add_target request.""" |
| 151 | + image_data = get_image_data(image=image) |
| 152 | + image_data_encoded = base64.b64encode(s=image_data).decode( |
| 153 | + encoding="ascii", |
| 154 | + ) |
| 155 | + data = { |
| 156 | + "name": name, |
| 157 | + "width": width, |
| 158 | + "image": image_data_encoded, |
| 159 | + "active_flag": active_flag, |
| 160 | + "application_metadata": application_metadata, |
| 161 | + } |
| 162 | + return json.dumps(obj=data).encode(encoding="utf-8") |
| 163 | + |
| 164 | + |
| 165 | +def build_update_target_content( |
| 166 | + *, |
| 167 | + name: str | None, |
| 168 | + width: float | None, |
| 169 | + image: ImageType | None, |
| 170 | + active_flag: bool | None, |
| 171 | + application_metadata: str | None, |
| 172 | +) -> bytes: |
| 173 | + """Build the request body for an update_target request.""" |
| 174 | + data: dict[str, str | bool | float | int] = {} |
| 175 | + |
| 176 | + if name is not None: |
| 177 | + data["name"] = name |
| 178 | + |
| 179 | + if width is not None: |
| 180 | + data["width"] = width |
| 181 | + |
| 182 | + if image is not None: |
| 183 | + image_data = get_image_data(image=image) |
| 184 | + image_data_encoded = base64.b64encode(s=image_data).decode( |
| 185 | + encoding="ascii", |
| 186 | + ) |
| 187 | + data["image"] = image_data_encoded |
| 188 | + |
| 189 | + if active_flag is not None: |
| 190 | + data["active_flag"] = active_flag |
| 191 | + |
| 192 | + if application_metadata is not None: |
| 193 | + data["application_metadata"] = application_metadata |
| 194 | + |
| 195 | + return json.dumps(obj=data).encode(encoding="utf-8") |
0 commit comments