diff --git a/tm_admin/tmdb.py b/tm_admin/tmdb.py index c67d76ed..bf987599 100755 --- a/tm_admin/tmdb.py +++ b/tm_admin/tmdb.py @@ -27,7 +27,7 @@ from datetime import datetime from osm_rawdata.postgres import uriParser, PostgresClient from progress.bar import Bar, PixelBar -from tm_admin.types_tm import Userrole, Mappinglevel +from tm_admin.types_tm import Userrole, Mappinglevel, Organizationtype # Instantiate logger log = logging.getLogger(__name__) @@ -77,9 +77,9 @@ def getColumns(self, table = dict() for column in results: # print(f"FIXME: {column}") - if column[2] and column[2][:7] == 'nextval': - # log.debug(f"Dropping SEQUENCE variable '{column[2]}'") - continue + # if column[2] and column[2][:7] == 'nextval': + # log.debug(f"Dropping SEQUENCE variable '{column[2]}'") + # continue if column[1][:9] == 'timestamp': table[column[0]] = None elif column[1][:5] == 'ARRAY': @@ -158,8 +158,6 @@ def writeAllData(self, log.error(f"{val} {e}") values += f"'USER_READ_ONLY', " continue - #elif table == 'organizations': - # pass # Mapping level is another column that's an int in TM, but also is # an enum in TM, so use the correct enum instead of the integer. # Unlike role, this starts with 1. @@ -167,6 +165,17 @@ def writeAllData(self, level = Mappinglevel(val) values += f"'{level.name}', " continue + elif table == 'organizations': + if key == 'type': + org = Organizationtype(val) + values += f"'{org.name}', " + continue + if key == 'subscription_tier': + if val is None: + values += f"0, " + else: + values += f"{val}, " + continue # All tables if type(val) == str: tmp = val.replace("'", "'") @@ -211,6 +220,7 @@ def main(): parser.add_argument("-v", "--verbose", nargs="?", const="0", help="verbose output") parser.add_argument("-i", "--inuri", default='localhost/tm4', help="The URI string for the TM database") parser.add_argument("-o", "--outuri", default='localhost/tm_admin', help="The URI string for the TM Admin database") + parser.add_argument("-t", "--table", required=True, help="The table to import into") args = parser.parse_args() # if len(argv) <= 1: @@ -230,8 +240,14 @@ def main(): ) doit = TMImport(args.inuri, args.outuri) - data = doit.getAllData('users') - doit.writeAllData(data, 'users') + + table = args.table + # You have to love subtle culture spelling differences. + if table == 'organizations': + data = doit.getAllData('organisations') + else: + data = doit.getAllData(table) + doit.writeAllData(data, table) if __name__ == "__main__": """This is just a hook so this file can be run standalone during development."""