Skip to content

Commit

Permalink
add /get_log_level, /reset_log_level, and /set_log_level to all…
Browse files Browse the repository at this point in the history
… rpcs (#18843)

* add `/get_log_level` and `/set_log_level` to all rpcs

* tidy

* for py <3.11

* fixup

* add tests

* a few no covers

* add `RpcClient.reset_log_level()`

* tidy

* fixup
  • Loading branch information
altendky authored Nov 14, 2024
1 parent cb9f7d2 commit f5235c2
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 14 deletions.
1 change: 1 addition & 0 deletions chia/_tests/rpc/test_rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ async def test_client_fetch_methods(
RpcClient.open_connection: {"host": "", "port": 0},
RpcClient.close_connection: {"node_id": b""},
RpcClient.get_connections: {"node_type": NodeType.FULL_NODE},
RpcClient.set_log_level: {"level": "DEBUG"},
}

try:
Expand Down
183 changes: 183 additions & 0 deletions chia/_tests/rpc/test_rpc_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from __future__ import annotations

import contextlib
import dataclasses
import logging
import ssl
import sys
from collections.abc import AsyncIterator
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast

import aiohttp
import pytest

from chia.rpc.rpc_server import Endpoint, EndpointResult, RpcServer, RpcServiceProtocol
from chia.ssl.create_ssl import create_all_ssl
from chia.util.config import load_config
from chia.util.ints import uint16
from chia.util.ws_message import WsRpcMessage

root_logger = logging.getLogger()

if sys.version_info >= (3, 11): # pragma: no cover
name_to_number_level_map = logging.getLevelNamesMapping()
else:
name_to_number_level_map = logging._nameToLevel

number_to_name_level_map = {number: name for name, number in name_to_number_level_map.items()}

# just picking one for which a config is present
service_name = "full_node"


@dataclasses.dataclass
class TestRpcApi:
if TYPE_CHECKING:
from chia.rpc.rpc_server import RpcApiProtocol

_protocol_check: ClassVar[RpcApiProtocol] = cast("TestRpcApi", None)

# unused as of the initial writing of these tests
service: RpcServiceProtocol
service_name: str = service_name

async def _state_changed(self, change: str, change_data: Optional[dict[str, Any]] = None) -> list[WsRpcMessage]:
# just here to satisfy the complete protocol
return [] # pragma: no cover

def get_routes(self) -> dict[str, Endpoint]:
return {
"/log": self.log,
}

async def log(self, request: dict[str, Any]) -> EndpointResult:
message = request["message"]

level = name_to_number_level_map[request["level"]]

root_logger.log(level=level, msg=message)

return {}


@dataclasses.dataclass
class Client:
session: aiohttp.ClientSession
ssl_context: ssl.SSLContext
url: str

@classmethod
@contextlib.asynccontextmanager
async def managed(cls, ssl_context: ssl.SSLContext, url: str) -> AsyncIterator[Client]:
async with aiohttp.ClientSession() as session:
yield cls(session=session, ssl_context=ssl_context, url=url)

async def request(self, endpoint: str, json: Optional[dict[str, Any]] = None) -> dict[str, Any]:
if json is None:
json = {}

async with self.session.post(
self.url.rstrip("/") + "/" + endpoint.lstrip("/"),
json=json,
ssl=self.ssl_context,
) as response:
response.raise_for_status()
json = await response.json()

assert json is not None
assert json["success"], json

return json

async def log(self, level: str, message: str) -> None:
await self.request("log", json={"message": message, "level": level})


@pytest.fixture(name="server")
async def server_fixture(
root_path_populated_with_config: Path,
self_hostname: str,
) -> AsyncIterator[RpcServer[TestRpcApi]]:
config = load_config(root_path=root_path_populated_with_config, filename="config.yaml")
service_config = config[service_name]

create_all_ssl(root_path=root_path_populated_with_config)
rpc_server = RpcServer.create(
# the test rpc api doesn't presently need a real service for these tests
rpc_api=TestRpcApi(service=None), # type: ignore[arg-type]
service_name="test_rpc_server",
stop_cb=lambda: None,
root_path=root_path_populated_with_config,
net_config=config,
service_config=service_config,
prefer_ipv6=False,
)

try:
await rpc_server.start(
self_hostname=self_hostname,
rpc_port=uint16(0),
max_request_body_size=2**16,
)

yield rpc_server
finally:
rpc_server.close()
await rpc_server.await_closed()


@pytest.fixture(name="client")
async def client_fixture(
server: RpcServer[TestRpcApi],
) -> AsyncIterator[Client]:
assert server.webserver is not None
async with Client.managed(ssl_context=server.ssl_client_context, url=server.webserver.url()) as client:
yield client


@pytest.mark.anyio
async def test_get_log_level(
client: Client,
caplog: pytest.LogCaptureFixture,
) -> None:
level = "WARNING"
root_logger.setLevel(level)
result = await client.request("get_log_level")
assert result["level"] == number_to_name_level_map[root_logger.level]


@pytest.mark.anyio
async def test_set_log_level(
client: Client,
caplog: pytest.LogCaptureFixture,
) -> None:
message = "just a maybe unique probably message"

level = "WARNING"
await client.request("set_log_level", json={"level": level})
assert number_to_name_level_map[root_logger.level] == level

caplog.clear()
await client.log(message=message, level="WARNING")
assert caplog.messages == [message]

caplog.clear()
await client.log(message=message, level="INFO")
assert caplog.messages == []


@pytest.mark.anyio
async def test_reset_log_level(
client: Client,
server: RpcServer[TestRpcApi],
) -> None:
configured_level = server.service_config["logging"]["log_level"]
temporary_level = "INFO"
assert configured_level != temporary_level

root_logger.setLevel(temporary_level)
assert number_to_name_level_map[root_logger.level] == temporary_level

await client.request("reset_log_level")
assert number_to_name_level_map[root_logger.level] == configured_level
3 changes: 3 additions & 0 deletions chia/_tests/util/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ async def validate_get_routes(client: RpcClient, api: RpcApiProtocol) -> None:
"/get_routes",
"/get_version",
"/healthz",
"/get_log_level",
"/set_log_level",
"/reset_log_level",
]
assert len(routes_api) > 0
assert sorted(routes_client) == sorted(routes_api + routes_server)
9 changes: 9 additions & 0 deletions chia/rpc/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ async def get_routes(self) -> dict:
async def get_version(self) -> dict:
return await self.fetch("get_version", {})

async def get_log_level(self) -> dict:
return await self.fetch("get_log_level", {})

async def set_log_level(self, level: str) -> dict:
return await self.fetch("set_log_level", {"level": level})

async def reset_log_level(self) -> dict:
return await self.fetch("reset_log_level", {})

def close(self) -> None:
self.closing_task = asyncio.create_task(self.session.close())

Expand Down
77 changes: 72 additions & 5 deletions chia/rpc/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import json
import logging
import sys
import traceback
from collections.abc import AsyncIterator, Awaitable
from dataclasses import dataclass
Expand All @@ -12,21 +13,38 @@
from types import MethodType
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar

from aiohttp import ClientConnectorError, ClientSession, ClientWebSocketResponse, WSMsgType, web
from aiohttp import (
ClientConnectorError,
ClientSession,
ClientWebSocketResponse,
WSMsgType,
web,
)
from typing_extensions import Protocol, final

from chia import __version__
from chia.rpc.util import wrap_http_handler
from chia.server.outbound_message import NodeType
from chia.server.server import ChiaServer, ssl_context_for_client, ssl_context_for_server
from chia.server.server import (
ChiaServer,
ssl_context_for_client,
ssl_context_for_server,
)
from chia.server.ws_connection import WSChiaConnection
from chia.types.peer_info import PeerInfo
from chia.util.byte_types import hexstr_to_bytes
from chia.util.chia_logging import default_log_level, set_log_level
from chia.util.config import str2bool
from chia.util.ints import uint16
from chia.util.json_util import dict_to_json_str
from chia.util.network import WebServer, resolve
from chia.util.ws_message import WsRpcMessage, create_payload, create_payload_dict, format_response, pong
from chia.util.ws_message import (
WsRpcMessage,
create_payload,
create_payload_dict,
format_response,
pong,
)

log = logging.getLogger(__name__)
max_message_size = 50 * 1024 * 1024 # 50MB
Expand Down Expand Up @@ -134,6 +152,7 @@ class RpcServer(Generic[_T_RpcApiProtocol]):
ssl_context: SSLContext
ssl_client_context: SSLContext
net_config: dict[str, Any]
service_config: dict[str, Any]
webserver: Optional[WebServer] = None
daemon_heartbeat: int = 300
daemon_connection_task: Optional[asyncio.Task[None]] = None
Expand All @@ -150,6 +169,7 @@ def create(
stop_cb: Callable[[], None],
root_path: Path,
net_config: dict[str, Any],
service_config: dict[str, Any],
prefer_ipv6: bool,
) -> RpcServer[_T_RpcApiProtocol]:
crt_path = root_path / net_config["daemon_ssl"]["private_crt"]
Expand All @@ -166,6 +186,7 @@ def create(
ssl_context,
ssl_client_context,
net_config,
service_config=service_config,
daemon_heartbeat=daemon_heartbeat,
prefer_ipv6=prefer_ipv6,
)
Expand Down Expand Up @@ -251,7 +272,11 @@ async def get_network_info(self, _: dict[str, Any]) -> EndpointResult:
network_name = self.net_config["selected_network"]
address_prefix = self.net_config["network_overrides"]["config"][network_name]["address_prefix"]
genesis_challenge = self.net_config["network_overrides"]["constants"][network_name]["GENESIS_CHALLENGE"]
return {"network_name": network_name, "network_prefix": address_prefix, "genesis_challenge": genesis_challenge}
return {
"network_name": network_name,
"network_prefix": address_prefix,
"genesis_challenge": genesis_challenge,
}

async def get_connections(self, request: dict[str, Any]) -> EndpointResult:
request_node_type: Optional[NodeType] = None
Expand Down Expand Up @@ -303,6 +328,38 @@ async def get_version(self, request: dict[str, Any]) -> EndpointResult:
"version": __version__,
}

async def get_log_level(self, request: dict[str, Any]) -> EndpointResult:
logger = logging.getLogger()
level_number = logger.level
level_name = logging.getLevelName(level_number)

if sys.version_info >= (3, 11):
map = logging.getLevelNamesMapping()
else:
map = logging._nameToLevel

return {
"success": True,
"level": level_name,
"available_levels": list(map),
}

async def reset_log_level(self, request: dict[str, Any]) -> EndpointResult:
level_name = self.service_config.get("log_level", default_log_level)

return await self.set_log_level(request={"level": level_name})

async def set_log_level(self, request: dict[str, Any]) -> EndpointResult:
error_strings = set_log_level(log_level=request["level"], service_name=self.service_name)
status = await self.get_log_level(request={})

status["success"] &= len(error_strings) == 0

return {
**status,
"errors": error_strings,
}

async def ws_api(self, message: WsRpcMessage) -> Optional[dict[str, object]]:
"""
This function gets called when new message is received via websocket.
Expand Down Expand Up @@ -413,6 +470,9 @@ async def inner() -> None:
"/get_routes": get_routes,
"/get_version": get_version,
"/healthz": healthz,
"/get_log_level": get_log_level,
"/set_log_level": set_log_level,
"/reset_log_level": reset_log_level,
}


Expand All @@ -424,6 +484,7 @@ async def start_rpc_server(
stop_cb: Callable[[], None],
root_path: Path,
net_config: dict[str, object],
service_config: dict[str, object],
connect_to_daemon: bool = True,
max_request_body_size: Optional[int] = None,
) -> RpcServer[_T_RpcApiProtocol]:
Expand All @@ -438,7 +499,13 @@ async def start_rpc_server(
prefer_ipv6 = str2bool(str(net_config.get("prefer_ipv6", False)))

rpc_server = RpcServer.create(
rpc_api, rpc_api.service_name, stop_cb, root_path, net_config, prefer_ipv6=prefer_ipv6
rpc_api,
rpc_api.service_name,
stop_cb,
root_path,
net_config,
service_config=service_config,
prefer_ipv6=prefer_ipv6,
)
rpc_server.rpc_api.service._set_state_changed_callback(rpc_server.state_changed)
await rpc_server.start(self_hostname, rpc_port, max_request_body_size)
Expand Down
3 changes: 2 additions & 1 deletion chia/server/start_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ async def manage(self, *, start: bool = True) -> AsyncIterator[None]:
self.stop_requested.set,
self.root_path,
self.config,
self._connect_to_daemon,
service_config=self.service_config,
connect_to_daemon=self._connect_to_daemon,
max_request_body_size=self.max_request_body_size,
)
yield
Expand Down
Loading

0 comments on commit f5235c2

Please sign in to comment.