diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index 49b257243..4c8be7e30 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -80,54 +80,19 @@ def __init__(self): """ self._alias = {} self._connected_alias = {} - self._env_uri = None + self._connection_references = {} + self._con_lock = threading.RLock() - if Config.MILVUS_URI != "": - address, parsed_uri = self.__parse_address_from_uri(Config.MILVUS_URI) - self._env_uri = (address, parsed_uri) + address, user, _, db_name = self.__parse_info(Config.MILVUS_URI) - default_conn_config = { - "user": parsed_uri.username if parsed_uri.username is not None else "", - "address": address, - } - else: - default_conn_config = { - "user": "", - "address": f"{Config.DEFAULT_HOST}:{Config.DEFAULT_PORT}", - } + default_conn_config = { + "user": user, + "address": address, + "db_name": db_name, + } self.add_connection(**{Config.MILVUS_CONN_ALIAS: default_conn_config}) - def __verify_host_port(self, host, port): - if not is_legal_host(host): - raise ConnectionConfigException(message=ExceptionsMessage.HostType) - if not is_legal_port(port): - raise ConnectionConfigException(message=ExceptionsMessage.PortType) - if not 0 <= int(port) < 65535: - raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)") - - def __parse_address_from_uri(self, uri: str) -> (str, parse.ParseResult): - illegal_uri_msg = "Illegal uri: [{}], expected form 'https://user:pwd@example.com:12345'" - try: - parsed_uri = parse.urlparse(uri) - except (Exception) as e: - raise ConnectionConfigException( - message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None - - if len(parsed_uri.netloc) == 0: - raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}") from None - - host = parsed_uri.hostname if parsed_uri.hostname is not None else Config.DEFAULT_HOST - port = parsed_uri.port if parsed_uri.port is not None else Config.DEFAULT_PORT - addr = f"{host}:{port}" - - self.__verify_host_port(host, port) - - if not is_legal_address(addr): - raise ConnectionConfigException(message=illegal_uri_msg.format(uri)) - - return addr, parsed_uri - def add_connection(self, **kwargs): """ Configures a milvus connection. @@ -159,41 +124,24 @@ def add_connection(self, **kwargs): ) """ for alias, config in kwargs.items(): - addr, _ = self.__get_full_address( - config.get("address", ""), - config.get("uri", ""), - config.get("host", ""), - config.get("port", "")) + address, user, _, db_name = self.__parse_info(**config) if alias in self._connected_alias: - if self._alias[alias].get("address") != addr: + if ( + self._alias[alias].get("address") != address + or self._alias[alias].get("user") != user + or self._alias[alias].get("db_name") != db_name + ): raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) alias_config = { - "address": addr, - "user": config.get("user", ""), + "address": address, + "user": user, + "db_name": db_name, } self._alias[alias] = alias_config - def __get_full_address(self, address: str = "", uri: str = "", host: str = "", port: str = "") -> ( - str, parse.ParseResult): - if address != "": - if not is_legal_address(address): - raise ConnectionConfigException( - message=f"Illegal address: {address}, should be in form 'localhost:19530'") - return address, None - - if uri != "": - address, parsed = self.__parse_address_from_uri(uri) - return address, parsed - - host = host if host != "" else Config.DEFAULT_HOST - port = port if port != "" else Config.DEFAULT_PORT - self.__verify_host_port(host, port) - - return f"{host}:{port}", None - def disconnect(self, alias: str): """ Disconnects connection from the registry. @@ -203,8 +151,13 @@ def disconnect(self, alias: str): if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) - if alias in self._connected_alias: - self._connected_alias.pop(alias).close() + with self._con_lock: + if alias in self._connected_alias: + gh = self._connected_alias.pop(alias) + self._connection_references[id(gh)] -= 1 + if self._connection_references[id(gh)] <= 0: + gh.close() + del self._connection_references[id(gh)] def remove_connection(self, alias: str): """ Removes connection from the registry. @@ -272,107 +225,74 @@ def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", token="" >>> from pymilvus import connections >>> connections.connect("test", host="localhost", port="19530") """ + # pylint: disable=too-many-statements def connect_milvus(**kwargs): - gh = GrpcHandler(**kwargs) + with self._con_lock: + gh = None + for key, connection_details in self._alias.items(): + + if ( + key in self._connected_alias + and connection_details["address"] == kwargs["address"] + and connection_details["user"] == kwargs["user"] + and connection_details["db_name"] == kwargs["db_name"] + ): + gh = self._connected_alias[key] + break - t = kwargs.get("timeout") - timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT + if gh is None: + gh = GrpcHandler(**kwargs) + t = kwargs.get("timeout") + timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT + gh._wait_for_channel_ready(timeout=timeout) - gh._wait_for_channel_ready(timeout=timeout) - kwargs.pop('password') - kwargs.pop("token", None) - kwargs.pop('db_name', None) - kwargs.pop('secure', None) - kwargs.pop("db_name", "") + kwargs.pop('password', None) + kwargs.pop('secure', None) - self._connected_alias[alias] = gh - self._alias[alias] = copy.deepcopy(kwargs) + self._connected_alias[alias] = gh - def with_config(config: Tuple) -> bool: - for c in config: - if c != "": - return True + self._alias[alias] = copy.deepcopy(kwargs) - return False + if id(gh) not in self._connection_references: + self._connection_references[id(gh)] = 1 + else: + self._connection_references[id(gh)] += 1 if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) - # Set port if server type is zilliz cloud serverless - uri = kwargs.get("uri") - if uri is not None: - server_type = get_server_type(uri) - if server_type == ZILLIZ and ":" not in token: - kwargs["uri"] = uri+":"+str(VIRTUAL_PORT) - - config = ( - kwargs.pop("address", ""), - kwargs.pop("uri", ""), - kwargs.pop("host", ""), - kwargs.pop("port", "") - ) - - # Make sure passed in None doesnt break - user = user or "" - password = password or "" - token = token or "" - # Make sure passed in are Strings - user = str(user) - password = str(password) - token = str(token) - - # 1st Priority: connection from params - if with_config(config): - in_addr, parsed_uri = self.__get_full_address(*config) - kwargs["address"] = in_addr - - if self.has_connection(alias): - if self._alias[alias].get("address") != in_addr: - raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) - - # uri might take extra info - if parsed_uri is not None: - user = parsed_uri.username if parsed_uri.username is not None else user - password = parsed_uri.password if parsed_uri.password is not None else password - - group = parsed_uri.path.split("/") - db_name = "default" - if len(group) > 1: - db_name = group[1] - - # Set secure=True if https scheme - if parsed_uri.scheme == "https": - kwargs["secure"] = True - - - connect_milvus(**kwargs, user=user, password=password, token=token, db_name=db_name) - return + address = kwargs.pop("address", "") + uri = kwargs.pop("uri", "") + host = kwargs.pop("host", "") + port = kwargs.pop("port", "") + user = '' if user is None else str(user) + password = '' if password is None else str(password) + db_name = '' if db_name is None else str(db_name) + + if set([address, uri, host, port]) != {''}: + address, user, password, db_name = self.__parse_info(address, uri, host, port, db_name, user, password) + kwargs["address"] = address + + elif alias in self._alias: + kwargs = dict(self._alias[alias].items()) + # If user is passed in, use it, if not, use previous connections user. + prev_user = kwargs.pop("user") + user = user if user != "" else prev_user + # If db_name is passed in, use it, if not, use previous db_name. + prev_db_name = kwargs.pop("db_name") + db_name = db_name if db_name != "" else prev_db_name - # 2nd Priority, connection configs from env - if self._env_uri is not None: - addr, parsed_uri = self._env_uri - kwargs["address"] = addr - - user = parsed_uri.username if parsed_uri.username is not None else "" - password = parsed_uri.password if parsed_uri.password is not None else "" - - # Set secure=True if https scheme - if parsed_uri.scheme == "https": - kwargs["secure"] = True + # No params, env, and cached configs for the alias + else: + raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias) - connect_milvus(**kwargs, user=user, password=password, db_name=db_name) - return + # Set secure=True if username and password are provided + if len(user) > 0 and len(password) > 0: + kwargs["secure"] = True - # 3rd Priority, connect to cached configs with provided user and password - if alias in self._alias: - connect_alias = dict(self._alias[alias].items()) - connect_alias["user"] = user - connect_milvus(**connect_alias, password=password, db_name=db_name, **kwargs) - return + connect_milvus(**kwargs, user=user, password=password, db_name=db_name) - # No params, env, and cached configs for the alias - raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias) def list_connections(self) -> list: """ List names of all connections. @@ -386,7 +306,8 @@ def list_connections(self) -> list: >>> connections.list_connections() // TODO [('default', None), ('test', )] """ - return [(k, self._connected_alias.get(k, None)) for k in self._alias] + with self._con_lock: + return [(k, self._connected_alias.get(k, None)) for k in self._alias] def get_connection_addr(self, alias: str): """ @@ -431,7 +352,99 @@ def has_connection(self, alias: str) -> bool: """ if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) - return alias in self._connected_alias + with self._con_lock: + return alias in self._connected_alias + + def __parse_info( + self, + address: str = "", + uri: str = "", + host: str = "", + port: str = "", + db_name: str = "", + user: str = "", + password: str = "", + **kwargs) -> dict: + + passed_in_address = "" + passed_in_user = "" + passed_in_password = "" + passed_in_db_name = "" + + # If uri + if uri != "": + passed_in_address, passed_in_user, passed_in_password, passed_in_db_name = ( + self.__parse_address_from_uri(uri) + ) + + elif address != "": + if not is_legal_address(address): + raise ConnectionConfigException( + message=f"Illegal address: {address}, should be in form 'localhost:19530'") + passed_in_address = address + + else: + if host == "": + host = Config.DEFAULT_HOST + if port == "": + port = Config.DEFAULT_PORT + self.__verify_host_port(host, port) + passed_in_address = f"{host}:{port}" + + passed_in_user = user if passed_in_user == "" else str(passed_in_user) + passed_in_user = Config.MILVUS_USER if passed_in_user == "" else str(passed_in_user) + + passed_in_password = password if passed_in_password == "" else str(passed_in_password) + passed_in_password = Config.MILVUS_PASSWORD if passed_in_password == "" else str(passed_in_password) + + passed_in_db_name = db_name if passed_in_db_name == "" else str(passed_in_db_name) + passed_in_db_name = Config.MILVUS_DB_NAME if passed_in_db_name == "" else str(passed_in_db_name) + + return passed_in_address, passed_in_user, passed_in_password, passed_in_db_name + + def __verify_host_port(self, host, port): + if not is_legal_host(host): + raise ConnectionConfigException(message=ExceptionsMessage.HostType) + if not is_legal_port(port): + raise ConnectionConfigException(message=ExceptionsMessage.PortType) + if not 0 <= int(port) < 65535: + raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)") + + def __parse_address_from_uri(self, uri: str) -> Tuple[str, str, str, str]: + illegal_uri_msg = "Illegal uri: [{}], expected form 'https://user:pwd@example.com:12345'" + try: + parsed_uri = parse.urlparse(uri) + except (Exception) as e: + raise ConnectionConfigException( + message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None + + if len(parsed_uri.netloc) == 0: + raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}") from None + + group = parsed_uri.path.split("/") + if len(group) > 1: + db_name = group[1] + else: + db_name = "" + + host = parsed_uri.hostname if parsed_uri.hostname is not None else "" + port = parsed_uri.port if parsed_uri.port is not None else "" + user = parsed_uri.username if parsed_uri.username is not None else "" + password = parsed_uri.password if parsed_uri.password is not None else "" + + if host == "": + raise ConnectionConfigException(message=f"Illegal uri: URI is missing host address: {uri}") + if port == "": + raise ConnectionConfigException(message=f"Illegal uri: URI is missing port: {uri}") + + self.__verify_host_port(host, port) + addr = f"{host}:{port}" + + if not is_legal_address(addr): + raise ConnectionConfigException(message=illegal_uri_msg.format(uri)) + + return addr, user, password, db_name + def _fetch_handler(self, alias=Config.MILVUS_CONN_ALIAS) -> GrpcHandler: """ Retrieves a GrpcHandler by alias. """ diff --git a/pymilvus/settings.py b/pymilvus/settings.py index d54826731..1325b6a34 100644 --- a/pymilvus/settings.py +++ b/pymilvus/settings.py @@ -15,6 +15,11 @@ class Config: MILVUS_CONN_ALIAS = env.str("MILVUS_CONN_ALIAS", "default") MILVUS_CONN_TIMEOUT = env.float("MILVUS_CONN_TIMEOUT", 10) + MILVUS_USER = env.str("MILVUS_USER", "") + MILVUS_PASSWORD = env.str("MILVUS_PASSWORD", "") + + MILVUS_DB_NAME = env.str("MILVUS_DB_NAME", "") + # legacy configs: DEFAULT_USING = MILVUS_CONN_ALIAS DEFAULT_CONNECT_TIMEOUT = MILVUS_CONN_TIMEOUT diff --git a/tests/test_connections.py b/tests/test_connections.py index cb862d51e..8a0b6ec47 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -55,7 +55,7 @@ def uri(self, request): def test_connect_with_default_config(self): alias = "default" - default_addr = {"address": "localhost:19530", "user": ""} + default_addr = {"address": "localhost:19530", "user": "", "db_name": ""} assert connections.has_connection(alias) is False addr = connections.get_connection_addr(alias) @@ -109,7 +109,7 @@ def test_connect_with_default_config_from_environment(self, env_result): def test_connect_new_alias_with_configs(self): alias = "exist" - addr = {"address": "localhost:19530"} + addr = {"address": "localhost:19530", "db_name": ""} assert connections.has_connection(alias) is False a = connections.get_connection_addr(alias) @@ -140,7 +140,7 @@ def test_connect_new_alias_with_configs_NoHostOrPort(self, no_host_or_port): connections.connect(alias, **no_host_or_port) assert connections.has_connection(alias) is True - assert connections.get_connection_addr(alias) == {"address": "localhost:19530", "user": ""} + assert connections.get_connection_addr(alias) == {"address": "localhost:19530", "user": "", "db_name": ""} with mock.patch(f"{mock_prefix}.close", return_value=None): connections.remove_connection(alias) @@ -193,6 +193,33 @@ def test_add_connection_then_connect(self, uri): with mock.patch(f"{mock_prefix}.close", return_value=None): connections.remove_connection(alias) + def test_connect_with_reuse_grpc(self): + alias = "default" + default_addr = {"address": "localhost:19530", "user": "", "db_name": ""} + + reuse_alias = "reuse" + + assert connections.has_connection(alias) is False + addr = connections.get_connection_addr(alias) + assert addr == default_addr + + with mock.patch(f"{mock_prefix}.__init__", return_value=None): + with mock.patch(f"{mock_prefix}._wait_for_channel_ready", return_value=None): + connections.connect(alias=alias, **default_addr) + connections.connect(alias=reuse_alias, **default_addr) + assert connections._connected_alias[alias] == connections._connected_alias[reuse_alias] + assert list(connections._connection_references.values())[0] == 2 + + with mock.patch(f"{mock_prefix}.close", return_value=None): + connections.disconnect(alias) + + assert list(connections._connection_references.values())[0] == 1 + + with mock.patch(f"{mock_prefix}.close", return_value=None): + connections.disconnect(reuse_alias) + + assert len(connections._connection_references) == 0 + class TestAddConnection: @pytest.fixture(scope="function", params=[ @@ -301,7 +328,6 @@ def test_add_connection_address_invalid(self, invalid_addr): {"uri": "http://127.0.0.1:19530"}, {"uri": "http://localhost:19530"}, {"uri": "http://example.com:80"}, - {"uri": "http://example.com"}, ]) def test_add_connection_uri(self, valid_uri): alias = self.test_add_connection_uri.__name__ @@ -323,6 +349,8 @@ def test_add_connection_uri(self, valid_uri): {"uri": "http://"}, {"uri": None}, {"uri": -1}, + {"uri": "http://example.com"}, + {"uri": "http://:90"}, ]) def test_add_connection_uri_invalid(self, invalid_uri): alias = self.test_add_connection_uri_invalid.__name__ @@ -359,13 +387,13 @@ def test_issue_1196(self): config = {"alias": alias, "host": "localhost", "port": "19531", "user": "root", "password": 12345, "secure": True} connections.connect(**config) config = connections.get_connection_addr(alias) - assert config == {"address": 'localhost:19531', "user": 'root'} + assert config == {"address": 'localhost:19531', "user": 'root', "db_name": ""} connections.add_connection(default={"host": "localhost", "port": 19531}) config = connections.get_connection_addr("default") - assert config == {"address": 'localhost:19531', "user": ""} + assert config == {"address": 'localhost:19531', "user": "", "db_name": ""} connections.connect("default", user="root", password="12345", secure=True) config = connections.get_connection_addr("default") - assert config == {"address": 'localhost:19531', "user": 'root'} + assert config == {"address": 'localhost:19531', "user": 'root', "db_name": ""}