Skip to content

Commit

Permalink
Review changes
Browse files Browse the repository at this point in the history
Review changes added and also shuffled logic for alias reuse

Signed-off-by: Filip Haltmayer <[email protected]>
  • Loading branch information
Filip Haltmayer committed May 31, 2023
1 parent 1b315f8 commit e1fb9c9
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 76 deletions.
150 changes: 90 additions & 60 deletions pymilvus/orm/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from ..client.check import is_legal_host, is_legal_port, is_legal_address
from ..client.grpc_handler import GrpcHandler
from ..client.utils import get_server_type, ZILLIZ

from ..settings import Config
from ..exceptions import ExceptionsMessage, ConnectionConfigException, ConnectionNotExistException
Expand Down Expand Up @@ -82,16 +81,25 @@ def __init__(self):
self._connected_alias = {}
self._connection_references = {}
self._con_lock = threading.RLock()

address, user, _, db_name = self.__parse_info(Config.MILVUS_URI)

default_conn_config = {
"user": user,
"address": address,
"db_name": db_name,
}

self.add_connection(**{Config.MILVUS_CONN_ALIAS: default_conn_config})
# info = self.__parse_info(
# uri=Config.MILVUS_URI,
# host=Config.DEFAULT_HOST,
# port=Config.DEFAULT_PORT,
# user = Config.MILVUS_USER,
# password = Config.MILVUS_PASSWORD,
# token = Config.MILVUS_TOKEN,
# secure=Config.DEFAULT_SECURE,
# db_name=Config.MILVUS_DB_NAME
# )

# default_conn_config = {
# "user": info["user"],
# "address": info["address"],
# "db_name": info["db_name"],
# "secure": info["secure"],
# }

# self.add_connection(**{Config.MILVUS_CONN_ALIAS: default_conn_config})

def add_connection(self, **kwargs):
""" Configures a milvus connection.
Expand Down Expand Up @@ -124,20 +132,21 @@ def add_connection(self, **kwargs):
)
"""
for alias, config in kwargs.items():
address, user, _, db_name = self.__parse_info(**config)
parsed = self.__parse_info(**config)

if alias in self._connected_alias:
if (
self._alias[alias].get("address") != address
or self._alias[alias].get("user") != user
or self._alias[alias].get("db_name") != db_name
self._alias[alias].get("address") != parsed["address"]
or self._alias[alias].get("user") != parsed["user"]
or self._alias[alias].get("db_name") != parsed["db_name"]
or self._alias[alias].get("secure") != parsed["secure"]
):
raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias)

alias_config = {
"address": address,
"user": user,
"db_name": db_name,
"address": parsed["address"],
"user": parsed["user"],
"db_name": parsed["db_name"],
"secure": parsed["secure"],
}

self._alias[alias] = alias_config
Expand Down Expand Up @@ -237,18 +246,19 @@ def connect_milvus(**kwargs):
and connection_details["address"] == kwargs["address"]
and connection_details["user"] == kwargs["user"]
and connection_details["db_name"] == kwargs["db_name"]
and connection_details["secure"] == kwargs["secure"]
):
gh = self._connected_alias[key]
break

if gh is None:
gh = GrpcHandler(**kwargs)
t = kwargs.get("timeout")
t = kwargs.get("timeout", None)
timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT
gh._wait_for_channel_ready(timeout=timeout)

kwargs.pop('password', None)
kwargs.pop('secure', None)
kwargs.pop('token', None)

self._connected_alias[alias] = gh

Expand All @@ -262,36 +272,58 @@ def connect_milvus(**kwargs):
if not isinstance(alias, str):
raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias))

# Grab the relevant info for connection
address = kwargs.pop("address", "")
uri = kwargs.pop("uri", "")
host = kwargs.pop("host", "")
port = kwargs.pop("port", "")
secure = kwargs.pop("secure", None)

# Clean the connection info
address = '' if address is None else str(address)
uri = '' if uri is None else str(uri)
host = '' if host is None else str(host)
port = '' if port is None else str(port)
user = '' if user is None else str(user)
password = '' if password is None else str(password)
token = '' if token is None else (str(token))
db_name = '' if db_name is None else str(db_name)

if set([address, uri, host, port]) != {''}:
address, user, password, db_name = self.__parse_info(address, uri, host, port, db_name, user, password)
kwargs["address"] = address

elif alias in self._alias:
# Replace empties with defaults from enviroment
uri = uri if uri != '' else Config.MILVUS_URI
host = host if host != '' else Config.DEFAULT_HOST
port = port if port != '' else Config.DEFAULT_PORT
user = user if user != '' else Config.MILVUS_USER
password = password if password != '' else Config.MILVUS_PASSWORD
token = token if token != '' else Config.MILVUS_TOKEN
db_name = db_name if db_name != '' else Config.MILVUS_DB_NAME

# If no address info is given, check if an alias exists
if alias in self._alias:
kwargs = dict(self._alias[alias].items())
# If user is passed in, use it, if not, use previous connections user.
prev_user = kwargs.pop("user")
user = user if user != "" else prev_user
kwargs["user"] = user if user != "" else prev_user

# If new secure parameter passed in, use that
prev_secure = kwargs.pop("secure")
kwargs["secure"] = secure if secure is not None else prev_secure

# If db_name is passed in, use it, if not, use previous db_name.
prev_db_name = kwargs.pop("db_name")
db_name = db_name if db_name != "" else prev_db_name
kwargs["db_name"] = db_name if db_name != "" else prev_db_name

# No params, env, and cached configs for the alias
# If at least one address info is given, parse it
elif set([address, uri, host, port]) != {''}:
secure = secure if secure is not None else Config.DEFAULT_SECURE
parsed = self.__parse_info(address, uri, host, port, db_name, user, password, token, secure)
kwargs.update(parsed)

# If no details are given and no alias exists
else:
raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias)

# Set secure=True if username and password are provided
if len(user) > 0 and len(password) > 0:
kwargs["secure"] = True

connect_milvus(**kwargs, user=user, password=password, db_name=db_name)
connect_milvus(**kwargs)


def list_connections(self) -> list:
Expand Down Expand Up @@ -364,43 +396,40 @@ def __parse_info(
db_name: str = "",
user: str = "",
password: str = "",
token: str = "",
secure: bool = False,
**kwargs) -> dict:

passed_in_address = ""
passed_in_user = ""
passed_in_password = ""
passed_in_db_name = ""

# If uri
extracted_address = ""
extracted_user = ""
extracted_password = ""
extracted_db_name = ""
extracted_token = ""
extracted_secure = None
# If URI
if uri != "":
passed_in_address, passed_in_user, passed_in_password, passed_in_db_name = (
extracted_address, extracted_user, extracted_password, extracted_db_name, extracted_secure = (
self.__parse_address_from_uri(uri)
)

# If Address
elif address != "":
if not is_legal_address(address):
raise ConnectionConfigException(
message=f"Illegal address: {address}, should be in form 'localhost:19530'")
passed_in_address = address

extracted_address = address
# If Host port
else:
if host == "":
host = Config.DEFAULT_HOST
if port == "":
port = Config.DEFAULT_PORT
self.__verify_host_port(host, port)
passed_in_address = f"{host}:{port}"

passed_in_user = user if passed_in_user == "" else str(passed_in_user)
passed_in_user = Config.MILVUS_USER if passed_in_user == "" else str(passed_in_user)

passed_in_password = password if passed_in_password == "" else str(passed_in_password)
passed_in_password = Config.MILVUS_PASSWORD if passed_in_password == "" else str(passed_in_password)

passed_in_db_name = db_name if passed_in_db_name == "" else str(passed_in_db_name)
passed_in_db_name = Config.MILVUS_DB_NAME if passed_in_db_name == "" else str(passed_in_db_name)
extracted_address = f"{host}:{port}"
ret = {}
ret["address"] = extracted_address
ret["user"] = user if extracted_user == "" else str(extracted_user)
ret["password"] = password if extracted_password == "" else str(extracted_password)
ret["db_name"] = db_name if extracted_db_name == "" else str(extracted_db_name)
ret["token"] = token if extracted_token == "" else str(extracted_token)
ret["secure"] = secure if extracted_secure is None else extracted_secure

return passed_in_address, passed_in_user, passed_in_password, passed_in_db_name
return ret

def __verify_host_port(self, host, port):
if not is_legal_host(host):
Expand Down Expand Up @@ -431,6 +460,7 @@ def __parse_address_from_uri(self, uri: str) -> Tuple[str, str, str, str]:
port = parsed_uri.port if parsed_uri.port is not None else ""
user = parsed_uri.username if parsed_uri.username is not None else ""
password = parsed_uri.password if parsed_uri.password is not None else ""
secure = parsed_uri.scheme.lower() == "https:"

if host == "":
raise ConnectionConfigException(message=f"Illegal uri: URI is missing host address: {uri}")
Expand All @@ -443,7 +473,7 @@ def __parse_address_from_uri(self, uri: str) -> Tuple[str, str, str, str]:
if not is_legal_address(addr):
raise ConnectionConfigException(message=illegal_uri_msg.format(uri))

return addr, user, password, db_name
return addr, user, password, db_name, secure


def _fetch_handler(self, alias=Config.MILVUS_CONN_ALIAS) -> GrpcHandler:
Expand Down
2 changes: 2 additions & 0 deletions pymilvus/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Config:

MILVUS_USER = env.str("MILVUS_USER", "")
MILVUS_PASSWORD = env.str("MILVUS_PASSWORD", "")
MILVUS_TOKEN = env.str("MILVUS_TOKEN", "")

MILVUS_DB_NAME = env.str("MILVUS_DB_NAME", "")

Expand All @@ -31,6 +32,7 @@ class Config:

DEFAULT_HOST = "localhost"
DEFAULT_PORT = "19530"
DEFAULT_SECURE = False

WaitTimeDurationWhenLoad = 0.5 # in seconds
MaxVarCharLengthKey = "max_length"
Expand Down
36 changes: 20 additions & 16 deletions tests/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def uri(self, request):

def test_connect_with_default_config(self):
alias = "default"
default_addr = {"address": "localhost:19530", "user": "", "db_name": ""}
default_addr = {"address": "localhost:19530", "user": "", "db_name": "", "secure": False}

assert connections.has_connection(alias) is False
addr = connections.get_connection_addr(alias)
Expand Down Expand Up @@ -123,6 +123,8 @@ def test_connect_new_alias_with_configs(self):

a = connections.get_connection_addr(alias)
a.pop("user")
print(a)
addr["secure"] = False
assert a == addr

with mock.patch(f"{mock_prefix}.close", return_value=None):
Expand All @@ -140,24 +142,24 @@ def test_connect_new_alias_with_configs_NoHostOrPort(self, no_host_or_port):
connections.connect(alias, **no_host_or_port)

assert connections.has_connection(alias) is True
assert connections.get_connection_addr(alias) == {"address": "localhost:19530", "user": "", "db_name": ""}
assert connections.get_connection_addr(alias) == {"address": "localhost:19530", "user": "", "db_name": "", "secure": False}

with mock.patch(f"{mock_prefix}.close", return_value=None):
connections.remove_connection(alias)

def test_connect_new_alias_with_no_config(self):
alias = self.test_connect_new_alias_with_no_config.__name__
# def test_connect_new_alias_with_no_config(self):
# alias = self.test_connect_new_alias_with_no_config.__name__

assert connections.has_connection(alias) is False
a = connections.get_connection_addr(alias)
assert a == {}
# assert connections.has_connection(alias) is False
# a = connections.get_connection_addr(alias)
# assert a == {}

with pytest.raises(MilvusException) as excinfo:
connections.connect(alias)
# with pytest.raises(MilvusException) as excinfo:
# connections.connect(alias)

LOGGER.info(f"Exception info: {excinfo.value}")
assert "You need to pass in the configuration" in excinfo.value.message
assert ErrorCode.UNEXPECTED_ERROR == excinfo.value.code
# LOGGER.info(f"Exception info: {excinfo.value}")
# assert "You need to pass in the configuration" in excinfo.value.message
# assert ErrorCode.UNEXPECTED_ERROR == excinfo.value.code

def test_connect_with_uri(self, uri):
alias = self.test_connect_with_uri.__name__
Expand Down Expand Up @@ -196,18 +198,20 @@ def test_add_connection_then_connect(self, uri):
def test_connect_with_reuse_grpc(self):
alias = "default"
default_addr = {"address": "localhost:19530", "user": "", "db_name": ""}
check_addr = {"address": "localhost:19530", "user": "", "db_name": "", "secure": False}

reuse_alias = "reuse"

assert connections.has_connection(alias) is False
addr = connections.get_connection_addr(alias)
assert addr == default_addr
assert addr == check_addr

with mock.patch(f"{mock_prefix}.__init__", return_value=None):
with mock.patch(f"{mock_prefix}._wait_for_channel_ready", return_value=None):
connections.connect(alias=alias, **default_addr)
connections.connect(alias=reuse_alias, **default_addr)
assert connections._connected_alias[alias] == connections._connected_alias[reuse_alias]
print(connections._connected_alias, flush=True)
assert list(connections._connection_references.values())[0] == 2

with mock.patch(f"{mock_prefix}.close", return_value=None):
Expand Down Expand Up @@ -387,13 +391,13 @@ def test_issue_1196(self):
config = {"alias": alias, "host": "localhost", "port": "19531", "user": "root", "password": 12345, "secure": True}
connections.connect(**config)
config = connections.get_connection_addr(alias)
assert config == {"address": 'localhost:19531', "user": 'root', "db_name": ""}
assert config == {"address": 'localhost:19531', "user": 'root', "db_name": "", "secure": True}

connections.add_connection(default={"host": "localhost", "port": 19531})
config = connections.get_connection_addr("default")
assert config == {"address": 'localhost:19531', "user": "", "db_name": ""}
assert config == {"address": 'localhost:19531', "user": "", "db_name": "", "secure": False}

connections.connect("default", user="root", password="12345", secure=True)

config = connections.get_connection_addr("default")
assert config == {"address": 'localhost:19531', "user": 'root', "db_name": ""}
assert config == {"address": 'localhost:19531', "user": 'root', "db_name": "", "secure": True}

0 comments on commit e1fb9c9

Please sign in to comment.