diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 560e8bc00..c71b58a23 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,7 +11,7 @@ jobs: matrix: os: - ubuntu-latest - python: [ 2.7, 3.7 ] + python: [ 3.7, 3.9, 3.10.7] splunk-version: - "8.1" - "8.2" diff --git a/Makefile b/Makefile index 2810c6aec..9f1bbd8b6 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ docs: .PHONY: test test: @echo "$(ATTN_COLOR)==> test $(NO_COLOR)" - @tox -e py27,py37 + @tox -e py37,py39 .PHONY: test_specific test_specific: diff --git a/setup.py b/setup.py index 284c50983..169d79157 100755 --- a/setup.py +++ b/setup.py @@ -24,10 +24,7 @@ failed = False def run_test_suite(): - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest def mark_failed(): global failed @@ -143,7 +140,8 @@ def run(self): packages = ["splunklib", "splunklib.modularinput", - "splunklib.searchcommands"], + "splunklib.searchcommands", + "splunklib.customrest"], url="http://github.com/splunk/splunk-sdk-python", diff --git a/splunklib/__init__.py b/splunklib/__init__.py index 774cb7576..11894e777 100644 --- a/splunklib/__init__.py +++ b/splunklib/__init__.py @@ -14,12 +14,10 @@ """Python library for Splunk.""" -from __future__ import absolute_import -from splunklib.six.moves import map import logging DEFAULT_LOG_FORMAT = '%(asctime)s, Level=%(levelname)s, Pid=%(process)s, Logger=%(name)s, File=%(filename)s, ' \ - 'Line=%(lineno)s, %(message)s' + 'Line=%(lineno)s, %(message)s' DEFAULT_DATE_FORMAT = '%Y-%m-%d %H:%M:%S %Z' @@ -31,5 +29,51 @@ def setup_logging(level, log_format=DEFAULT_LOG_FORMAT, date_format=DEFAULT_DATE format=log_format, datefmt=date_format) + +def ensure_binary(s, encoding='utf-8', errors='strict'): + """ + - `str` -> encoded to `bytes` + - `bytes` -> `bytes` + """ + if isinstance(s, str): + return s.encode(encoding, errors) + + if isinstance(s, bytes): + return s + + raise TypeError(f"not expecting type '{type(s)}'") + + +def ensure_str(s, encoding='utf-8', errors='strict'): + """ + - `str` -> `str` + - `bytes` -> decoded to `str` + """ + if isinstance(s, bytes): + return s.decode(encoding, errors) + + if isinstance(s, str): + return s + + raise TypeError(f"not expecting type '{type(s)}'") + + +def ensure_text(s, encoding='utf-8', errors='strict'): + """ + - `str` -> `str` + - `bytes` -> decoded to `str` + """ + if isinstance(s, bytes): + return s.decode(encoding, errors) + if isinstance(s, str): + return s + raise TypeError(f"not expecting type '{type(s)}'") + + +def assertRegex(self, *args, **kwargs): + return getattr(self, "assertRegex")(*args, **kwargs) + + __version_info__ = (1, 7, 2) + __version__ = ".".join(map(str, __version_info__)) diff --git a/splunklib/binding.py b/splunklib/binding.py index 370f076cf..b5f08b551 100644 --- a/splunklib/binding.py +++ b/splunklib/binding.py @@ -24,31 +24,23 @@ :mod:`splunklib.client` module. """ -from __future__ import absolute_import - import io import logging import socket import ssl -import sys import time from base64 import b64encode from contextlib import contextmanager from datetime import datetime from functools import wraps from io import BytesIO -from xml.etree.ElementTree import XML - +from urllib import parse +from http import client +from http.cookies import SimpleCookie +from xml.etree.ElementTree import XML, ParseError +from splunklib.data import record from splunklib import __version__ -from splunklib import six -from splunklib.six.moves import urllib -from .data import record - -try: - from xml.etree.ElementTree import ParseError -except ImportError as e: - from xml.parsers.expat import ExpatError as ParseError logger = logging.getLogger(__name__) @@ -57,7 +49,12 @@ "connect", "Context", "handler", - "HTTPError" + "HTTPError", + "UrlEncoded", + "_encode", + "_make_cookie_header", + "_NoAuthenticationToken", + "namespace" ] # If you change these, update the docstring @@ -66,14 +63,16 @@ DEFAULT_PORT = "8089" DEFAULT_SCHEME = "https" + def _log_duration(f): @wraps(f) def new_f(*args, **kwargs): start_time = datetime.now() val = f(*args, **kwargs) end_time = datetime.now() - logger.debug("Operation took %s", end_time-start_time) + logger.debug("Operation took %s", end_time - start_time) return val + return new_f @@ -93,8 +92,8 @@ def _parse_cookies(cookie_str, dictionary): :param dictionary: A dictionary to update with any found key-value pairs. :type dictionary: ``dict`` """ - parsed_cookie = six.moves.http_cookies.SimpleCookie(cookie_str) - for cookie in parsed_cookie.values(): + parsed_cookie = SimpleCookie(cookie_str) + for cookie in list(parsed_cookie.values()): dictionary[cookie.key] = cookie.coded_value @@ -115,10 +114,11 @@ def _make_cookie_header(cookies): :return: ``str` An HTTP header cookie string. :rtype: ``str`` """ - return "; ".join("%s=%s" % (key, value) for key, value in cookies) + return "; ".join(f"{key}={value}" for key, value in cookies) + # Singleton values to eschew None -class _NoAuthenticationToken(object): +class _NoAuthenticationToken: """The value stored in a :class:`Context` or :class:`splunklib.client.Service` class that is not logged in. @@ -130,7 +130,6 @@ class that is not logged in. Likewise, after a ``Context`` or ``Service`` object has been logged out, the token is set to this value again. """ - pass class UrlEncoded(str): @@ -156,7 +155,7 @@ class UrlEncoded(str): **Example**:: import urllib - UrlEncoded('%s://%s' % (scheme, urllib.quote(host)), skip_encode=True) + UrlEncoded(f'{scheme}://{urllib.quote(host)}', skip_encode=True) If you append ``str`` strings and ``UrlEncoded`` strings, the result is also URL encoded. @@ -166,19 +165,19 @@ class UrlEncoded(str): UrlEncoded('ab c') + 'de f' == UrlEncoded('ab cde f') 'ab c' + UrlEncoded('de f') == UrlEncoded('ab cde f') """ + def __new__(self, val='', skip_encode=False, encode_slash=False): if isinstance(val, UrlEncoded): # Don't urllib.quote something already URL encoded. return val - elif skip_encode: + if skip_encode: return str.__new__(self, val) - elif encode_slash: - return str.__new__(self, urllib.parse.quote_plus(val)) - else: - # When subclassing str, just call str's __new__ method - # with your class and the value you want to have in the - # new string. - return str.__new__(self, urllib.parse.quote(val)) + if encode_slash: + return str.__new__(self, parse.quote_plus(val)) + # When subclassing str, just call str.__new__ method + # with your class and the value you want to have in the + # new string. + return str.__new__(self, parse.quote(val)) def __add__(self, other): """self + other @@ -188,8 +187,8 @@ def __add__(self, other): """ if isinstance(other, UrlEncoded): return UrlEncoded(str.__add__(self, other), skip_encode=True) - else: - return UrlEncoded(str.__add__(self, urllib.parse.quote(other)), skip_encode=True) + + return UrlEncoded(str.__add__(self, parse.quote(other)), skip_encode=True) def __radd__(self, other): """other + self @@ -199,8 +198,8 @@ def __radd__(self, other): """ if isinstance(other, UrlEncoded): return UrlEncoded(str.__radd__(self, other), skip_encode=True) - else: - return UrlEncoded(str.__add__(urllib.parse.quote(other), self), skip_encode=True) + + return UrlEncoded(str.__add__(parse.quote(other), self), skip_encode=True) def __mod__(self, fields): """Interpolation into ``UrlEncoded``s is disabled. @@ -209,15 +208,17 @@ def __mod__(self, fields): ``TypeError``. """ raise TypeError("Cannot interpolate into a UrlEncoded object.") + def __repr__(self): - return "UrlEncoded(%s)" % repr(urllib.parse.unquote(str(self))) + return f"UrlEncoded({repr(parse.unquote(str(self)))})" + @contextmanager def _handle_auth_error(msg): - """Handle reraising HTTP authentication errors as something clearer. + """Handle re-raising HTTP authentication errors as something clearer. If an ``HTTPError`` is raised with status 401 (access denied) in - the body of this context manager, reraise it as an + the body of this context manager, re-raise it as an ``AuthenticationError`` instead, with *msg* as its message. This function adds no round trips to the server. @@ -238,6 +239,7 @@ def _handle_auth_error(msg): else: raise + def _authentication(request_fun): """Decorator to handle autologin and authentication errors. @@ -270,12 +272,12 @@ def _authentication(request_fun): def f(): c.get("/services") return 42 - print _authentication(f) + print(_authentication(f)) """ + @wraps(request_fun) def wrapper(self, *args, **kwargs): - if self.token is _NoAuthenticationToken and \ - not self.has_cookies(): + if self.token is _NoAuthenticationToken and not self.has_cookies(): # Not yet logged in. if self.autologin and self.username and self.password: # This will throw an uncaught @@ -297,8 +299,7 @@ def wrapper(self, *args, **kwargs): # an AuthenticationError and give up. with _handle_auth_error("Autologin failed."): self.login() - with _handle_auth_error( - "Authentication Failed! If session token is used, it seems to have been expired."): + with _handle_auth_error("Authentication Failed! If session token is used, it seems to have been expired."): return request_fun(self, *args, **kwargs) elif he.status == 401 and not self.autologin: raise AuthenticationError( @@ -348,10 +349,10 @@ def _authority(scheme=DEFAULT_SCHEME, host=DEFAULT_HOST, port=DEFAULT_PORT): """ if ':' in host: - # IPv6 addresses must be enclosed in [ ] in order to be well - # formed. + # IPv6 addresses must be enclosed in [ ] in order to be well-formed. host = '[' + host + ']' - return UrlEncoded("%s://%s:%s" % (scheme, host, port), skip_encode=True) + return UrlEncoded(f"{scheme}://{host}:{port}", skip_encode=True) + # kwargs: sharing, owner, app def namespace(sharing=None, owner=None, app=None, **kwargs): @@ -406,7 +407,7 @@ def namespace(sharing=None, owner=None, app=None, **kwargs): n = binding.namespace(sharing="global", app="search") """ if sharing in ["system"]: - return record({'sharing': sharing, 'owner': "nobody", 'app': "system" }) + return record({'sharing': sharing, 'owner': "nobody", 'app': "system"}) if sharing in ["global", "app"]: return record({'sharing': sharing, 'owner': "nobody", 'app': app}) if sharing in ["user", None]: @@ -414,7 +415,7 @@ def namespace(sharing=None, owner=None, app=None, **kwargs): raise ValueError("Invalid value for argument: 'sharing'") -class Context(object): +class Context: """This class represents a context that encapsulates a splunkd connection. The ``Context`` class encapsulates the details of HTTP requests, @@ -433,7 +434,7 @@ class Context(object): :type port: ``integer`` :param scheme: The scheme for accessing the service (the default is "https"). :type scheme: "https" or "http" - :param verify: Enable (True) or disable (False) SSL verrification for https connections. + :param verify: Enable (True) or disable (False) SSL verification for https connections. :type verify: ``Boolean`` :param sharing: The sharing mode for the namespace (the default is "user"). :type sharing: "global", "system", "app", or "user" @@ -476,12 +477,14 @@ class Context(object): # Or if you already have a valid cookie c = binding.Context(cookie="splunkd_8089=...") """ + def __init__(self, handler=None, **kwargs): self.http = HttpLib(handler, kwargs.get("verify", False), key_file=kwargs.get("key_file"), - cert_file=kwargs.get("cert_file"), context=kwargs.get("context"), # Default to False for backward compat + cert_file=kwargs.get("cert_file"), context=kwargs.get("context"), + # Default to False for backward compat retries=kwargs.get("retries", 0), retryDelay=kwargs.get("retryDelay", 10)) self.token = kwargs.get("token", _NoAuthenticationToken) - if self.token is None: # In case someone explicitly passes token=None + if self.token is None: # In case someone explicitly passes token=None self.token = _NoAuthenticationToken self.scheme = kwargs.get("scheme", DEFAULT_SCHEME) self.host = kwargs.get("host", DEFAULT_HOST) @@ -514,7 +517,7 @@ def has_cookies(self): :rtype: ``bool`` """ auth_token_key = "splunkd_" - return any(auth_token_key in key for key in self.get_cookies().keys()) + return any(auth_token_key in key for key in list(self.get_cookies().keys())) # Shared per-context request headers @property @@ -531,9 +534,9 @@ def _auth_headers(self): if self.has_cookies(): return [("Cookie", _make_cookie_header(list(self.get_cookies().items())))] elif self.basic and (self.username and self.password): - token = 'Basic %s' % b64encode(("%s:%s" % (self.username, self.password)).encode('utf-8')).decode('ascii') + token = f'Basic {b64encode(("%s:%s" % (self.username, self.password)).encode("utf-8")).decode("ascii")}' elif self.bearerToken: - token = 'Bearer %s' % self.bearerToken + token = f'Bearer {self.bearerToken}' elif self.token is _NoAuthenticationToken: token = [] else: @@ -541,7 +544,7 @@ def _auth_headers(self): if self.token.startswith('Splunk '): token = self.token else: - token = 'Splunk %s' % self.token + token = f'Splunk {self.token}' if token: header.append(("Authorization", token)) if self.get_cookies(): @@ -839,12 +842,12 @@ def request(self, path_segment, method="GET", headers=None, body={}, headers = [] path = self.authority \ - + self._abspath(path_segment, owner=owner, - app=app, sharing=sharing) + + self._abspath(path_segment, owner=owner, + app=app, sharing=sharing) all_headers = headers + self.additional_headers + self._auth_headers logger.debug("%s request to %s (headers: %s, body: %s)", - method, path, str(all_headers), repr(body)) + method, path, str(all_headers), repr(body)) if body: body = _encode(**body) @@ -886,14 +889,14 @@ def login(self): """ if self.has_cookies() and \ - (not self.username and not self.password): + (not self.username and not self.password): # If we were passed session cookie(s), but no username or # password, then login is a nop, since we're automatically # logged in. return if self.token is not _NoAuthenticationToken and \ - (not self.username and not self.password): + (not self.username and not self.password): # If we were passed a session token, but no username or # password, then login is a nop, since we're automatically # logged in. @@ -915,11 +918,11 @@ def login(self): username=self.username, password=self.password, headers=self.additional_headers, - cookie="1") # In Splunk 6.2+, passing "cookie=1" will return the "set-cookie" header + cookie="1") # In Splunk 6.2+, passing "cookie=1" will return the "set-cookie" header body = response.body.read() session = XML(body).findtext("./sessionKey") - self.token = "Splunk %s" % session + self.token = f"Splunk {session}" return self except HTTPError as he: if he.status == 401: @@ -934,7 +937,7 @@ def logout(self): return self def _abspath(self, path_segment, - owner=None, app=None, sharing=None): + owner=None, app=None, sharing=None): """Qualifies *path_segment* into an absolute path for a URL. If *path_segment* is already absolute, returns it unchanged. @@ -986,12 +989,11 @@ def _abspath(self, path_segment, # namespace. If only one of app and owner is specified, use # '-' for the other. if ns.app is None and ns.owner is None: - return UrlEncoded("/services/%s" % path_segment, skip_encode=skip_encode) + return UrlEncoded(f"/services/{path_segment}", skip_encode=skip_encode) oname = "nobody" if ns.owner is None else ns.owner aname = "system" if ns.app is None else ns.app - path = UrlEncoded("/servicesNS/%s/%s/%s" % (oname, aname, path_segment), - skip_encode=skip_encode) + path = UrlEncoded(f"/servicesNS/{oname}/{aname}/{path_segment}", skip_encode=skip_encode) return path @@ -1042,21 +1044,23 @@ def connect(**kwargs): c.login() return c + # Note: the error response schema supports multiple messages but we only # return the first, although we do return the body so that an exception # handler that wants to read multiple messages can do so. class HTTPError(Exception): """This exception is raised for HTTP responses that return an error.""" + def __init__(self, response, _message=None): status = response.status reason = response.reason body = response.body.read() try: detail = XML(body).findtext("./messages/msg") - except ParseError as err: + except ParseError: detail = body - message = "HTTP %d %s%s" % ( - status, reason, "" if detail is None else " -- %s" % detail) + detail_formatted = "" if detail is None else f" -- {detail}" + message = f"HTTP {status} {reason}{detail_formatted}" Exception.__init__(self, _message or message) self.status = status self.reason = reason @@ -1064,6 +1068,7 @@ def __init__(self, response, _message=None): self.body = body self._response = response + class AuthenticationError(HTTPError): """Raised when a login request to Splunk fails. @@ -1071,6 +1076,7 @@ class AuthenticationError(HTTPError): in a call to :meth:`Context.login` or :meth:`splunklib.client.Service.login`, this exception is raised. """ + def __init__(self, message, cause): # Put the body back in the response so that HTTPError's constructor can # read it again. @@ -1078,6 +1084,7 @@ def __init__(self, message, cause): HTTPError.__init__(self, cause._response, message) + # # The HTTP interface used by the Splunk binding layer abstracts the underlying # HTTP library using request & response 'messages' which are implemented as @@ -1105,16 +1112,17 @@ def __init__(self, message, cause): # 'foo=1&foo=2&foo=3'. def _encode(**kwargs): items = [] - for key, value in six.iteritems(kwargs): + for key, value in list(kwargs.items()): if isinstance(value, list): items.extend([(key, item) for item in value]) else: items.append((key, value)) - return urllib.parse.urlencode(items) + return parse.urlencode(items) + # Crack the given url into (scheme, host, port, path) def _spliturl(url): - parsed_url = urllib.parse.urlparse(url) + parsed_url = parse.urlparse(url) host = parsed_url.hostname port = parsed_url.port path = '?'.join((parsed_url.path, parsed_url.query)) if parsed_url.query else parsed_url.path @@ -1123,9 +1131,10 @@ def _spliturl(url): if port is None: port = DEFAULT_PORT return parsed_url.scheme, host, port, path + # Given an HTTP request handler, this wrapper objects provides a related # family of convenience methods built using that handler. -class HttpLib(object): +class HttpLib: """A set of convenient methods for making HTTP calls. ``HttpLib`` provides a general :meth:`request` method, and :meth:`delete`, @@ -1167,7 +1176,9 @@ class HttpLib(object): If using the default handler, SSL verification can be disabled by passing verify=False. """ - def __init__(self, custom_handler=None, verify=False, key_file=None, cert_file=None, context=None, retries=0, retryDelay=10): + + def __init__(self, custom_handler=None, verify=False, key_file=None, cert_file=None, context=None, retries=0, + retryDelay=10): if custom_handler is None: self.handler = handler(verify=verify, key_file=key_file, cert_file=cert_file, context=context) else: @@ -1228,7 +1239,7 @@ def get(self, url, headers=None, **kwargs): # the query to be encoded or it will get automatically URL # encoded by being appended to url. url = url + UrlEncoded('?' + _encode(**kwargs), skip_encode=True) - return self.request(url, { 'method': "GET", 'headers': headers }) + return self.request(url, {'method': "GET", 'headers': headers}) def post(self, url, headers=None, **kwargs): """Sends a POST request to a URL. @@ -1324,6 +1335,7 @@ class ResponseReader(io.RawIOBase): types of HTTP libraries used with this SDK. This class also provides a preview of the stream and a few useful predicates. """ + # For testing, you can use a StringIO as the argument to # ``ResponseReader`` instead of an ``httplib.HTTPResponse``. It # will work equally well. @@ -1333,10 +1345,7 @@ def __init__(self, response, connection=None): self._buffer = b'' def __str__(self): - if six.PY2: - return self.read() - else: - return str(self.read(), 'UTF-8') + return str(self.read(), 'UTF-8') @property def empty(self): @@ -1362,7 +1371,7 @@ def close(self): self._connection.close() self._response.close() - def read(self, size = None): + def read(self, size=None): """Reads a given number of characters from the response. :param size: The number of characters to read, or "None" to read the @@ -1415,7 +1424,7 @@ def connect(scheme, host, port): kwargs = {} if timeout is not None: kwargs['timeout'] = timeout if scheme == "http": - return six.moves.http_client.HTTPConnection(host, port, **kwargs) + return client.HTTPConnection(host, port, **kwargs) if scheme == "https": if key_file is not None: kwargs['key_file'] = key_file if cert_file is not None: kwargs['cert_file'] = cert_file @@ -1426,8 +1435,8 @@ def connect(scheme, host, port): # verify is True in elif branch and context is not None kwargs['context'] = context - return six.moves.http_client.HTTPSConnection(host, port, **kwargs) - raise ValueError("unsupported scheme: %s" % scheme) + return client.HTTPSConnection(host, port, **kwargs) + raise ValueError(f"unsupported scheme: {scheme}") def request(url, message, **kwargs): scheme, host, port, path = _spliturl(url) @@ -1438,7 +1447,7 @@ def request(url, message, **kwargs): "User-Agent": "splunk-sdk-python/%s" % __version__, "Accept": "*/*", "Connection": "Close", - } # defaults + } # defaults for key, value in message["headers"]: head[key] = value method = message.get("method", "GET") diff --git a/splunklib/client.py b/splunklib/client.py index 564a40f66..a8c5ac34d 100644 --- a/splunklib/client.py +++ b/splunklib/client.py @@ -54,7 +54,7 @@ are subclasses of :class:`Entity`. An ``Entity`` object has fields for its attributes, and methods that are specific to each kind of entity. For example:: - print my_app['author'] # Or: print my_app.author + print(my_app['author']) # Or: print(my_app.author) my_app.package() # Creates a compressed package of this application """ @@ -66,15 +66,13 @@ import socket from datetime import datetime, timedelta from time import sleep +from urllib import parse -from splunklib import six -from splunklib.six.moves import urllib - -from . import data -from .binding import (AuthenticationError, Context, HTTPError, UrlEncoded, - _encode, _make_cookie_header, _NoAuthenticationToken, - namespace) -from .data import record +from splunklib import data +from splunklib.data import record +from splunklib.binding import (AuthenticationError, Context, HTTPError, UrlEncoded, + _encode, _make_cookie_header, _NoAuthenticationToken, + namespace) logger = logging.getLogger(__name__) @@ -84,7 +82,8 @@ "OperationError", "IncomparableException", "Service", - "namespace" + "namespace", + "AuthenticationError" ] PATH_APPS = "apps/local/" @@ -106,7 +105,7 @@ PATH_MODULAR_INPUTS = "data/modular-inputs" PATH_ROLES = "authorization/roles/" PATH_SAVED_SEARCHES = "saved/searches/" -PATH_STANZA = "configs/conf-%s/%s" # (file, stanza) +PATH_STANZA = "configs/conf-%s/%s" # (file, stanza) PATH_USERS = "authentication/users/" PATH_RECEIVERS_STREAM = "/services/receivers/stream" PATH_RECEIVERS_SIMPLE = "/services/receivers/simple" @@ -116,45 +115,38 @@ XNAME_ENTRY = XNAMEF_ATOM % "entry" XNAME_CONTENT = XNAMEF_ATOM % "content" -MATCH_ENTRY_CONTENT = "%s/%s/*" % (XNAME_ENTRY, XNAME_CONTENT) +MATCH_ENTRY_CONTENT = f"{XNAME_ENTRY}/{XNAME_CONTENT}/*" class IllegalOperationException(Exception): """Thrown when an operation is not possible on the Splunk instance that a :class:`Service` object is connected to.""" - pass class IncomparableException(Exception): """Thrown when trying to compare objects (using ``==``, ``<``, ``>``, and so on) of a type that doesn't support it.""" - pass class AmbiguousReferenceException(ValueError): """Thrown when the name used to fetch an entity matches more than one entity.""" - pass class InvalidNameException(Exception): """Thrown when the specified name contains characters that are not allowed in Splunk entity names.""" - pass class NoSuchCapability(Exception): """Thrown when the capability that has been referred to doesn't exist.""" - pass class OperationError(Exception): - """Raised for a failed operation, such as a time out.""" - pass + """Raised for a failed operation, such as a timeout.""" class NotSupportedError(Exception): """Raised for operations that are not supported on a given object.""" - pass def _trailing(template, *targets): @@ -190,8 +182,9 @@ def _trailing(template, *targets): def _filter_content(content, *args): if len(args) > 0: return record((k, content[k]) for k in args) - return record((k, v) for k, v in six.iteritems(content) - if k not in ['eai:acl', 'eai:attributes', 'type']) + return record((k, v) for k, v in list(content.items()) + if k not in ['eai:acl', 'eai:attributes', 'type']) + # Construct a resource path from the given base path + resource name def _path(base, name): @@ -221,10 +214,9 @@ def _load_atom_entries(response): # its state wrapped in another element, but at the top level. # For example, in XML, it returns ... instead of # .... - else: - entries = r.get('entry', None) - if entries is None: return None - return entries if isinstance(entries, list) else [entries] + entries = r.get('entry', None) + if entries is None: return None + return entries if isinstance(entries, list) else [entries] # Load the sid from the body of the given response @@ -250,7 +242,7 @@ def _parse_atom_entry(entry): metadata = _parse_atom_metadata(content) # Filter some of the noise out of the content record - content = record((k, v) for k, v in six.iteritems(content) + content = record((k, v) for k, v in list(content.items()) if k not in ['eai:acl', 'eai:attributes']) if 'type' in content: @@ -289,6 +281,7 @@ def _parse_atom_metadata(content): return record({'access': access, 'fields': fields}) + # kwargs: scheme, host, port, app, owner, username, password def connect(**kwargs): """This function connects and logs in to a Splunk instance. @@ -417,8 +410,9 @@ class Service(_BaseService): # Or if you already have a valid cookie s = client.Service(cookie="splunkd_8089=...") """ + def __init__(self, **kwargs): - super(Service, self).__init__(**kwargs) + super().__init__(**kwargs) self._splunk_version = None self._kvstore_owner = None self._instance_type = None @@ -538,8 +532,7 @@ def modular_input_kinds(self): """ if self.splunk_version >= (5,): return ReadOnlyCollection(self, PATH_MODULAR_INPUTS, item=ModularInputKind) - else: - raise IllegalOperationException("Modular inputs are not supported before Splunk version 5.") + raise IllegalOperationException("Modular inputs are not supported before Splunk version 5.") @property def storage_passwords(self): @@ -589,7 +582,7 @@ def restart(self, timeout=None): :param timeout: A timeout period, in seconds. :type timeout: ``integer`` """ - msg = { "value": "Restart requested by " + self.username + "via the Splunk SDK for Python"} + msg = {"value": "Restart requested by " + self.username + "via the Splunk SDK for Python"} # This message will be deleted once the server actually restarts. self.messages.create(name="restart_required", **msg) result = self.post("/services/server/control/restart") @@ -693,7 +686,7 @@ def splunk_version(self): :return: A ``tuple`` of ``integers``. """ if self._splunk_version is None: - self._splunk_version = tuple([int(p) for p in self.info['version'].split('.')]) + self._splunk_version = tuple(int(p) for p in self.info['version'].split('.')) return self._splunk_version @property @@ -751,13 +744,14 @@ def users(self): return Users(self) -class Endpoint(object): +class Endpoint: """This class represents individual Splunk resources in the Splunk REST API. An ``Endpoint`` object represents a URI, such as ``/services/saved/searches``. This class provides the common functionality of :class:`Collection` and :class:`Entity` (essentially HTTP GET and POST methods). """ + def __init__(self, service, path): self.service = service self.path = path @@ -774,11 +768,11 @@ def get_api_version(self, path): # Default to v1 if undefined in the path # For example, "/services/search/jobs" is using API v1 api_version = 1 - + versionSearch = re.search('(?:servicesNS\/[^/]+\/[^/]+|services)\/[^/]+\/v(\d+)\/', path) if versionSearch: api_version = int(versionSearch.group(1)) - + return api_version def get(self, path_segment="", owner=None, app=None, sharing=None, **query): @@ -852,7 +846,7 @@ def get(self, path_segment="", owner=None, app=None, sharing=None, **query): # - Fallback from v2+ to v1 if Splunk Version is < 9. # if api_version >= 2 and ('search' in query and path.endswith(tuple(["results_preview", "events", "results"])) or self.service.splunk_version < (9,)): # path = path.replace(PATH_JOBS_V2, PATH_JOBS) - + if api_version == 1: if isinstance(path, UrlEncoded): path = UrlEncoded(path.replace(PATH_JOBS_V2, PATH_JOBS), skip_encode=True) @@ -911,14 +905,14 @@ def post(self, path_segment="", owner=None, app=None, sharing=None, **query): apps.get('nonexistant/path') # raises HTTPError s.logout() apps.get() # raises AuthenticationError - """ + """ if path_segment.startswith('/'): path = path_segment else: if not self.path.endswith('/') and path_segment != "": self.path = self.path + '/' path = self.service._abspath(self.path + path_segment, owner=owner, app=app, sharing=sharing) - + # Get the API version from the path api_version = self.get_api_version(path) @@ -927,7 +921,7 @@ def post(self, path_segment="", owner=None, app=None, sharing=None, **query): # - Fallback from v2+ to v1 if Splunk Version is < 9. # if api_version >= 2 and ('search' in query and path.endswith(tuple(["results_preview", "events", "results"])) or self.service.splunk_version < (9,)): # path = path.replace(PATH_JOBS_V2, PATH_JOBS) - + if api_version == 1: if isinstance(path, UrlEncoded): path = UrlEncoded(path.replace(PATH_JOBS_V2, PATH_JOBS), skip_encode=True) @@ -1003,7 +997,6 @@ def __init__(self, service, path, **kwargs): self._state = None if not kwargs.get('skip_refresh', False): self.refresh(kwargs.get('state', None)) # "Prefresh" - return def __contains__(self, item): try: @@ -1028,24 +1021,21 @@ def __eq__(self, other): but then ``x != saved_searches['asearch']``. whether or not there was a change on the server. Rather than - try to do something fancy, we simple declare that equality is + try to do something fancy, we simply declare that equality is undefined for Entities. Makes no roundtrips to the server. """ - raise IncomparableException( - "Equality is undefined for objects of class %s" % \ - self.__class__.__name__) + raise IncomparableException(f"Equality is undefined for objects of class {self.__class__.__name__}") def __getattr__(self, key): # Called when an attribute was not found by the normal method. In this # case we try to find it in self.content and then self.defaults. if key in self.state.content: return self.state.content[key] - elif key in self.defaults: + if key in self.defaults: return self.defaults[key] - else: - raise AttributeError(key) + raise AttributeError(key) def __getitem__(self, key): # getattr attempts to find a field on the object in the normal way, @@ -1061,9 +1051,8 @@ def _load_atom_entry(self, response): apps = [ele.entry.content.get('eai:appName') for ele in elem] raise AmbiguousReferenceException( - "Fetch from server returned multiple entries for name '%s' in apps %s." % (elem[0].entry.title, apps)) - else: - return elem.entry + f"Fetch from server returned multiple entries for name '{elem[0].entry.title}' in apps {apps}.") + return elem.entry # Load the entity state record from the given response def _load_state(self, response): @@ -1096,17 +1085,15 @@ def _proper_namespace(self, owner=None, app=None, sharing=None): :param sharing: :return: """ - if owner is None and app is None and sharing is None: # No namespace provided + if owner is None and app is None and sharing is None: # No namespace provided if self._state is not None and 'access' in self._state: return (self._state.access.owner, self._state.access.app, self._state.access.sharing) - else: - return (self.service.namespace['owner'], + return (self.service.namespace['owner'], self.service.namespace['app'], self.service.namespace['sharing']) - else: - return (owner,app,sharing) + return owner, app, sharing def delete(self): owner, app, sharing = self._proper_namespace() @@ -1114,11 +1101,11 @@ def delete(self): def get(self, path_segment="", owner=None, app=None, sharing=None, **query): owner, app, sharing = self._proper_namespace(owner, app, sharing) - return super(Entity, self).get(path_segment, owner=owner, app=app, sharing=sharing, **query) + return super().get(path_segment, owner=owner, app=app, sharing=sharing, **query) def post(self, path_segment="", owner=None, app=None, sharing=None, **query): owner, app, sharing = self._proper_namespace(owner, app, sharing) - return super(Entity, self).post(path_segment, owner=owner, app=app, sharing=sharing, **query) + return super().post(path_segment, owner=owner, app=app, sharing=sharing, **query) def refresh(self, state=None): """Refreshes the state of this entity. @@ -1206,8 +1193,8 @@ def read(self, response): # In lower layers of the SDK, we end up trying to URL encode # text to be dispatched via HTTP. However, these links are already # URL encoded when they arrive, and we need to mark them as such. - unquoted_links = dict([(k, UrlEncoded(v, skip_encode=True)) - for k,v in six.iteritems(results['links'])]) + unquoted_links = dict((k, UrlEncoded(v, skip_encode=True)) + for k, v in list(results['links'].items())) results['links'] = unquoted_links return results @@ -1282,7 +1269,7 @@ def update(self, **kwargs): """ # The peculiarity in question: the REST API creates a new # Entity if we pass name in the dictionary, instead of the - # expected behavior of updating this Entity. Therefore we + # expected behavior of updating this Entity. Therefore, we # check for 'name' in kwargs and throw an error if it is # there. if 'name' in kwargs: @@ -1295,9 +1282,10 @@ class ReadOnlyCollection(Endpoint): """This class represents a read-only collection of entities in the Splunk instance. """ + def __init__(self, service, path, item=Entity): Endpoint.__init__(self, service, path) - self.item = item # Item accessor + self.item = item # Item accessor self.null_count = -1 def __contains__(self, name): @@ -1329,7 +1317,7 @@ def __getitem__(self, key): name. Where there is no conflict, ``__getitem__`` will fetch the - entity given just the name. If there is a conflict and you + entity given just the name. If there is a conflict, and you pass just a name, it will raise a ``ValueError``. In that case, add the namespace as a second argument. @@ -1376,13 +1364,13 @@ def __getitem__(self, key): response = self.get(key) entries = self._load_list(response) if len(entries) > 1: - raise AmbiguousReferenceException("Found multiple entities named '%s'; please specify a namespace." % key) - elif len(entries) == 0: + raise AmbiguousReferenceException( + f"Found multiple entities named '{key}'; please specify a namespace.") + if len(entries) == 0: raise KeyError(key) - else: - return entries[0] + return entries[0] except HTTPError as he: - if he.status == 404: # No entity matching key and namespace. + if he.status == 404: # No entity matching key and namespace. raise KeyError(key) else: raise @@ -1405,7 +1393,7 @@ def __iter__(self, **kwargs): c = client.connect(...) saved_searches = c.saved_searches for entity in saved_searches: - print "Saved search named %s" % entity.name + print(f"Saved search named {entity.name}") """ for item in self.iter(**kwargs): @@ -1446,13 +1434,12 @@ def _entity_path(self, state): # This has been factored out so that it can be easily # overloaded by Configurations, which has to switch its # entities' endpoints from its own properties/ to configs/. - raw_path = urllib.parse.unquote(state.links.alternate) + raw_path = parse.unquote(state.links.alternate) if 'servicesNS/' in raw_path: return _trailing(raw_path, 'servicesNS/', '/', '/') - elif 'services/' in raw_path: + if 'services/' in raw_path: return _trailing(raw_path, 'services/') - else: - return raw_path + return raw_path def _load_list(self, response): """Converts *response* to a list of entities. @@ -1615,8 +1602,6 @@ def list(self, count=None, **kwargs): return list(self.iter(count=count, **kwargs)) - - class Collection(ReadOnlyCollection): """A collection of entities. @@ -1690,8 +1675,8 @@ def create(self, name, **params): applications = s.apps new_app = applications.create("my_fake_app") """ - if not isinstance(name, six.string_types): - raise InvalidNameException("%s is not a valid name for an entity." % name) + if not isinstance(name, str): + raise InvalidNameException(f"{name} is not a valid name for an entity.") if 'namespace' in params: namespace = params.pop('namespace') params['owner'] = namespace.owner @@ -1703,14 +1688,13 @@ def create(self, name, **params): # This endpoint doesn't return the content of the new # item. We have to go fetch it ourselves. return self[name] - else: - entry = atom.entry - state = _parse_atom_entry(entry) - entity = self.item( - self.service, - self._entity_path(state), - state=state) - return entity + entry = atom.entry + state = _parse_atom_entry(entry) + entity = self.item( + self.service, + self._entity_path(state), + state=state) + return entity def delete(self, name, **params): """Deletes a specified entity from the collection. @@ -1750,7 +1734,7 @@ def delete(self, name, **params): # has already been deleted, and we reraise it as a # KeyError. if he.status == 404: - raise KeyError("No such entity %s" % name) + raise KeyError(f"No such entity {name}") else: raise return self @@ -1801,14 +1785,13 @@ def get(self, name="", owner=None, app=None, sharing=None, **query): """ name = UrlEncoded(name, encode_slash=True) - return super(Collection, self).get(name, owner, app, sharing, **query) - - + return super().get(name, owner, app, sharing, **query) class ConfigurationFile(Collection): """This class contains all of the stanzas from one configuration file. """ + # __init__'s arguments must match those of an Entity, not a # Collection, since it is being created as the elements of a # Configurations, which is a Collection subclass. @@ -1825,6 +1808,7 @@ class Configurations(Collection): stanzas. This collection is unusual in that the values in it are themselves collections of :class:`ConfigurationFile` objects. """ + def __init__(self, service): Collection.__init__(self, service, PATH_PROPERTIES, item=ConfigurationFile) if self.service.namespace.owner == '-' or self.service.namespace.app == '-': @@ -1842,7 +1826,7 @@ def __getitem__(self, key): response = self.get(key) return ConfigurationFile(self.service, PATH_CONF % key, state={'title': key}) except HTTPError as he: - if he.status == 404: # No entity matching key + if he.status == 404: # No entity matching key raise KeyError(key) else: raise @@ -1854,10 +1838,9 @@ def __contains__(self, key): response = self.get(key) return True except HTTPError as he: - if he.status == 404: # No entity matching key + if he.status == 404: # No entity matching key return False - else: - raise + raise def create(self, name): """ Creates a configuration file named *name*. @@ -1873,15 +1856,14 @@ def create(self, name): # This has to be overridden to handle the plumbing of creating # a ConfigurationFile (which is a Collection) instead of some # Entity. - if not isinstance(name, six.string_types): - raise ValueError("Invalid name: %s" % repr(name)) + if not isinstance(name, str): + raise ValueError(f"Invalid name: {repr(name)}") response = self.post(__conf=name) if response.status == 303: return self[name] - elif response.status == 201: + if response.status == 201: return ConfigurationFile(self.service, PATH_CONF % name, item=Stanza, state={'title': name}) - else: - raise ValueError("Unexpected status code %s returned from creating a stanza" % response.status) + raise ValueError(f"Unexpected status code {response.status} returned from creating a stanza") def delete(self, key): """Raises `IllegalOperationException`.""" @@ -1913,17 +1895,18 @@ def __len__(self): # The stanza endpoint returns all the keys at the same level in the XML as the eai information # and 'disabled', so to get an accurate length, we have to filter those out and have just # the stanza keys. - return len([x for x in self._state.content.keys() + return len([x for x in list(self._state.content.keys()) if not x.startswith('eai') and x != 'disabled']) class StoragePassword(Entity): """This class contains a storage password. """ + def __init__(self, service, path, **kwargs): state = kwargs.get('state', None) kwargs['skip_refresh'] = kwargs.get('skip_refresh', state is not None) - super(StoragePassword, self).__init__(service, path, **kwargs) + super().__init__(service, path, **kwargs) self._state = state @property @@ -1947,8 +1930,11 @@ class StoragePasswords(Collection): """This class provides access to the storage passwords from this Splunk instance. Retrieve this collection using :meth:`Service.storage_passwords`. """ + def __init__(self, service): - super(StoragePasswords, self).__init__(service, PATH_STORAGE_PASSWORDS, item=StoragePassword) + if service.namespace.owner == '-' or service.namespace.app == '-': + raise ValueError("StoragePasswords cannot have wildcards in namespace.") + super().__init__(service, PATH_STORAGE_PASSWORDS, item=StoragePassword) def create(self, password, username, realm=None): """ Creates a storage password. @@ -1965,11 +1951,8 @@ def create(self, password, username, realm=None): :return: The :class:`StoragePassword` object created. """ - if self.service.namespace.owner == '-' or self.service.namespace.app == '-': - raise ValueError("While creating StoragePasswords, namespace cannot have wildcards.") - - if not isinstance(username, six.string_types): - raise ValueError("Invalid name: %s" % repr(username)) + if not isinstance(username, str): + raise ValueError(f"Invalid name: {repr(username)}") if realm is None: response = self.post(password=password, name=username) @@ -1977,7 +1960,7 @@ def create(self, password, username, realm=None): response = self.post(password=password, realm=realm, name=username) if response.status != 201: - raise ValueError("Unexpected status code %s returned from creating a stanza" % response.status) + raise ValueError(f"Unexpected status code {response.status} returned from creating a stanza") entries = _load_atom_entries(response) state = _parse_atom_entry(entries[0]) @@ -2020,6 +2003,7 @@ def delete(self, username, realm=None): class AlertGroup(Entity): """This class represents a group of fired alerts for a saved search. Access it using the :meth:`alerts` property.""" + def __init__(self, service, path, **kwargs): Entity.__init__(self, service, path, **kwargs) @@ -2048,6 +2032,7 @@ class Indexes(Collection): """This class contains the collection of indexes in this Splunk instance. Retrieve this collection using :meth:`Service.indexes`. """ + def get_default(self): """ Returns the name of the default index. @@ -2075,6 +2060,7 @@ def delete(self, name): class Index(Entity): """This class represents an index and provides different operations, such as cleaning the index, writing to the index, and so forth.""" + def __init__(self, service, path, **kwargs): Entity.__init__(self, service, path, **kwargs) @@ -2091,26 +2077,26 @@ def attach(self, host=None, source=None, sourcetype=None): :return: A writable socket. """ - args = { 'index': self.name } + args = {'index': self.name} if host is not None: args['host'] = host if source is not None: args['source'] = source if sourcetype is not None: args['sourcetype'] = sourcetype - path = UrlEncoded(PATH_RECEIVERS_STREAM + "?" + urllib.parse.urlencode(args), skip_encode=True) + path = UrlEncoded(PATH_RECEIVERS_STREAM + "?" + parse.urlencode(args), skip_encode=True) - cookie_or_auth_header = "Authorization: Splunk %s\r\n" % \ - (self.service.token if self.service.token is _NoAuthenticationToken - else self.service.token.replace("Splunk ", "")) + cookie_header = self.service.token if self.service.token is _NoAuthenticationToken else self.service.token.replace("Splunk ", "") + cookie_or_auth_header = f"Authorization: Splunk {cookie_header}\r\n" # If we have cookie(s), use them instead of "Authorization: ..." if self.service.has_cookies(): - cookie_or_auth_header = "Cookie: %s\r\n" % _make_cookie_header(self.service.get_cookies().items()) + cookie_header = _make_cookie_header(list(self.service.get_cookies().items())) + cookie_or_auth_header = f"Cookie: {cookie_header}\r\n" # Since we need to stream to the index connection, we have to keep # the connection open and use the Splunk extension headers to note # the input mode sock = self.service.connect() - headers = [("POST %s HTTP/1.1\r\n" % str(self.service._abspath(path))).encode('utf-8'), - ("Host: %s:%s\r\n" % (self.service.host, int(self.service.port))).encode('utf-8'), + headers = [f"POST {str(self.service._abspath(path))} HTTP/1.1\r\n".encode('utf-8'), + f"Host: {self.service.host}:{int(self.service.port)}\r\n".encode('utf-8'), b"Accept-Encoding: identity\r\n", cookie_or_auth_header.encode('utf-8'), b"X-Splunk-Input-Mode: Streaming\r\n", @@ -2172,8 +2158,7 @@ def clean(self, timeout=60): ftp = self['frozenTimePeriodInSecs'] was_disabled_initially = self.disabled try: - if (not was_disabled_initially and \ - self.service.splunk_version < (5,)): + if not was_disabled_initially and self.service.splunk_version < (5,): # Need to disable the index first on Splunk 4.x, # but it doesn't work to disable it on 5.0. self.disable() @@ -2183,17 +2168,17 @@ def clean(self, timeout=60): # Wait until event count goes to 0. start = datetime.now() diff = timedelta(seconds=timeout) - while self.content.totalEventCount != '0' and datetime.now() < start+diff: + while self.content.totalEventCount != '0' and datetime.now() < start + diff: sleep(1) self.refresh() if self.content.totalEventCount != '0': - raise OperationError("Cleaning index %s took longer than %s seconds; timing out." % (self.name, timeout)) + raise OperationError( + f"Cleaning index {self.name} took longer than {timeout} seconds; timing out.") finally: # Restore original values self.update(maxTotalDataSizeMB=tds, frozenTimePeriodInSecs=ftp) - if (not was_disabled_initially and \ - self.service.splunk_version < (5,)): + if not was_disabled_initially and self.service.splunk_version < (5,): # Re-enable the index if it was originally enabled and we messed with it. self.enable() @@ -2221,7 +2206,7 @@ def submit(self, event, host=None, source=None, sourcetype=None): :return: The :class:`Index`. """ - args = { 'index': self.name } + args = {'index': self.name} if host is not None: args['host'] = host if source is not None: args['source'] = source if sourcetype is not None: args['sourcetype'] = sourcetype @@ -2255,6 +2240,7 @@ class Input(Entity): typed input classes and is also used when the client does not recognize an input kind. """ + def __init__(self, service, path, kind=None, **kwargs): # kind can be omitted (in which case it is inferred from the path) # Otherwise, valid values are the paths from data/inputs ("udp", @@ -2265,7 +2251,7 @@ def __init__(self, service, path, kind=None, **kwargs): path_segments = path.split('/') i = path_segments.index('inputs') + 1 if path_segments[i] == 'tcp': - self.kind = path_segments[i] + '/' + path_segments[i+1] + self.kind = path_segments[i] + '/' + path_segments[i + 1] else: self.kind = path_segments[i] else: @@ -2291,7 +2277,7 @@ def update(self, **kwargs): # UDP and TCP inputs require special handling due to their restrictToHost # field. For all other inputs kinds, we can dispatch to the superclass method. if self.kind not in ['tcp', 'splunktcp', 'tcp/raw', 'tcp/cooked', 'udp']: - return super(Input, self).update(**kwargs) + return super().update(**kwargs) else: # The behavior of restrictToHost is inconsistent across input kinds and versions of Splunk. # In Splunk 4.x, the name of the entity is only the port, independent of the value of @@ -2309,11 +2295,11 @@ def update(self, **kwargs): if 'restrictToHost' in kwargs: raise IllegalOperationException("Cannot set restrictToHost on an existing input with the SDK.") - elif 'restrictToHost' in self._state.content and self.kind != 'udp': + if 'restrictToHost' in self._state.content and self.kind != 'udp': to_update['restrictToHost'] = self._state.content['restrictToHost'] # Do the actual update operation. - return super(Input, self).update(**to_update) + return super().update(**to_update) # Inputs is a "kinded" collection, which is a heterogenous collection where @@ -2340,13 +2326,12 @@ def __getitem__(self, key): response = self.get(self.kindpath(kind) + "/" + key) entries = self._load_list(response) if len(entries) > 1: - raise AmbiguousReferenceException("Found multiple inputs of kind %s named %s." % (kind, key)) - elif len(entries) == 0: + raise AmbiguousReferenceException(f"Found multiple inputs of kind {kind} named {key}.") + if len(entries) == 0: raise KeyError((key, kind)) - else: - return entries[0] + return entries[0] except HTTPError as he: - if he.status == 404: # No entity matching kind and key + if he.status == 404: # No entity matching kind and key raise KeyError((key, kind)) else: raise @@ -2360,22 +2345,21 @@ def __getitem__(self, key): response = self.get(kind + "/" + key) entries = self._load_list(response) if len(entries) > 1: - raise AmbiguousReferenceException("Found multiple inputs of kind %s named %s." % (kind, key)) - elif len(entries) == 0: + raise AmbiguousReferenceException(f"Found multiple inputs of kind {kind} named {key}.") + if len(entries) == 0: pass - else: - if candidate is not None: # Already found at least one candidate - raise AmbiguousReferenceException("Found multiple inputs named %s, please specify a kind" % key) - candidate = entries[0] + if candidate is not None: # Already found at least one candidate + raise AmbiguousReferenceException( + f"Found multiple inputs named {key}, please specify a kind") + candidate = entries[0] except HTTPError as he: if he.status == 404: - pass # Just carry on to the next kind. + pass # Just carry on to the next kind. else: raise if candidate is None: - raise KeyError(key) # Never found a match. - else: - return candidate + raise KeyError(key) # Never found a match. + return candidate def __contains__(self, key): if isinstance(key, tuple) and len(key) == 2: @@ -2395,11 +2379,10 @@ def __contains__(self, key): entries = self._load_list(response) if len(entries) > 0: return True - else: - pass + pass except HTTPError as he: if he.status == 404: - pass # Just carry on to the next kind. + pass # Just carry on to the next kind. else: raise return False @@ -2451,9 +2434,8 @@ def create(self, name, kind, **kwargs): name = UrlEncoded(name, encode_slash=True) path = _path( self.path + kindpath, - '%s:%s' % (kwargs['restrictToHost'], name) \ - if 'restrictToHost' in kwargs else name - ) + f"{kwargs['restrictToHost']}:{name}" if 'restrictToHost' in kwargs else name + ) return Input(self.service, path, kind) def delete(self, name, kind=None): @@ -2523,7 +2505,7 @@ def itemmeta(self, kind): :return: The metadata. :rtype: class:``splunklib.data.Record`` """ - response = self.get("%s/_new" % self._kindmap[kind]) + response = self.get(f"{self._kindmap[kind]}/_new") content = _load_atom(response, MATCH_ENTRY_CONTENT) return _parse_atom_metadata(content) @@ -2538,9 +2520,9 @@ def _get_kind_list(self, subpath=None): this_subpath = subpath + [entry.title] # The "all" endpoint doesn't work yet. # The "tcp/ssl" endpoint is not a real input collection. - if entry.title == 'all' or this_subpath == ['tcp','ssl']: + if entry.title == 'all' or this_subpath == ['tcp', 'ssl']: continue - elif 'create' in [x.rel for x in entry.link]: + if 'create' in [x.rel for x in entry.link]: path = '/'.join(subpath + [entry.title]) kinds.append(path) else: @@ -2589,10 +2571,9 @@ def kindpath(self, kind): """ if kind == 'tcp': return UrlEncoded('tcp/raw', skip_encode=True) - elif kind == 'splunktcp': + if kind == 'splunktcp': return UrlEncoded('tcp/cooked', skip_encode=True) - else: - return UrlEncoded(kind, skip_encode=True) + return UrlEncoded(kind, skip_encode=True) def list(self, *kinds, **kwargs): """Returns a list of inputs that are in the :class:`Inputs` collection. @@ -2660,18 +2641,18 @@ def list(self, *kinds, **kwargs): path = UrlEncoded(path, skip_encode=True) response = self.get(path, **kwargs) except HTTPError as he: - if he.status == 404: # No inputs of this kind + if he.status == 404: # No inputs of this kind return [] entities = [] entries = _load_atom_entries(response) if entries is None: - return [] # No inputs in a collection comes back with no feed or entry in the XML + return [] # No inputs in a collection comes back with no feed or entry in the XML for entry in entries: state = _parse_atom_entry(entry) # Unquote the URL, since all URL encoded in the SDK # should be of type UrlEncoded, and all str should not # be URL encoded. - path = urllib.parse.unquote(state.links.alternate) + path = parse.unquote(state.links.alternate) entity = Input(self.service, path, kind, state=state) entities.append(entity) return entities @@ -2686,18 +2667,18 @@ def list(self, *kinds, **kwargs): response = self.get(self.kindpath(kind), search=search) except HTTPError as e: if e.status == 404: - continue # No inputs of this kind + continue # No inputs of this kind else: raise entries = _load_atom_entries(response) - if entries is None: continue # No inputs to process + if entries is None: continue # No inputs to process for entry in entries: state = _parse_atom_entry(entry) # Unquote the URL, since all URL encoded in the SDK # should be of type UrlEncoded, and all str should not # be URL encoded. - path = urllib.parse.unquote(state.links.alternate) + path = parse.unquote(state.links.alternate) entity = Input(self.service, path, kind, state=state) entities.append(entity) if 'offset' in kwargs: @@ -2765,6 +2746,7 @@ def oneshot(self, path, **kwargs): class Job(Entity): """This class represents a search job.""" + def __init__(self, service, sid, **kwargs): # Default to v2 in Splunk Version 9+ path = "{path}{sid}" @@ -2827,7 +2809,7 @@ def events(self, **kwargs): :return: The ``InputStream`` IO handle to this job's events. """ kwargs['segmentation'] = kwargs.get('segmentation', 'none') - + # Search API v1(GET) and v2(POST) if self.service.disable_v2_api: return self.get("events", **kwargs).body @@ -2898,10 +2880,10 @@ def results(self, **query_params): for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - print '%s: %s' % (result.type, result.message) + print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - print result + print(result) assert rr.is_preview == False Results are not available until the job has finished. If called on @@ -2919,7 +2901,7 @@ def results(self, **query_params): :return: The ``InputStream`` IO handle to this job's results. """ query_params['segmentation'] = query_params.get('segmentation', 'none') - + # Search API v1(GET) and v2(POST) if self.service.disable_v2_api: return self.get("results", **query_params).body @@ -2942,14 +2924,14 @@ def preview(self, **query_params): for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - print '%s: %s' % (result.type, result.message) + print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - print result + print(result) if rr.is_preview: - print "Preview of a running search job." + print("Preview of a running search job.") else: - print "Job is finished. Results are final." + print("Job is finished. Results are final.") This method makes one roundtrip to the server, plus at most two more if @@ -2964,7 +2946,7 @@ def preview(self, **query_params): :return: The ``InputStream`` IO handle to this job's preview results. """ query_params['segmentation'] = query_params.get('segmentation', 'none') - + # Search API v1(GET) and v2(POST) if self.service.disable_v2_api: return self.get("results_preview", **query_params).body @@ -3056,6 +3038,7 @@ def unpause(self): class Jobs(Collection): """This class represents a collection of search jobs. Retrieve this collection using :meth:`Service.jobs`.""" + def __init__(self, service): # Splunk 9 introduces the v2 endpoint if not service.disable_v2_api: @@ -3114,10 +3097,10 @@ def export(self, query, **params): for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - print '%s: %s' % (result.type, result.message) + print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - print result + print(result) assert rr.is_preview == False Running an export search is more efficient as it streams the results @@ -3170,10 +3153,10 @@ def oneshot(self, query, **params): for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - print '%s: %s' % (result.type, result.message) + print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - print result + print(result) assert rr.is_preview == False The ``oneshot`` method makes a single roundtrip to the server (as opposed @@ -3214,6 +3197,7 @@ def oneshot(self, query, **params): class Loggers(Collection): """This class represents a collection of service logging categories. Retrieve this collection using :meth:`Service.loggers`.""" + def __init__(self, service): Collection.__init__(self, service, PATH_LOGGER) @@ -3245,19 +3229,18 @@ class ModularInputKind(Entity): """This class contains the different types of modular inputs. Retrieve this collection using :meth:`Service.modular_input_kinds`. """ + def __contains__(self, name): args = self.state.content['endpoints']['args'] if name in args: return True - else: - return Entity.__contains__(self, name) + return Entity.__contains__(self, name) def __getitem__(self, name): args = self.state.content['endpoint']['args'] if name in args: return args['item'] - else: - return Entity.__getitem__(self, name) + return Entity.__getitem__(self, name) @property def arguments(self): @@ -3282,6 +3265,7 @@ def update(self, **kwargs): class SavedSearch(Entity): """This class represents a saved search.""" + def __init__(self, service, path, **kwargs): Entity.__init__(self, service, path, **kwargs) @@ -3423,8 +3407,7 @@ def suppressed(self): r = self._run_action("suppress") if r.suppressed == "1": return int(r.expiration) - else: - return 0 + return 0 def unsuppress(self): """Cancels suppression and makes this search run as scheduled. @@ -3438,6 +3421,7 @@ def unsuppress(self): class SavedSearches(Collection): """This class represents a collection of saved searches. Retrieve this collection using :meth:`Service.saved_searches`.""" + def __init__(self, service): Collection.__init__( self, service, PATH_SAVED_SEARCHES, item=SavedSearch) @@ -3462,6 +3446,7 @@ def create(self, name, search, **kwargs): class Settings(Entity): """This class represents configuration settings for a Splunk service. Retrieve this collection using :meth:`Service.settings`.""" + def __init__(self, service, **kwargs): Entity.__init__(self, service, "/services/server/settings", **kwargs) @@ -3483,6 +3468,7 @@ def update(self, **kwargs): class User(Entity): """This class represents a Splunk user. """ + @property def role_entities(self): """Returns a list of roles assigned to this user. @@ -3499,6 +3485,7 @@ class Users(Collection): """This class represents the collection of Splunk users for this instance of Splunk. Retrieve this collection using :meth:`Service.users`. """ + def __init__(self, service): Collection.__init__(self, service, PATH_USERS, item=User) @@ -3538,8 +3525,8 @@ def create(self, username, password, roles, **params): boris = users.create("boris", "securepassword", roles="user") hilda = users.create("hilda", "anotherpassword", roles=["user","power"]) """ - if not isinstance(username, six.string_types): - raise ValueError("Invalid username: %s" % str(username)) + if not isinstance(username, str): + raise ValueError(f"Invalid username: {str(username)}") username = username.lower() self.post(name=username, password=password, roles=roles, **params) # splunkd doesn't return the user in the POST response body, @@ -3549,7 +3536,7 @@ def create(self, username, password, roles, **params): state = _parse_atom_entry(entry) entity = self.item( self.service, - urllib.parse.unquote(state.links.alternate), + parse.unquote(state.links.alternate), state=state) return entity @@ -3568,6 +3555,7 @@ def delete(self, name): class Role(Entity): """This class represents a user role. """ + def grant(self, *capabilities_to_grant): """Grants additional capabilities to this role. @@ -3618,8 +3606,8 @@ def revoke(self, *capabilities_to_revoke): for c in old_capabilities: if c not in capabilities_to_revoke: new_capabilities.append(c) - if new_capabilities == []: - new_capabilities = '' # Empty lists don't get passed in the body, so we have to force an empty argument. + if not new_capabilities: + new_capabilities = '' # Empty lists don't get passed in the body, so we have to force an empty argument. self.post(capabilities=new_capabilities) return self @@ -3627,6 +3615,7 @@ def revoke(self, *capabilities_to_revoke): class Roles(Collection): """This class represents the collection of roles in the Splunk instance. Retrieve this collection using :meth:`Service.roles`.""" + def __init__(self, service): return Collection.__init__(self, service, PATH_ROLES, item=Role) @@ -3661,8 +3650,8 @@ def create(self, name, **params): roles = c.roles paltry = roles.create("paltry", imported_roles="user", defaultApp="search") """ - if not isinstance(name, six.string_types): - raise ValueError("Invalid role name: %s" % str(name)) + if not isinstance(name, str): + raise ValueError(f"Invalid role name: {str(name)}") name = name.lower() self.post(name=name, **params) # splunkd doesn't return the user in the POST response body, @@ -3672,7 +3661,7 @@ def create(self, name, **params): state = _parse_atom_entry(entry) entity = self.item( self.service, - urllib.parse.unquote(state.links.alternate), + parse.unquote(state.links.alternate), state=state) return entity @@ -3689,6 +3678,7 @@ def delete(self, name): class Application(Entity): """Represents a locally-installed Splunk app.""" + @property def setupInfo(self): """Returns the setup information for the app. @@ -3705,11 +3695,12 @@ def updateInfo(self): """Returns any update information that is available for the app.""" return self._run_action("update") + class KVStoreCollections(Collection): def __init__(self, service): Collection.__init__(self, service, 'storage/collections/config', item=KVStoreCollection) - def create(self, name, indexes = {}, fields = {}, **kwargs): + def create(self, name, indexes={}, fields={}, **kwargs): """Creates a KV Store Collection. :param name: name of collection to create @@ -3723,14 +3714,15 @@ def create(self, name, indexes = {}, fields = {}, **kwargs): :return: Result of POST request """ - for k, v in six.iteritems(indexes): + for k, v in list(indexes.items()): if isinstance(v, dict): v = json.dumps(v) kwargs['index.' + k] = v - for k, v in six.iteritems(fields): + for k, v in list(fields.items()): kwargs['field.' + k] = v return self.post(name=name, **kwargs) + class KVStoreCollection(Entity): @property def data(self): @@ -3751,7 +3743,7 @@ def update_index(self, name, value): :return: Result of POST request """ kwargs = {} - kwargs['index.' + name] = value if isinstance(value, six.string_types) else json.dumps(value) + kwargs['index.' + name] = value if isinstance(value, str) else json.dumps(value) return self.post(**kwargs) def update_field(self, name, value): @@ -3768,7 +3760,8 @@ def update_field(self, name, value): kwargs['field.' + name] = value return self.post(**kwargs) -class KVStoreCollectionData(object): + +class KVStoreCollectionData: """This class represents the data endpoint for a KVStoreCollection. Retrieve using :meth:`KVStoreCollection.data` @@ -3801,7 +3794,7 @@ def query(self, **query): :rtype: ``array`` """ - for key, value in query.items(): + for key, value in list(query.items()): if isinstance(query[key], dict): query[key] = json.dumps(value) @@ -3831,7 +3824,8 @@ def insert(self, data): """ if isinstance(data, dict): data = json.dumps(data) - return json.loads(self._post('', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) + return json.loads( + self._post('', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) def delete(self, query=None): """ @@ -3869,7 +3863,8 @@ def update(self, id, data): """ if isinstance(data, dict): data = json.dumps(data) - return json.loads(self._post(UrlEncoded(str(id), encode_slash=True), headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) + return json.loads(self._post(UrlEncoded(str(id), encode_slash=True), headers=KVStoreCollectionData.JSON_HEADER, + body=data).body.read().decode('utf-8')) def batch_find(self, *dbqueries): """ @@ -3886,7 +3881,8 @@ def batch_find(self, *dbqueries): data = json.dumps(dbqueries) - return json.loads(self._post('batch_find', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) + return json.loads( + self._post('batch_find', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) def batch_save(self, *documents): """ @@ -3903,4 +3899,5 @@ def batch_save(self, *documents): data = json.dumps(documents) - return json.loads(self._post('batch_save', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) \ No newline at end of file + return json.loads( + self._post('batch_save', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) diff --git a/splunklib/customrest/__init__.py b/splunklib/customrest/__init__.py new file mode 100644 index 000000000..4bee5eaae --- /dev/null +++ b/splunklib/customrest/__init__.py @@ -0,0 +1,6 @@ +"""The following imports allow these classes to be imported via +the splunklib.customrest package like so: + +from splunklib.customrest import * +""" +from .json import json_handler \ No newline at end of file diff --git a/splunklib/customrest/json.py b/splunklib/customrest/json.py new file mode 100644 index 000000000..c7b868267 --- /dev/null +++ b/splunklib/customrest/json.py @@ -0,0 +1,51 @@ +import logging +import traceback +import json + +from functools import wraps + +def json_handler(func): + @wraps(func) + def wrapper(*args, **kwargs): + decorated = json_exception_handler(json_payload_extractor(func)) + return decorated(*args, **kwargs) + return wrapper + + +def json_payload_extractor(func): + @wraps(func) + def wrapper(self, in_string): + try: + request = json.loads(in_string) + kwargs = {'request': request, 'in_string': in_string} + if 'payload' in request: + # if request contains payload, parse it and add it as payload parameter + kwargs['payload'] = json.loads(request['payload']) + if 'query' in request: + # if request contains query, parse it and add it as query parameter + kwargs['query'] = _convert_tuples_to_dict(request['query']) + return func(self, **kwargs) + except ValueError as e: + return {'payload': {'success': 'false', 'result': f'Error parsing JSON: {e}'}, + 'status': 400 + } + return wrapper + + +def json_exception_handler(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logging.error( + f'error={repr(e)} traceback={traceback.format_exc()}') + return {'payload': {'success': 'false', 'message': f'Error: {repr(e)}'}, + 'status': 500 + } + return wrapper + + +def _convert_tuples_to_dict(tuples): + return {t[0]: t[1] for t in tuples} + diff --git a/splunklib/data.py b/splunklib/data.py index f9ffb8692..c889ff9bc 100644 --- a/splunklib/data.py +++ b/splunklib/data.py @@ -12,16 +12,13 @@ # License for the specific language governing permissions and limitations # under the License. -"""The **splunklib.data** module reads the responses from splunkd in Atom Feed +"""The **splunklib.data** module reads the responses from splunkd in Atom Feed format, which is the format used by most of the REST API. """ -from __future__ import absolute_import -import sys from xml.etree.ElementTree import XML -from splunklib import six -__all__ = ["load"] +__all__ = ["load", "record"] # LNAME refers to element names without namespaces; XNAME is the same # name, but with an XML namespace. @@ -36,33 +33,41 @@ XNAME_KEY = XNAMEF_REST % LNAME_KEY XNAME_LIST = XNAMEF_REST % LNAME_LIST + # Some responses don't use namespaces (eg: search/parse) so we look for # both the extended and local versions of the following names. + def isdict(name): - return name == XNAME_DICT or name == LNAME_DICT + return name in (XNAME_DICT, LNAME_DICT) + def isitem(name): - return name == XNAME_ITEM or name == LNAME_ITEM + return name in (XNAME_ITEM, LNAME_ITEM) + def iskey(name): - return name == XNAME_KEY or name == LNAME_KEY + return name in (XNAME_KEY, LNAME_KEY) + def islist(name): - return name == XNAME_LIST or name == LNAME_LIST + return name in (XNAME_LIST, LNAME_LIST) + def hasattrs(element): return len(element.attrib) > 0 + def localname(xname): rcurly = xname.find('}') - return xname if rcurly == -1 else xname[rcurly+1:] + return xname if rcurly == -1 else xname[rcurly + 1:] + def load(text, match=None): - """This function reads a string that contains the XML of an Atom Feed, then - returns the - data in a native Python structure (a ``dict`` or ``list``). If you also - provide a tag name or path to match, only the matching sub-elements are + """This function reads a string that contains the XML of an Atom Feed, then + returns the + data in a native Python structure (a ``dict`` or ``list``). If you also + provide a tag name or path to match, only the matching sub-elements are loaded. :param text: The XML text to load. @@ -78,30 +83,27 @@ def load(text, match=None): 'names': {} } - # Convert to unicode encoding in only python 2 for xml parser - if(sys.version_info < (3, 0, 0) and isinstance(text, unicode)): - text = text.encode('utf-8') - root = XML(text) items = [root] if match is None else root.findall(match) count = len(items) - if count == 0: + if count == 0: return None - elif count == 1: + if count == 1: return load_root(items[0], nametable) - else: - return [load_root(item, nametable) for item in items] + return [load_root(item, nametable) for item in items] + # Load the attributes of the given element. def load_attrs(element): if not hasattrs(element): return None attrs = record() - for key, value in six.iteritems(element.attrib): + for key, value in list(element.attrib.items()): attrs[key] = value return attrs + # Parse a element and return a Python dict -def load_dict(element, nametable = None): +def load_dict(element, nametable=None): value = record() children = list(element) for child in children: @@ -110,6 +112,7 @@ def load_dict(element, nametable = None): value[name] = load_value(child, nametable) return value + # Loads the given elements attrs & value into single merged dict. def load_elem(element, nametable=None): name = localname(element.tag) @@ -118,12 +121,12 @@ def load_elem(element, nametable=None): if attrs is None: return name, value if value is None: return name, attrs # If value is simple, merge into attrs dict using special key - if isinstance(value, six.string_types): + if isinstance(value, str): attrs["$text"] = value return name, attrs # Both attrs & value are complex, so merge the two dicts, resolving collisions. collision_keys = [] - for key, val in six.iteritems(attrs): + for key, val in list(attrs.items()): if key in value and key in collision_keys: value[key].append(val) elif key in value and key not in collision_keys: @@ -133,6 +136,7 @@ def load_elem(element, nametable=None): value[key] = val return name, value + # Parse a element and return a Python list def load_list(element, nametable=None): assert islist(element.tag) @@ -143,6 +147,7 @@ def load_list(element, nametable=None): value.append(load_value(child, nametable)) return value + # Load the given root element. def load_root(element, nametable=None): tag = element.tag @@ -151,6 +156,7 @@ def load_root(element, nametable=None): k, v = load_elem(element, nametable) return Record.fromkv(k, v) + # Load the children of the given element. def load_value(element, nametable=None): children = list(element) @@ -159,7 +165,7 @@ def load_value(element, nametable=None): # No children, assume a simple text value if count == 0: text = element.text - if text is None: + if text is None: return None if len(text.strip()) == 0: @@ -179,7 +185,7 @@ def load_value(element, nametable=None): # If we have seen this name before, promote the value to a list if name in value: current = value[name] - if not isinstance(current, list): + if not isinstance(current, list): value[name] = [current] value[name].append(item) else: @@ -187,23 +193,24 @@ def load_value(element, nametable=None): return value + # A generic utility that enables "dot" access to dicts class Record(dict): - """This generic utility class enables dot access to members of a Python + """This generic utility class enables dot access to members of a Python dictionary. - Any key that is also a valid Python identifier can be retrieved as a field. - So, for an instance of ``Record`` called ``r``, ``r.key`` is equivalent to - ``r['key']``. A key such as ``invalid-key`` or ``invalid.key`` cannot be - retrieved as a field, because ``-`` and ``.`` are not allowed in + Any key that is also a valid Python identifier can be retrieved as a field. + So, for an instance of ``Record`` called ``r``, ``r.key`` is equivalent to + ``r['key']``. A key such as ``invalid-key`` or ``invalid.key`` cannot be + retrieved as a field, because ``-`` and ``.`` are not allowed in identifiers. - Keys of the form ``a.b.c`` are very natural to write in Python as fields. If - a group of keys shares a prefix ending in ``.``, you can retrieve keys as a + Keys of the form ``a.b.c`` are very natural to write in Python as fields. If + a group of keys shares a prefix ending in ``.``, you can retrieve keys as a nested dictionary by calling only the prefix. For example, if ``r`` contains keys ``'foo'``, ``'bar.baz'``, and ``'bar.qux'``, ``r.bar`` returns a record - with the keys ``baz`` and ``qux``. If a key contains multiple ``.``, each - one is placed into a nested dictionary, so you can write ``r.bar.qux`` or + with the keys ``baz`` and ``qux``. If a key contains multiple ``.``, each + one is placed into a nested dictionary, so you can write ``r.bar.qux`` or ``r['bar.qux']`` interchangeably. """ sep = '.' @@ -215,7 +222,7 @@ def __call__(self, *args): def __getattr__(self, name): try: return self[name] - except KeyError: + except KeyError: raise AttributeError(name) def __delattr__(self, name): @@ -235,7 +242,7 @@ def __getitem__(self, key): return dict.__getitem__(self, key) key += self.sep result = record() - for k,v in six.iteritems(self): + for k, v in list(self.items()): if not k.startswith(key): continue suffix = k[len(key):] @@ -250,17 +257,16 @@ def __getitem__(self, key): else: result[suffix] = v if len(result) == 0: - raise KeyError("No key or prefix: %s" % key) + raise KeyError(f"No key or prefix: {key}") return result - -def record(value=None): - """This function returns a :class:`Record` instance constructed with an + +def record(value=None): + """This function returns a :class:`Record` instance constructed with an initial value that you provide. - - :param `value`: An initial record value. - :type `value`: ``dict`` + + :param value: An initial record value. + :type value: ``dict`` """ if value is None: value = {} return Record(value) - diff --git a/splunklib/modularinput/argument.py b/splunklib/modularinput/argument.py index 04214d16d..f16ea99e3 100644 --- a/splunklib/modularinput/argument.py +++ b/splunklib/modularinput/argument.py @@ -12,16 +12,13 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -try: - import xml.etree.ElementTree as ET -except ImportError: - import xml.etree.cElementTree as ET +import xml.etree.ElementTree as ET + +class Argument: -class Argument(object): """Class representing an argument to a modular input kind. - ``Argument`` is meant to be used with ``Scheme`` to generate an XML + ``Argument`` is meant to be used with ``Scheme`` to generate an XML definition of the modular input kind that Splunk understands. ``name`` is the only required parameter for the constructor. @@ -100,4 +97,4 @@ def add_to_document(self, parent): for name, value in subelements: ET.SubElement(arg, name).text = str(value).lower() - return arg \ No newline at end of file + return arg diff --git a/splunklib/modularinput/event.py b/splunklib/modularinput/event.py index 9cd6cf3ae..6a9fba939 100644 --- a/splunklib/modularinput/event.py +++ b/splunklib/modularinput/event.py @@ -12,16 +12,13 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from io import TextIOBase -from splunklib.six import ensure_text +import xml.etree.ElementTree as ET -try: - import xml.etree.cElementTree as ET -except ImportError as ie: - import xml.etree.ElementTree as ET +from splunklib import ensure_text -class Event(object): + +class Event: """Represents an event or fragment of an event to be written by this modular input to Splunk. To write an input to a stream, call the ``write_to`` function, passing in a stream. @@ -111,4 +108,4 @@ def write_to(self, stream): stream.write(ensure_text(ET.tostring(event))) else: stream.write(ET.tostring(event)) - stream.flush() \ No newline at end of file + stream.flush() diff --git a/splunklib/modularinput/event_writer.py b/splunklib/modularinput/event_writer.py index 5f8c5aa8b..5aa83d963 100644 --- a/splunklib/modularinput/event_writer.py +++ b/splunklib/modularinput/event_writer.py @@ -12,18 +12,13 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import import sys -from splunklib.six import ensure_str +from splunklib import ensure_str from .event import ET -try: - from splunklib.six.moves import cStringIO as StringIO -except ImportError: - from splunklib.six import StringIO -class EventWriter(object): +class EventWriter: """``EventWriter`` writes events and error messages to Splunk from a modular input. Its two important methods are ``writeEvent``, which takes an ``Event`` object, and ``log``, which takes a severity and an error message. @@ -68,7 +63,7 @@ def log(self, severity, message): :param message: ``string``, message to log. """ - self._err.write("%s %s\n" % (severity, message)) + self._err.write(f"{severity} {message}\n") self._err.flush() def write_xml_document(self, document): @@ -83,5 +78,5 @@ def write_xml_document(self, document): def close(self): """Write the closing tag to make this XML well formed.""" if self.header_written: - self._out.write("") + self._out.write("") self._out.flush() diff --git a/splunklib/modularinput/input_definition.py b/splunklib/modularinput/input_definition.py index fdc7cbb3f..c0e8e1ac5 100644 --- a/splunklib/modularinput/input_definition.py +++ b/splunklib/modularinput/input_definition.py @@ -12,12 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -try: - import xml.etree.cElementTree as ET -except ImportError as ie: - import xml.etree.ElementTree as ET - +import xml.etree.ElementTree as ET from .utils import parse_xml_data class InputDefinition: @@ -57,4 +52,4 @@ def parse(stream): else: definition.metadata[node.tag] = node.text - return definition \ No newline at end of file + return definition diff --git a/splunklib/modularinput/scheme.py b/splunklib/modularinput/scheme.py index 4104e4a3f..e84ce00dc 100644 --- a/splunklib/modularinput/scheme.py +++ b/splunklib/modularinput/scheme.py @@ -12,13 +12,10 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -try: - import xml.etree.cElementTree as ET -except ImportError: - import xml.etree.ElementTree as ET +import xml.etree.ElementTree as ET -class Scheme(object): + +class Scheme: """Class representing the metadata for a modular input kind. A ``Scheme`` specifies a title, description, several options of how Splunk should run modular inputs of this @@ -82,4 +79,4 @@ def to_xml(self): for arg in self.arguments: arg.add_to_document(args) - return root \ No newline at end of file + return root diff --git a/splunklib/modularinput/script.py b/splunklib/modularinput/script.py index 8595dc4bd..5df6d0fce 100644 --- a/splunklib/modularinput/script.py +++ b/splunklib/modularinput/script.py @@ -12,24 +12,18 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from abc import ABCMeta, abstractmethod -from splunklib.six.moves.urllib.parse import urlsplit import sys +import xml.etree.ElementTree as ET +from urllib.parse import urlsplit from ..client import Service from .event_writer import EventWriter from .input_definition import InputDefinition from .validation_definition import ValidationDefinition -from splunklib import six -try: - import xml.etree.cElementTree as ET -except ImportError: - import xml.etree.ElementTree as ET - -class Script(six.with_metaclass(ABCMeta, object)): +class Script(metaclass=ABCMeta): """An abstract base class for implementing modular inputs. Subclasses should override ``get_scheme``, ``stream_events``, @@ -74,7 +68,7 @@ def run_script(self, args, event_writer, input_stream): event_writer.close() return 0 - elif str(args[1]).lower() == "--scheme": + if str(args[1]).lower() == "--scheme": # Splunk has requested XML specifying the scheme for this # modular input Return it and exit. scheme = self.get_scheme() @@ -83,11 +77,10 @@ def run_script(self, args, event_writer, input_stream): EventWriter.FATAL, "Modular input script returned a null scheme.") return 1 - else: - event_writer.write_xml_document(scheme.to_xml()) - return 0 + event_writer.write_xml_document(scheme.to_xml()) + return 0 - elif args[1].lower() == "--validate-arguments": + if args[1].lower() == "--validate-arguments": validation_definition = ValidationDefinition.parse(input_stream) try: self.validate_input(validation_definition) @@ -98,11 +91,10 @@ def run_script(self, args, event_writer, input_stream): event_writer.write_xml_document(root) return 1 - else: - err_string = "ERROR Invalid arguments to modular input script:" + ' '.join( - args) - event_writer._err.write(err_string) - return 1 + err_string = "ERROR Invalid arguments to modular input script:" + ' '.join( + args) + event_writer._err.write(err_string) + return 1 except Exception as e: event_writer.log(EventWriter.ERROR, str(e)) @@ -165,7 +157,6 @@ def validate_input(self, definition): :param definition: The parameters for the proposed input passed by splunkd. """ - pass @abstractmethod def stream_events(self, inputs, ew): diff --git a/splunklib/modularinput/utils.py b/splunklib/modularinput/utils.py index 3d42b6326..6429c0a71 100644 --- a/splunklib/modularinput/utils.py +++ b/splunklib/modularinput/utils.py @@ -14,8 +14,7 @@ # File for utility functions -from __future__ import absolute_import -from splunklib.six.moves import zip + def xml_compare(expected, found): """Checks equality of two ``ElementTree`` objects. @@ -39,27 +38,25 @@ def xml_compare(expected, found): return False # compare children - if not all([xml_compare(a, b) for a, b in zip(expected_children, found_children)]): + if not all(xml_compare(a, b) for a, b in zip(expected_children, found_children)): return False # compare elements, if there is no text node, return True if (expected.text is None or expected.text.strip() == "") \ and (found.text is None or found.text.strip() == ""): return True - else: - return expected.tag == found.tag and expected.text == found.text \ + return expected.tag == found.tag and expected.text == found.text \ and expected.attrib == found.attrib def parse_parameters(param_node): if param_node.tag == "param": return param_node.text - elif param_node.tag == "param_list": + if param_node.tag == "param_list": parameters = [] for mvp in param_node: parameters.append(mvp.text) return parameters - else: - raise ValueError("Invalid configuration scheme, %s tag unexpected." % param_node.tag) + raise ValueError(f"Invalid configuration scheme, {param_node.tag} tag unexpected.") def parse_xml_data(parent_node, child_node_tag): data = {} diff --git a/splunklib/modularinput/validation_definition.py b/splunklib/modularinput/validation_definition.py index 3bbe9760e..0ad40e9ed 100644 --- a/splunklib/modularinput/validation_definition.py +++ b/splunklib/modularinput/validation_definition.py @@ -13,16 +13,12 @@ # under the License. -from __future__ import absolute_import -try: - import xml.etree.cElementTree as ET -except ImportError as ie: - import xml.etree.ElementTree as ET +import xml.etree.ElementTree as ET from .utils import parse_xml_data -class ValidationDefinition(object): +class ValidationDefinition: """This class represents the XML sent by Splunk for external validation of a new modular input. @@ -83,4 +79,4 @@ def parse(stream): # Store anything else in metadata definition.metadata[node.tag] = node.text - return definition \ No newline at end of file + return definition diff --git a/splunklib/results.py b/splunklib/results.py index 8543ab0df..8420cf3d1 100644 --- a/splunklib/results.py +++ b/splunklib/results.py @@ -29,38 +29,27 @@ reader = ResultsReader(result_stream) for item in reader: print(item) - print "Results are a preview: %s" % reader.is_preview + print(f"Results are a preview: {reader.is_preview}") """ -from __future__ import absolute_import - from io import BufferedReader, BytesIO -from splunklib import six - -from splunklib.six import deprecated -try: - import xml.etree.cElementTree as et -except: - import xml.etree.ElementTree as et +import xml.etree.ElementTree as et from collections import OrderedDict from json import loads as json_loads -try: - from splunklib.six.moves import cStringIO as StringIO -except: - from splunklib.six import StringIO - __all__ = [ "ResultsReader", "Message", "JSONResultsReader" ] +import deprecation -class Message(object): + +class Message: """This class represents informational messages that Splunk interleaves in the results stream. ``Message`` takes two arguments: a string giving the message type (e.g., "DEBUG"), and @@ -76,7 +65,7 @@ def __init__(self, type_, message): self.message = message def __repr__(self): - return "%s: %s" % (self.type, self.message) + return f"{self.type}: {self.message}" def __eq__(self, other): return (self.type, self.message) == (other.type, other.message) @@ -85,7 +74,7 @@ def __hash__(self): return hash((self.type, self.message)) -class _ConcatenatedStream(object): +class _ConcatenatedStream: """Lazily concatenate zero or more streams into a stream. As you read from the concatenated stream, you get characters from @@ -117,7 +106,7 @@ def read(self, n=None): return response -class _XMLDTDFilter(object): +class _XMLDTDFilter: """Lazily remove all XML DTDs from a stream. All substrings matching the regular expression ]*> are @@ -144,7 +133,7 @@ def read(self, n=None): c = self.stream.read(1) if c == b"": break - elif c == b"<": + if c == b"<": c += self.stream.read(1) if c == b"`_ 3. `Configure seach assistant with searchbnf.conf `_ - + 4. `Control search distribution with distsearch.conf `_ """ -from __future__ import absolute_import, division, print_function, unicode_literals - from .environment import * from .decorators import * from .validators import * diff --git a/splunklib/searchcommands/decorators.py b/splunklib/searchcommands/decorators.py index d8b3f48cc..479029698 100644 --- a/splunklib/searchcommands/decorators.py +++ b/splunklib/searchcommands/decorators.py @@ -14,19 +14,16 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals -from splunklib import six from collections import OrderedDict # must be python 2.7 - from inspect import getmembers, isclass, isfunction -from splunklib.six.moves import map as imap + from .internals import ConfigurationSettingsType, json_encode_string from .validators import OptionName -class Configuration(object): +class Configuration: """ Defines the configuration settings for a search command. Documents, validates, and ensures that only relevant configuration settings are applied. Adds a :code:`name` class @@ -69,7 +66,7 @@ def __call__(self, o): name = o.__name__ if name.endswith('Command'): name = name[:-len('Command')] - o.name = six.text_type(name.lower()) + o.name = str(name.lower()) # Construct ConfigurationSettings instance for the command class @@ -82,7 +79,7 @@ def __call__(self, o): o.ConfigurationSettings.fix_up(o) Option.fix_up(o) else: - raise TypeError('Incorrect usage: Configuration decorator applied to {0}'.format(type(o), o.__name__)) + raise TypeError(f'Incorrect usage: Configuration decorator applied to {type(o)}') return o @@ -136,7 +133,7 @@ def fix_up(cls, values): for name, setting in definitions: if setting._name is None: - setting._name = name = six.text_type(name) + setting._name = name = str(name) else: name = setting._name @@ -187,14 +184,14 @@ def is_supported_by_protocol(version): continue if setting.fset is None: - raise ValueError('The value of configuration setting {} is fixed'.format(name)) + raise ValueError(f'The value of configuration setting {name} is fixed') setattr(cls, backing_field_name, validate(specification, name, value)) del values[name] if len(values) > 0: - settings = sorted(list(six.iteritems(values))) - settings = imap(lambda n_v: '{}={}'.format(n_v[0], repr(n_v[1])), settings) + settings = sorted(list(values.items())) + settings = [f'{n_v[0]}={n_v[1]}' for n_v in settings] raise AttributeError('Inapplicable configuration settings: ' + ', '.join(settings)) cls.configuration_setting_definitions = definitions @@ -212,7 +209,7 @@ def _get_specification(self): try: specification = ConfigurationSettingsType.specification_matrix[name] except KeyError: - raise AttributeError('Unknown configuration setting: {}={}'.format(name, repr(self._value))) + raise AttributeError(f'Unknown configuration setting: {name}={repr(self._value)}') return ConfigurationSettingsType.validate_configuration_setting, specification @@ -346,7 +343,7 @@ def _copy_extra_attributes(self, other): # region Types - class Item(object): + class Item: """ Presents an instance/class view over a search command `Option`. This class is used by SearchCommand.process to parse and report on option values. @@ -357,7 +354,7 @@ def __init__(self, command, option): self._option = option self._is_set = False validator = self.validator - self._format = six.text_type if validator is None else validator.format + self._format = str if validator is None else validator.format def __repr__(self): return '(' + repr(self.name) + ', ' + repr(self._format(self.value)) + ')' @@ -405,7 +402,6 @@ def reset(self): self._option.__set__(self._command, self._option.default) self._is_set = False - pass # endregion class View(OrderedDict): @@ -420,27 +416,26 @@ def __init__(self, command): OrderedDict.__init__(self, ((option.name, item_class(command, option)) for (name, option) in definitions)) def __repr__(self): - text = 'Option.View([' + ','.join(imap(lambda item: repr(item), six.itervalues(self))) + '])' + text = 'Option.View([' + ','.join([repr(item) for item in list(self.values())]) + '])' return text def __str__(self): - text = ' '.join([str(item) for item in six.itervalues(self) if item.is_set]) + text = ' '.join([str(item) for item in list(self.values()) if item.is_set]) return text # region Methods def get_missing(self): - missing = [item.name for item in six.itervalues(self) if item.is_required and not item.is_set] + missing = [item.name for item in list(self.values()) if item.is_required and not item.is_set] return missing if len(missing) > 0 else None def reset(self): - for value in six.itervalues(self): + for value in list(self.values()): value.reset() - pass # endregion - pass + # endregion diff --git a/splunklib/searchcommands/environment.py b/splunklib/searchcommands/environment.py index e92018f6a..2896df7b6 100644 --- a/splunklib/searchcommands/environment.py +++ b/splunklib/searchcommands/environment.py @@ -14,16 +14,15 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals + from logging import getLogger, root, StreamHandler from logging.config import fileConfig -from os import chdir, environ, path -from splunklib.six.moves import getcwd - +from os import chdir, environ, path, getcwd import sys + def configure_logging(logger_name, filename=None): """ Configure logging and return the named logger and the location of the logging configuration file loaded. @@ -88,9 +87,9 @@ def configure_logging(logger_name, filename=None): found = True break if not found: - raise ValueError('Logging configuration file "{}" not found in local or default directory'.format(filename)) + raise ValueError(f'Logging configuration file "{filename}" not found in local or default directory') elif not path.exists(filename): - raise ValueError('Logging configuration file "{}" not found'.format(filename)) + raise ValueError(f'Logging configuration file "{filename}" not found') if filename is not None: global _current_logging_configuration_file diff --git a/splunklib/searchcommands/eventing_command.py b/splunklib/searchcommands/eventing_command.py index 27dc13a3a..ab27d32e1 100644 --- a/splunklib/searchcommands/eventing_command.py +++ b/splunklib/searchcommands/eventing_command.py @@ -14,10 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals -from splunklib import six -from splunklib.six.moves import map as imap from .decorators import ConfigurationSetting from .search_command import SearchCommand @@ -140,10 +137,10 @@ def fix_up(cls, command): # N.B.: Does not use Python 2 dict copy semantics def iteritems(self): iteritems = SearchCommand.ConfigurationSettings.iteritems(self) - return imap(lambda name_value: (name_value[0], 'events' if name_value[0] == 'type' else name_value[1]), iteritems) + return [(name_value[0], 'events' if name_value[0] == 'type' else name_value[1]) for name_value in iteritems] # N.B.: Does not use Python 3 dict view semantics - if not six.PY2: - items = iteritems + + items = iteritems # endregion diff --git a/splunklib/searchcommands/external_search_command.py b/splunklib/searchcommands/external_search_command.py index c2306241e..18fc2643e 100644 --- a/splunklib/searchcommands/external_search_command.py +++ b/splunklib/searchcommands/external_search_command.py @@ -14,34 +14,31 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - from logging import getLogger import os import sys import traceback -from splunklib import six +from . import splunklib_logger as logger + if sys.platform == 'win32': from signal import signal, CTRL_BREAK_EVENT, SIGBREAK, SIGINT, SIGTERM from subprocess import Popen import atexit -from . import splunklib_logger as logger + # P1 [ ] TODO: Add ExternalSearchCommand class documentation -class ExternalSearchCommand(object): - """ - """ +class ExternalSearchCommand: def __init__(self, path, argv=None, environ=None): - if not isinstance(path, (bytes, six.text_type)): - raise ValueError('Expected a string value for path, not {}'.format(repr(path))) + if not isinstance(path, (bytes,str)): + raise ValueError(f'Expected a string value for path, not {repr(path)}') self._logger = getLogger(self.__class__.__name__) - self._path = six.text_type(path) + self._path = str(path) self._argv = None self._environ = None @@ -57,7 +54,7 @@ def argv(self): @argv.setter def argv(self, value): if not (value is None or isinstance(value, (list, tuple))): - raise ValueError('Expected a list, tuple or value of None for argv, not {}'.format(repr(value))) + raise ValueError(f'Expected a list, tuple or value of None for argv, not {repr(value)}') self._argv = value @property @@ -67,7 +64,7 @@ def environ(self): @environ.setter def environ(self, value): if not (value is None or isinstance(value, dict)): - raise ValueError('Expected a dictionary value for environ, not {}'.format(repr(value))) + raise ValueError(f'Expected a dictionary value for environ, not {repr(value)}') self._environ = value @property @@ -90,7 +87,7 @@ def execute(self): self._execute(self._path, self._argv, self._environ) except: error_type, error, tb = sys.exc_info() - message = 'Command execution failed: ' + six.text_type(error) + message = 'Command execution failed: ' + str(error) self._logger.error(message + '\nTraceback:\n' + ''.join(traceback.format_tb(tb))) sys.exit(1) @@ -120,13 +117,13 @@ def _execute(path, argv=None, environ=None): found = ExternalSearchCommand._search_path(path, search_path) if found is None: - raise ValueError('Cannot find command on path: {}'.format(path)) + raise ValueError(f'Cannot find command on path: {path}') path = found - logger.debug('starting command="%s", arguments=%s', path, argv) + logger.debug(f'starting command="{path}", arguments={path}') - def terminate(signal_number, frame): - sys.exit('External search command is terminating on receipt of signal={}.'.format(signal_number)) + def terminate(signal_number): + sys.exit(f'External search command is terminating on receipt of signal={signal_number}.') def terminate_child(): if p.pid is not None and p.returncode is None: @@ -206,7 +203,6 @@ def _execute(path, argv, environ): os.execvp(path, argv) else: os.execvpe(path, argv, environ) - return # endregion diff --git a/splunklib/searchcommands/generating_command.py b/splunklib/searchcommands/generating_command.py index 6a75d2c27..139935b88 100644 --- a/splunklib/searchcommands/generating_command.py +++ b/splunklib/searchcommands/generating_command.py @@ -14,14 +14,11 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals import sys from .decorators import ConfigurationSetting from .search_command import SearchCommand -from splunklib import six -from splunklib.six.moves import map as imap, filter as ifilter # P1 [O] TODO: Discuss generates_timeorder in the class-level documentation for GeneratingCommand @@ -254,8 +251,7 @@ def process(self, argv=sys.argv, ifile=sys.stdin, ofile=sys.stdout, allow_empty_ if not allow_empty_input: raise ValueError("allow_empty_input cannot be False for Generating Commands") - else: - return super(GeneratingCommand, self).process(argv=argv, ifile=ifile, ofile=ofile, allow_empty_input=True) + return super().process(argv=argv, ifile=ifile, ofile=ofile, allow_empty_input=True) # endregion @@ -370,18 +366,14 @@ def iteritems(self): iteritems = SearchCommand.ConfigurationSettings.iteritems(self) version = self.command.protocol_version if version == 2: - iteritems = ifilter(lambda name_value1: name_value1[0] != 'distributed', iteritems) + iteritems = [name_value1 for name_value1 in iteritems if name_value1[0] != 'distributed'] if not self.distributed and self.type == 'streaming': - iteritems = imap( - lambda name_value: (name_value[0], 'stateful') if name_value[0] == 'type' else (name_value[0], name_value[1]), iteritems) + iteritems = [(name_value[0], 'stateful') if name_value[0] == 'type' else (name_value[0], name_value[1]) for name_value in iteritems] return iteritems # N.B.: Does not use Python 3 dict view semantics - if not six.PY2: - items = iteritems + items = iteritems - pass # endregion - pass # endregion diff --git a/splunklib/searchcommands/internals.py b/splunklib/searchcommands/internals.py index 1ea2833db..6bbec4b39 100644 --- a/splunklib/searchcommands/internals.py +++ b/splunklib/searchcommands/internals.py @@ -14,25 +14,22 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function - -from io import TextIOWrapper -from collections import deque, namedtuple -from splunklib import six -from collections import OrderedDict -from splunklib.six.moves import StringIO -from itertools import chain -from splunklib.six.moves import map as imap -from json import JSONDecoder, JSONEncoder -from json.encoder import encode_basestring_ascii as json_encode_string -from splunklib.six.moves import urllib - import csv import gzip import os import re import sys import warnings +import urllib.parse +from io import TextIOWrapper, StringIO +from collections import deque, namedtuple +from collections import OrderedDict +from itertools import chain +from json import JSONDecoder, JSONEncoder +from json.encoder import encode_basestring_ascii as json_encode_string + + + from . import environment @@ -54,7 +51,7 @@ def set_binary_mode(fh): if sys.version_info >= (3, 0) and hasattr(fh, 'buffer'): return fh.buffer # check for python3 - elif sys.version_info >= (3, 0): + if sys.version_info >= (3, 0): pass # check for windows python2. SPL-175233 -- python3 stdout is already binary elif sys.platform == 'win32': @@ -65,13 +62,12 @@ def set_binary_mode(fh): implementation = python_implementation() if implementation == 'PyPy': return os.fdopen(fh.fileno(), 'wb', 0) - else: - import msvcrt - msvcrt.setmode(fh.fileno(), os.O_BINARY) + import msvcrt + msvcrt.setmode(fh.fileno(), os.O_BINARY) return fh -class CommandLineParser(object): +class CommandLineParser: r""" Parses the arguments to a search command. A search command line is described by the following syntax. @@ -144,7 +140,7 @@ def parse(cls, command, argv): command_args = cls._arguments_re.match(argv) if command_args is None: - raise SyntaxError('Syntax error: {}'.format(argv)) + raise SyntaxError(f'Syntax error: {argv}') # Parse options @@ -152,7 +148,7 @@ def parse(cls, command, argv): name, value = option.group('name'), option.group('value') if name not in command.options: raise ValueError( - 'Unrecognized {} command option: {}={}'.format(command.name, name, json_encode_string(value))) + f'Unrecognized {command.name} command option: {name}={json_encode_string(value)}') command.options[name].value = cls.unquote(value) missing = command.options.get_missing() @@ -160,8 +156,8 @@ def parse(cls, command, argv): if missing is not None: if len(missing) > 1: raise ValueError( - 'Values for these {} command options are required: {}'.format(command.name, ', '.join(missing))) - raise ValueError('A value for {} command option {} is required'.format(command.name, missing[0])) + f'Values for these {command.name} command options are required: {", ".join(missing)}') + raise ValueError(f'A value for {command.name} command option {missing[0]} is required') # Parse field names @@ -277,10 +273,10 @@ def validate_configuration_setting(specification, name, value): if isinstance(specification.type, type): type_names = specification.type.__name__ else: - type_names = ', '.join(imap(lambda t: t.__name__, specification.type)) - raise ValueError('Expected {} value, not {}={}'.format(type_names, name, repr(value))) + type_names = ', '.join(map(lambda t: t.__name__, specification.type)) + raise ValueError(f'Expected {type_names} value, not {name}={repr(value)}') if specification.constraint and not specification.constraint(value): - raise ValueError('Illegal value: {}={}'.format(name, repr(value))) + raise ValueError(f'Illegal value: {name}={ repr(value)}') return value specification = namedtuple( @@ -314,7 +310,7 @@ def validate_configuration_setting(specification, name, value): supporting_protocols=[1]), 'maxinputs': specification( type=int, - constraint=lambda value: 0 <= value <= six.MAXSIZE, + constraint=lambda value: 0 <= value <= sys.maxsize, supporting_protocols=[2]), 'overrides_timeorder': specification( type=bool, @@ -341,11 +337,11 @@ def validate_configuration_setting(specification, name, value): constraint=None, supporting_protocols=[1]), 'streaming_preop': specification( - type=(bytes, six.text_type), + type=(bytes, str), constraint=None, supporting_protocols=[1, 2]), 'type': specification( - type=(bytes, six.text_type), + type=(bytes, str), constraint=lambda value: value in ('events', 'reporting', 'streaming'), supporting_protocols=[2])} @@ -368,7 +364,7 @@ class InputHeader(dict): """ def __str__(self): - return '\n'.join([name + ':' + value for name, value in six.iteritems(self)]) + return '\n'.join([name + ':' + value for name, value in self.items()]) def read(self, ifile): """ Reads an input header from an input file. @@ -416,7 +412,7 @@ def _object_hook(dictionary): while len(stack): instance, member_name, dictionary = stack.popleft() - for name, value in six.iteritems(dictionary): + for name, value in dictionary.items(): if isinstance(value, dict): stack.append((dictionary, name, value)) @@ -437,11 +433,14 @@ def default(self, o): _separators = (',', ':') -class ObjectView(object): +class ObjectView: def __init__(self, dictionary): self.__dict__ = dictionary + def update(self, obj): + self.__dict__.update(obj.__dict__) + def __repr__(self): return repr(self.__dict__) @@ -449,7 +448,7 @@ def __str__(self): return str(self.__dict__) -class Recorder(object): +class Recorder: def __init__(self, path, f): self._recording = gzip.open(path + '.gz', 'wb') @@ -487,7 +486,7 @@ def write(self, text): self._recording.flush() -class RecordWriter(object): +class RecordWriter: def __init__(self, ofile, maxresultrows=None): self._maxresultrows = 50000 if maxresultrows is None else maxresultrows @@ -513,7 +512,7 @@ def is_flushed(self): @is_flushed.setter def is_flushed(self, value): - self._flushed = True if value else False + self._flushed = bool(value) @property def ofile(self): @@ -593,7 +592,7 @@ def _write_record(self, record): if fieldnames is None: self._fieldnames = fieldnames = list(record.keys()) self._fieldnames.extend([i for i in self.custom_fields if i not in self._fieldnames]) - value_list = imap(lambda fn: (str(fn), str('__mv_') + str(fn)), fieldnames) + value_list = map(lambda fn: (str(fn), str('__mv_') + str(fn)), fieldnames) self._writerow(list(chain.from_iterable(value_list))) get_value = record.get @@ -632,9 +631,9 @@ def _write_record(self, record): if value_t is bool: value = str(value.real) - elif value_t is six.text_type: - value = value - elif isinstance(value, six.integer_types) or value_t is float or value_t is complex: + elif value_t is str: + value = str(value) + elif isinstance(value, int) or value_t is float or value_t is complex: value = str(value) elif issubclass(value_t, (dict, list, tuple)): value = str(''.join(RecordWriter._iterencode_json(value, 0))) @@ -658,13 +657,11 @@ def _write_record(self, record): values += (value, None) continue - if value_t is six.text_type: - if six.PY2: - value = value.encode('utf-8') + if value_t is str: values += (value, None) continue - if isinstance(value, six.integer_types) or value_t is float or value_t is complex: + if isinstance(value, int) or value_t is float or value_t is complex: values += (str(value), None) continue @@ -799,16 +796,17 @@ def write_chunk(self, finished=None): if len(inspector) == 0: inspector = None - metadata = [item for item in (('inspector', inspector), ('finished', finished))] + metadata = [('inspector', inspector), ('finished', finished)] self._write_chunk(metadata, self._buffer.getvalue()) self._clear() def write_metadata(self, configuration): self._ensure_validity() - metadata = chain(six.iteritems(configuration), (('inspector', self._inspector if self._inspector else None),)) + metadata = chain(configuration.items(), (('inspector', self._inspector if self._inspector else None),)) self._write_chunk(metadata, '') - self.write('\n') + # Removed additional new line + # self.write('\n') self._clear() def write_metric(self, name, value): @@ -816,13 +814,13 @@ def write_metric(self, name, value): self._inspector['metric.' + name] = value def _clear(self): - super(RecordWriterV2, self)._clear() + super()._clear() self._fieldnames = None def _write_chunk(self, metadata, body): if metadata: - metadata = str(''.join(self._iterencode_json(dict([(n, v) for n, v in metadata if v is not None]), 0))) + metadata = str(''.join(self._iterencode_json(dict((n, v) for n, v in metadata if v is not None), 0))) if sys.version_info >= (3, 0): metadata = metadata.encode('utf-8') metadata_length = len(metadata) @@ -836,7 +834,7 @@ def _write_chunk(self, metadata, body): if not (metadata_length > 0 or body_length > 0): return - start_line = 'chunked 1.0,%s,%s\n' % (metadata_length, body_length) + start_line = f'chunked 1.0,{metadata_length},{body_length}\n' self.write(start_line) self.write(metadata) self.write(body) diff --git a/splunklib/searchcommands/reporting_command.py b/splunklib/searchcommands/reporting_command.py index 947086197..3551f4cd3 100644 --- a/splunklib/searchcommands/reporting_command.py +++ b/splunklib/searchcommands/reporting_command.py @@ -14,8 +14,6 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - from itertools import chain from .internals import ConfigurationSettingsType, json_encode_string @@ -23,7 +21,6 @@ from .streaming_command import StreamingCommand from .search_command import SearchCommand from .validators import Set -from splunklib import six class ReportingCommand(SearchCommand): @@ -94,7 +91,7 @@ def prepare(self): self._configuration.streaming_preop = ' '.join(streaming_preop) return - raise RuntimeError('Unrecognized reporting command phase: {}'.format(json_encode_string(six.text_type(phase)))) + raise RuntimeError(f'Unrecognized reporting command phase: {json_encode_string(str(phase))}') def reduce(self, records): """ Override this method to produce a reporting data structure. @@ -244,7 +241,7 @@ def fix_up(cls, command): """ if not issubclass(command, ReportingCommand): - raise TypeError('{} is not a ReportingCommand'.format( command)) + raise TypeError(f'{command} is not a ReportingCommand') if command.reduce == ReportingCommand.reduce: raise AttributeError('No ReportingCommand.reduce override') @@ -274,8 +271,7 @@ def fix_up(cls, command): ConfigurationSetting.fix_up(f.ConfigurationSettings, settings) del f._settings - pass + # endregion - pass # endregion diff --git a/splunklib/searchcommands/search_command.py b/splunklib/searchcommands/search_command.py index dd11391d6..084ebb4b1 100644 --- a/splunklib/searchcommands/search_command.py +++ b/splunklib/searchcommands/search_command.py @@ -14,44 +14,30 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - # Absolute imports -from collections import namedtuple - +import csv import io - -from collections import OrderedDict +import os +import re +import sys +import tempfile +import traceback +from collections import namedtuple, OrderedDict from copy import deepcopy -from splunklib.six.moves import StringIO +from io import StringIO from itertools import chain, islice -from splunklib.six.moves import filter as ifilter, map as imap, zip as izip -from splunklib import six -if six.PY2: - from logging import _levelNames, getLevelName, getLogger -else: - from logging import _nameToLevel as _levelNames, getLevelName, getLogger -try: - from shutil import make_archive -except ImportError: - # Used for recording, skip on python 2.6 - pass +from logging import _nameToLevel as _levelNames, getLevelName, getLogger +from shutil import make_archive from time import time -from splunklib.six.moves.urllib.parse import unquote -from splunklib.six.moves.urllib.parse import urlsplit +from urllib.parse import unquote +from urllib.parse import urlsplit from warnings import warn from xml.etree import ElementTree -import os -import sys -import re -import csv -import tempfile -import traceback - # Relative imports - +import splunklib +from . import Boolean, Option, environment from .internals import ( CommandLineParser, CsvDialect, @@ -64,8 +50,6 @@ RecordWriterV1, RecordWriterV2, json_encode_string) - -from . import Boolean, Option, environment from ..client import Service @@ -91,7 +75,7 @@ # P2 [ ] TODO: Consider bumping None formatting up to Option.Item.__str__ -class SearchCommand(object): +class SearchCommand: """ Represents a custom search command. """ @@ -158,16 +142,16 @@ def logging_level(self): def logging_level(self, value): if value is None: value = self._default_logging_level - if isinstance(value, (bytes, six.text_type)): + if isinstance(value, (bytes, str)): try: level = _levelNames[value.upper()] except KeyError: - raise ValueError('Unrecognized logging level: {}'.format(value)) + raise ValueError(f'Unrecognized logging level: {value}') else: try: level = int(value) except ValueError: - raise ValueError('Unrecognized logging level: {}'.format(value)) + raise ValueError(f'Unrecognized logging level: {value}') self._logger.setLevel(level) def add_field(self, current_record, field_name, field_value): @@ -291,7 +275,7 @@ def search_results_info(self): values = next(reader) except IOError as error: if error.errno == 2: - self.logger.error('Search results info file {} does not exist.'.format(json_encode_string(path))) + self.logger.error(f'Search results info file {json_encode_string(path)} does not exist.') return raise @@ -306,7 +290,7 @@ def convert_value(value): except ValueError: return value - info = ObjectView(dict(imap(lambda f_v: (convert_field(f_v[0]), convert_value(f_v[1])), izip(fields, values)))) + info = ObjectView(dict((convert_field(f_v[0]), convert_value(f_v[1])) for f_v in zip(fields, values))) try: count_map = info.countMap @@ -315,7 +299,7 @@ def convert_value(value): else: count_map = count_map.split(';') n = len(count_map) - info.countMap = dict(izip(islice(count_map, 0, n, 2), islice(count_map, 1, n, 2))) + info.countMap = dict(list(zip(islice(count_map, 0, n, 2), islice(count_map, 1, n, 2)))) try: msg_type = info.msgType @@ -323,7 +307,7 @@ def convert_value(value): except AttributeError: pass else: - messages = ifilter(lambda t_m: t_m[0] or t_m[1], izip(msg_type.split('\n'), msg_text.split('\n'))) + messages = [t_m for t_m in zip(msg_type.split('\n'), msg_text.split('\n')) if t_m[0] or t_m[1]] info.msg = [Message(message) for message in messages] del info.msgType @@ -417,7 +401,6 @@ def prepare(self): :rtype: NoneType """ - pass def process(self, argv=sys.argv, ifile=sys.stdin, ofile=sys.stdout, allow_empty_input=True): """ Process data. @@ -466,7 +449,7 @@ def _map_metadata(self, argv): def _map(metadata_map): metadata = {} - for name, value in six.iteritems(metadata_map): + for name, value in list(metadata_map.items()): if isinstance(value, dict): value = _map(value) else: @@ -485,7 +468,8 @@ def _map(metadata_map): _metadata_map = { 'action': - (lambda v: 'getinfo' if v == '__GETINFO__' else 'execute' if v == '__EXECUTE__' else None, lambda s: s.argv[1]), + (lambda v: 'getinfo' if v == '__GETINFO__' else 'execute' if v == '__EXECUTE__' else None, + lambda s: s.argv[1]), 'preview': (bool, lambda s: s.input_header.get('preview')), 'searchinfo': { @@ -533,7 +517,7 @@ def _prepare_protocol_v1(self, argv, ifile, ofile): try: tempfile.tempdir = self._metadata.searchinfo.dispatch_dir except AttributeError: - raise RuntimeError('{}.metadata.searchinfo.dispatch_dir is undefined'.format(self.__class__.__name__)) + raise RuntimeError(f'{self.__class__.__name__}.metadata.searchinfo.dispatch_dir is undefined') debug(' tempfile.tempdir=%r', tempfile.tempdir) @@ -603,7 +587,8 @@ def _process_protocol_v1(self, argv, ifile, ofile): ifile = self._prepare_protocol_v1(argv, ifile, ofile) self._record_writer.write_record(dict( - (n, ','.join(v) if isinstance(v, (list, tuple)) else v) for n, v in six.iteritems(self._configuration))) + (n, ','.join(v) if isinstance(v, (list, tuple)) else v) for n, v in + list(self._configuration.items()))) self.finish() elif argv[1] == '__EXECUTE__': @@ -617,21 +602,21 @@ def _process_protocol_v1(self, argv, ifile, ofile): else: message = ( - 'Command {0} appears to be statically configured for search command protocol version 1 and static ' + f'Command {self.name} appears to be statically configured for search command protocol version 1 and static ' 'configuration is unsupported by splunklib.searchcommands. Please ensure that ' 'default/commands.conf contains this stanza:\n' - '[{0}]\n' - 'filename = {1}\n' + f'[{self.name}]\n' + f'filename = {os.path.basename(argv[0])}\n' 'enableheader = true\n' 'outputheader = true\n' 'requires_srinfo = true\n' 'supports_getinfo = true\n' 'supports_multivalues = true\n' - 'supports_rawargs = true'.format(self.name, os.path.basename(argv[0]))) + 'supports_rawargs = true') raise RuntimeError(message) except (SyntaxError, ValueError) as error: - self.write_error(six.text_type(error)) + self.write_error(str(error)) self.flush() exit(0) @@ -686,7 +671,7 @@ def _process_protocol_v2(self, argv, ifile, ofile): action = getattr(metadata, 'action', None) if action != 'getinfo': - raise RuntimeError('Expected getinfo action, not {}'.format(action)) + raise RuntimeError(f'Expected getinfo action, not {action}') if len(body) > 0: raise RuntimeError('Did not expect data for getinfo action') @@ -706,7 +691,7 @@ def _process_protocol_v2(self, argv, ifile, ofile): try: tempfile.tempdir = self._metadata.searchinfo.dispatch_dir except AttributeError: - raise RuntimeError('%s.metadata.searchinfo.dispatch_dir is undefined'.format(class_name)) + raise RuntimeError(f'{class_name}.metadata.searchinfo.dispatch_dir is undefined') debug(' tempfile.tempdir=%r', tempfile.tempdir) except: @@ -727,7 +712,7 @@ def _process_protocol_v2(self, argv, ifile, ofile): debug('Parsing arguments') - if args and type(args) == list: + if args and isinstance(args, list): for arg in args: result = self._protocol_v2_option_parser(arg) if len(result) == 1: @@ -738,13 +723,13 @@ def _process_protocol_v2(self, argv, ifile, ofile): try: option = self.options[name] except KeyError: - self.write_error('Unrecognized option: {}={}'.format(name, value)) + self.write_error(f'Unrecognized option: {name}={value}') error_count += 1 continue try: option.value = value except ValueError: - self.write_error('Illegal value: {}={}'.format(name, value)) + self.write_error(f'Illegal value: {name}={value}') error_count += 1 continue @@ -752,15 +737,15 @@ def _process_protocol_v2(self, argv, ifile, ofile): if missing is not None: if len(missing) == 1: - self.write_error('A value for "{}" is required'.format(missing[0])) + self.write_error(f'A value for "{missing[0]}" is required'.format()) else: - self.write_error('Values for these required options are missing: {}'.format(', '.join(missing))) + self.write_error(f'Values for these required options are missing: {", ".join(missing)}') error_count += 1 if error_count > 0: exit(1) - debug(' command: %s', six.text_type(self)) + debug(' command: %s', str(self)) debug('Preparing for execution') self.prepare() @@ -778,7 +763,7 @@ def _process_protocol_v2(self, argv, ifile, ofile): setattr(info, attr, [arg for arg in getattr(info, attr) if not arg.startswith('record=')]) metadata = MetadataEncoder().encode(self._metadata) - ifile.record('chunked 1.0,', six.text_type(len(metadata)), ',0\n', metadata) + ifile.record('chunked 1.0,', str(len(metadata)), ',0\n', metadata) if self.show_configuration: self.write_info(self.name + ' command configuration: ' + str(self._configuration)) @@ -888,25 +873,25 @@ def _as_binary_stream(ifile): try: return ifile.buffer except AttributeError as error: - raise RuntimeError('Failed to get underlying buffer: {}'.format(error)) + raise RuntimeError(f'Failed to get underlying buffer: {error}') @staticmethod def _read_chunk(istream): # noinspection PyBroadException - assert isinstance(istream.read(0), six.binary_type), 'Stream must be binary' + assert isinstance(istream.read(0), bytes), 'Stream must be binary' try: header = istream.readline() except Exception as error: - raise RuntimeError('Failed to read transport header: {}'.format(error)) + raise RuntimeError(f'Failed to read transport header: {error}') if not header: return None - match = SearchCommand._header.match(six.ensure_str(header)) + match = SearchCommand._header.match(splunklib.ensure_str(header)) if match is None: - raise RuntimeError('Failed to parse transport header: {}'.format(header)) + raise RuntimeError(f'Failed to parse transport header: {header}') metadata_length, body_length = match.groups() metadata_length = int(metadata_length) @@ -915,14 +900,14 @@ def _read_chunk(istream): try: metadata = istream.read(metadata_length) except Exception as error: - raise RuntimeError('Failed to read metadata of length {}: {}'.format(metadata_length, error)) + raise RuntimeError(f'Failed to read metadata of length {metadata_length}: {error}') decoder = MetadataDecoder() try: - metadata = decoder.decode(six.ensure_str(metadata)) + metadata = decoder.decode(splunklib.ensure_str(metadata)) except Exception as error: - raise RuntimeError('Failed to parse metadata of length {}: {}'.format(metadata_length, error)) + raise RuntimeError(f'Failed to parse metadata of length {metadata_length}: {error}') # if body_length <= 0: # return metadata, '' @@ -932,9 +917,9 @@ def _read_chunk(istream): if body_length > 0: body = istream.read(body_length) except Exception as error: - raise RuntimeError('Failed to read body of length {}: {}'.format(body_length, error)) + raise RuntimeError(f'Failed to read body of length {body_length}: {error}') - return metadata, six.ensure_str(body) + return metadata, splunklib.ensure_str(body) _header = re.compile(r'chunked\s+1.0\s*,\s*(\d+)\s*,\s*(\d+)\s*\n') @@ -949,16 +934,16 @@ def _read_csv_records(self, ifile): except StopIteration: return - mv_fieldnames = dict([(name, name[len('__mv_'):]) for name in fieldnames if name.startswith('__mv_')]) + mv_fieldnames = dict((name, name[len('__mv_'):]) for name in fieldnames if name.startswith('__mv_')) if len(mv_fieldnames) == 0: for values in reader: - yield OrderedDict(izip(fieldnames, values)) + yield OrderedDict(list(zip(fieldnames, values))) return for values in reader: record = OrderedDict() - for fieldname, value in izip(fieldnames, values): + for fieldname, value in zip(fieldnames, values): if fieldname.startswith('__mv_'): if len(value) > 0: record[mv_fieldnames[fieldname]] = self._decode_list(value) @@ -978,25 +963,27 @@ def _execute_v2(self, ifile, process): metadata, body = result action = getattr(metadata, 'action', None) if action != 'execute': - raise RuntimeError('Expected execute action, not {}'.format(action)) + raise RuntimeError(f'Expected execute action, not {action}') self._finished = getattr(metadata, 'finished', False) self._record_writer.is_flushed = False - + # metadata.update(self._metadata) + # self._metadata = metadata + self._metadata.update(metadata) self._execute_chunk_v2(process, result) self._record_writer.write_chunk(finished=self._finished) def _execute_chunk_v2(self, process, chunk): - metadata, body = chunk + metadata, body = chunk - if len(body) <= 0 and not self._allow_empty_input: - raise ValueError( - "No records found to process. Set allow_empty_input=True in dispatch function to move forward " - "with empty records.") + if len(body) <= 0 and not self._allow_empty_input: + raise ValueError( + "No records found to process. Set allow_empty_input=True in dispatch function to move forward " + "with empty records.") - records = self._read_csv_records(StringIO(body)) - self._record_writer.write_records(process(records)) + records = self._read_csv_records(StringIO(body)) + self._record_writer.write_records(process(records)) def _report_unexpected_error(self): @@ -1008,7 +995,7 @@ def _report_unexpected_error(self): filename = origin.tb_frame.f_code.co_filename lineno = origin.tb_lineno - message = '{0} at "{1}", line {2:d} : {3}'.format(error_type.__name__, filename, lineno, error) + message = f'{error_type.__name__} at "{filename}", line {str(lineno)} : {error}' environment.splunklib_logger.error(message + '\nTraceback:\n' + ''.join(traceback.format_tb(tb))) self.write_error(message) @@ -1017,10 +1004,11 @@ def _report_unexpected_error(self): # region Types - class ConfigurationSettings(object): + class ConfigurationSettings: """ Represents the configuration settings common to all :class:`SearchCommand` classes. """ + def __init__(self, command): self.command = command @@ -1034,8 +1022,8 @@ def __repr__(self): """ definitions = type(self).configuration_setting_definitions - settings = imap( - lambda setting: repr((setting.name, setting.__get__(self), setting.supporting_protocols)), definitions) + settings = [repr((setting.name, setting.__get__(self), setting.supporting_protocols)) for setting in + definitions] return '[' + ', '.join(settings) + ']' def __str__(self): @@ -1047,8 +1035,8 @@ def __str__(self): :return: String representation of this instance """ - #text = ', '.join(imap(lambda (name, value): name + '=' + json_encode_string(unicode(value)), self.iteritems())) - text = ', '.join(['{}={}'.format(name, json_encode_string(six.text_type(value))) for (name, value) in six.iteritems(self)]) + # text = ', '.join(imap(lambda (name, value): name + '=' + json_encode_string(unicode(value)), self.iteritems())) + text = ', '.join([f'{name}={json_encode_string(str(value))}' for (name, value) in list(self.items())]) return text # region Methods @@ -1072,24 +1060,25 @@ def fix_up(cls, command_class): def iteritems(self): definitions = type(self).configuration_setting_definitions version = self.command.protocol_version - return ifilter( - lambda name_value1: name_value1[1] is not None, imap( - lambda setting: (setting.name, setting.__get__(self)), ifilter( - lambda setting: setting.is_supported_by_protocol(version), definitions))) + return [name_value1 for name_value1 in [(setting.name, setting.__get__(self)) for setting in + [setting for setting in definitions if + setting.is_supported_by_protocol(version)]] if + name_value1[1] is not None] # N.B.: Does not use Python 3 dict view semantics - if not six.PY2: - items = iteritems - pass # endregion + items = iteritems + + # endregion - pass # endregion + # endregion SearchMetric = namedtuple('SearchMetric', ('elapsed_seconds', 'invocation_count', 'input_count', 'output_count')) -def dispatch(command_class, argv=sys.argv, input_file=sys.stdin, output_file=sys.stdout, module_name=None, allow_empty_input=True): +def dispatch(command_class, argv=sys.argv, input_file=sys.stdin, output_file=sys.stdout, module_name=None, + allow_empty_input=True): """ Instantiates and executes a search command class This function implements a `conditional script stanza `_ based on the value of diff --git a/splunklib/searchcommands/streaming_command.py b/splunklib/searchcommands/streaming_command.py index fa075edb1..b3eb43756 100644 --- a/splunklib/searchcommands/streaming_command.py +++ b/splunklib/searchcommands/streaming_command.py @@ -14,10 +14,6 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - -from splunklib import six -from splunklib.six.moves import map as imap, filter as ifilter from .decorators import ConfigurationSetting from .search_command import SearchCommand @@ -171,7 +167,6 @@ def fix_up(cls, command): """ if command.stream == StreamingCommand.stream: raise AttributeError('No StreamingCommand.stream override') - return # TODO: Stop looking like a dictionary because we don't obey the semantics # N.B.: Does not use Python 2 dict copy semantics @@ -180,16 +175,14 @@ def iteritems(self): version = self.command.protocol_version if version == 1: if self.required_fields is None: - iteritems = ifilter(lambda name_value: name_value[0] != 'clear_required_fields', iteritems) + iteritems = [name_value for name_value in iteritems if name_value[0] != 'clear_required_fields'] else: - iteritems = ifilter(lambda name_value2: name_value2[0] != 'distributed', iteritems) + iteritems = [name_value2 for name_value2 in iteritems if name_value2[0] != 'distributed'] if not self.distributed: - iteritems = imap( - lambda name_value1: (name_value1[0], 'stateful') if name_value1[0] == 'type' else (name_value1[0], name_value1[1]), iteritems) + iteritems = [(name_value1[0], 'stateful') if name_value1[0] == 'type' else (name_value1[0], name_value1[1]) for name_value1 in iteritems] return iteritems # N.B.: Does not use Python 3 dict view semantics - if not six.PY2: - items = iteritems + items = iteritems # endregion diff --git a/splunklib/searchcommands/validators.py b/splunklib/searchcommands/validators.py index 22f0e16b2..ef460a4b1 100644 --- a/splunklib/searchcommands/validators.py +++ b/splunklib/searchcommands/validators.py @@ -14,20 +14,17 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - -from json.encoder import encode_basestring_ascii as json_encode_string -from collections import namedtuple -from splunklib.six.moves import StringIO -from io import open import csv import os import re -from splunklib import six -from splunklib.six.moves import getcwd +from io import open, StringIO +from os import getcwd +from json.encoder import encode_basestring_ascii as json_encode_string +from collections import namedtuple -class Validator(object): + +class Validator: """ Base class for validators that check and format search command options. You must inherit from this class and override :code:`Validator.__call__` and @@ -60,14 +57,16 @@ class Boolean(Validator): def __call__(self, value): if not (value is None or isinstance(value, bool)): - value = six.text_type(value).lower() + value = str(value).lower() if value not in Boolean.truth_values: - raise ValueError('Unrecognized truth value: {0}'.format(value)) + raise ValueError(f'Unrecognized truth value: {value}') value = Boolean.truth_values[value] return value def format(self, value): - return None if value is None else 't' if value else 'f' + if value is None: + return None + return 't' if value else 'f' class Code(Validator): @@ -93,11 +92,11 @@ def __call__(self, value): if value is None: return None try: - return Code.object(compile(value, 'string', self._mode), six.text_type(value)) + return Code.object(compile(value, 'string', self._mode), str(value)) except (SyntaxError, TypeError) as error: message = str(error) - six.raise_from(ValueError(message), error) + raise ValueError(message) from error def format(self, value): return None if value is None else value.source @@ -113,9 +112,9 @@ class Fieldname(Validator): def __call__(self, value): if value is not None: - value = six.text_type(value) + value = str(value) if Fieldname.pattern.match(value) is None: - raise ValueError('Illegal characters in fieldname: {}'.format(value)) + raise ValueError(f'Illegal characters in fieldname: {value}') return value def format(self, value): @@ -136,7 +135,7 @@ def __call__(self, value): if value is None: return value - path = six.text_type(value) + path = str(value) if not os.path.isabs(path): path = os.path.join(self.directory, path) @@ -144,8 +143,7 @@ def __call__(self, value): try: value = open(path, self.mode) if self.buffering is None else open(path, self.mode, self.buffering) except IOError as error: - raise ValueError('Cannot open {0} with mode={1} and buffering={2}: {3}'.format( - value, self.mode, self.buffering, error)) + raise ValueError(f'Cannot open {value} with mode={self.mode} and buffering={self.buffering}: {error}') return value @@ -163,42 +161,38 @@ class Integer(Validator): def __init__(self, minimum=None, maximum=None): if minimum is not None and maximum is not None: def check_range(value): - if not (minimum <= value <= maximum): - raise ValueError('Expected integer in the range [{0},{1}], not {2}'.format(minimum, maximum, value)) - return + if not minimum <= value <= maximum: + raise ValueError(f'Expected integer in the range [{minimum},{maximum}], not {value}') + elif minimum is not None: def check_range(value): if value < minimum: - raise ValueError('Expected integer in the range [{0},+∞], not {1}'.format(minimum, value)) - return + raise ValueError(f'Expected integer in the range [{minimum},+∞], not {value}') elif maximum is not None: def check_range(value): if value > maximum: - raise ValueError('Expected integer in the range [-∞,{0}], not {1}'.format(maximum, value)) - return + raise ValueError(f'Expected integer in the range [-∞,{maximum}], not {value}') + else: def check_range(value): return self.check_range = check_range - return + def __call__(self, value): if value is None: return None try: - if six.PY2: - value = long(value) - else: - value = int(value) + value = int(value) except ValueError: - raise ValueError('Expected integer value, not {}'.format(json_encode_string(value))) + raise ValueError(f'Expected integer value, not {json_encode_string(value)}') self.check_range(value) return value def format(self, value): - return None if value is None else six.text_type(int(value)) + return None if value is None else str(int(value)) class Float(Validator): @@ -208,25 +202,21 @@ class Float(Validator): def __init__(self, minimum=None, maximum=None): if minimum is not None and maximum is not None: def check_range(value): - if not (minimum <= value <= maximum): - raise ValueError('Expected float in the range [{0},{1}], not {2}'.format(minimum, maximum, value)) - return + if not minimum <= value <= maximum: + raise ValueError(f'Expected float in the range [{minimum},{maximum}], not {value}') elif minimum is not None: def check_range(value): if value < minimum: - raise ValueError('Expected float in the range [{0},+∞], not {1}'.format(minimum, value)) - return + raise ValueError(f'Expected float in the range [{minimum},+∞], not {value}') elif maximum is not None: def check_range(value): if value > maximum: - raise ValueError('Expected float in the range [-∞,{0}], not {1}'.format(maximum, value)) - return + raise ValueError(f'Expected float in the range [-∞,{maximum}], not {value}') else: def check_range(value): return - self.check_range = check_range - return + def __call__(self, value): if value is None: @@ -234,13 +224,13 @@ def __call__(self, value): try: value = float(value) except ValueError: - raise ValueError('Expected float value, not {}'.format(json_encode_string(value))) + raise ValueError(f'Expected float value, not {json_encode_string(value)}') self.check_range(value) return value def format(self, value): - return None if value is None else six.text_type(float(value)) + return None if value is None else str(float(value)) class Duration(Validator): @@ -265,7 +255,7 @@ def __call__(self, value): if len(p) == 3: result = 3600 * _unsigned(p[0]) + 60 * _60(p[1]) + _60(p[2]) except ValueError: - raise ValueError('Invalid duration value: {0}'.format(value)) + raise ValueError(f'Invalid duration value: {value}') return result @@ -302,7 +292,7 @@ class Dialect(csv.Dialect): def __init__(self, validator=None): if not (validator is None or isinstance(validator, Validator)): - raise ValueError('Expected a Validator instance or None for validator, not {}', repr(validator)) + raise ValueError(f'Expected a Validator instance or None for validator, not {repr(validator)}') self._validator = validator def __call__(self, value): @@ -322,7 +312,7 @@ def __call__(self, value): for index, item in enumerate(value): value[index] = self._validator(item) except ValueError as error: - raise ValueError('Could not convert item {}: {}'.format(index, error)) + raise ValueError(f'Could not convert item {index}: {error}') return value @@ -346,10 +336,10 @@ def __call__(self, value): if value is None: return None - value = six.text_type(value) + value = str(value) if value not in self.membership: - raise ValueError('Unrecognized value: {0}'.format(value)) + raise ValueError(f'Unrecognized value: {value}') return self.membership[value] @@ -362,19 +352,19 @@ class Match(Validator): """ def __init__(self, name, pattern, flags=0): - self.name = six.text_type(name) + self.name = str(name) self.pattern = re.compile(pattern, flags) def __call__(self, value): if value is None: return None - value = six.text_type(value) + value = str(value) if self.pattern.match(value) is None: - raise ValueError('Expected {}, not {}'.format(self.name, json_encode_string(value))) + raise ValueError(f'Expected {self.name}, not {json_encode_string(value)}') return value def format(self, value): - return None if value is None else six.text_type(value) + return None if value is None else str(value) class OptionName(Validator): @@ -385,13 +375,13 @@ class OptionName(Validator): def __call__(self, value): if value is not None: - value = six.text_type(value) + value = str(value) if OptionName.pattern.match(value) is None: - raise ValueError('Illegal characters in option name: {}'.format(value)) + raise ValueError(f'Illegal characters in option name: {value}') return value def format(self, value): - return None if value is None else six.text_type(value) + return None if value is None else str(value) class RegularExpression(Validator): @@ -402,9 +392,9 @@ def __call__(self, value): if value is None: return None try: - value = re.compile(six.text_type(value)) + value = re.compile(str(value)) except re.error as error: - raise ValueError('{}: {}'.format(six.text_type(error).capitalize(), value)) + raise ValueError(f'{str(error).capitalize()}: {value}') return value def format(self, value): @@ -421,9 +411,9 @@ def __init__(self, *args): def __call__(self, value): if value is None: return None - value = six.text_type(value) + value = str(value) if value not in self.membership: - raise ValueError('Unrecognized value: {}'.format(value)) + raise ValueError(f'Unrecognized value: {value}') return value def format(self, value): diff --git a/splunklib/six.py b/splunklib/six.py deleted file mode 100644 index d13e50c93..000000000 --- a/splunklib/six.py +++ /dev/null @@ -1,993 +0,0 @@ -# Copyright (c) 2010-2020 Benjamin Peterson -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""Utilities for writing code that runs on Python 2 and 3""" - -from __future__ import absolute_import - -import functools -import itertools -import operator -import sys -import types - -__author__ = "Benjamin Peterson " -__version__ = "1.14.0" - - -# Useful for very coarse version differentiation. -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 -PY34 = sys.version_info[0:2] >= (3, 4) - -if PY3: - string_types = str, - integer_types = int, - class_types = type, - text_type = str - binary_type = bytes - - MAXSIZE = sys.maxsize -else: - string_types = basestring, - integer_types = (int, long) - class_types = (type, types.ClassType) - text_type = unicode - binary_type = str - - if sys.platform.startswith("java"): - # Jython always uses 32 bits. - MAXSIZE = int((1 << 31) - 1) - else: - # It's possible to have sizeof(long) != sizeof(Py_ssize_t). - class X(object): - - def __len__(self): - return 1 << 31 - try: - len(X()) - except OverflowError: - # 32-bit - MAXSIZE = int((1 << 31) - 1) - else: - # 64-bit - MAXSIZE = int((1 << 63) - 1) - del X - - -def _add_doc(func, doc): - """Add documentation to a function.""" - func.__doc__ = doc - - -def _import_module(name): - """Import module, returning the module after the last dot.""" - __import__(name) - return sys.modules[name] - - -class _LazyDescr(object): - - def __init__(self, name): - self.name = name - - def __get__(self, obj, tp): - result = self._resolve() - setattr(obj, self.name, result) # Invokes __set__. - try: - # This is a bit ugly, but it avoids running this again by - # removing this descriptor. - delattr(obj.__class__, self.name) - except AttributeError: - pass - return result - - -class MovedModule(_LazyDescr): - - def __init__(self, name, old, new=None): - super(MovedModule, self).__init__(name) - if PY3: - if new is None: - new = name - self.mod = new - else: - self.mod = old - - def _resolve(self): - return _import_module(self.mod) - - def __getattr__(self, attr): - _module = self._resolve() - value = getattr(_module, attr) - setattr(self, attr, value) - return value - - -class _LazyModule(types.ModuleType): - - def __init__(self, name): - super(_LazyModule, self).__init__(name) - self.__doc__ = self.__class__.__doc__ - - def __dir__(self): - attrs = ["__doc__", "__name__"] - attrs += [attr.name for attr in self._moved_attributes] - return attrs - - # Subclasses should override this - _moved_attributes = [] - - -class MovedAttribute(_LazyDescr): - - def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): - super(MovedAttribute, self).__init__(name) - if PY3: - if new_mod is None: - new_mod = name - self.mod = new_mod - if new_attr is None: - if old_attr is None: - new_attr = name - else: - new_attr = old_attr - self.attr = new_attr - else: - self.mod = old_mod - if old_attr is None: - old_attr = name - self.attr = old_attr - - def _resolve(self): - module = _import_module(self.mod) - return getattr(module, self.attr) - - -class _SixMetaPathImporter(object): - - """ - A meta path importer to import six.moves and its submodules. - - This class implements a PEP302 finder and loader. It should be compatible - with Python 2.5 and all existing versions of Python3 - """ - - def __init__(self, six_module_name): - self.name = six_module_name - self.known_modules = {} - - def _add_module(self, mod, *fullnames): - for fullname in fullnames: - self.known_modules[self.name + "." + fullname] = mod - - def _get_module(self, fullname): - return self.known_modules[self.name + "." + fullname] - - def find_module(self, fullname, path=None): - if fullname in self.known_modules: - return self - return None - - def __get_module(self, fullname): - try: - return self.known_modules[fullname] - except KeyError: - raise ImportError("This loader does not know module " + fullname) - - def load_module(self, fullname): - try: - # in case of a reload - return sys.modules[fullname] - except KeyError: - pass - mod = self.__get_module(fullname) - if isinstance(mod, MovedModule): - mod = mod._resolve() - else: - mod.__loader__ = self - sys.modules[fullname] = mod - return mod - - def is_package(self, fullname): - """ - Return true, if the named module is a package. - - We need this method to get correct spec objects with - Python 3.4 (see PEP451) - """ - return hasattr(self.__get_module(fullname), "__path__") - - def get_code(self, fullname): - """Return None - - Required, if is_package is implemented""" - self.__get_module(fullname) # eventually raises ImportError - return None - get_source = get_code # same as get_code - -_importer = _SixMetaPathImporter(__name__) - - -class _MovedItems(_LazyModule): - - """Lazy loading of moved objects""" - __path__ = [] # mark as package - - -_moved_attributes = [ - MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), - MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), - MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"), - MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), - MovedAttribute("intern", "__builtin__", "sys"), - MovedAttribute("map", "itertools", "builtins", "imap", "map"), - MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"), - MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), - MovedAttribute("getoutput", "commands", "subprocess"), - MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), - MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"), - MovedAttribute("reduce", "__builtin__", "functools"), - MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), - MovedAttribute("StringIO", "StringIO", "io"), - MovedAttribute("UserDict", "UserDict", "collections"), - MovedAttribute("UserList", "UserList", "collections"), - MovedAttribute("UserString", "UserString", "collections"), - MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), - MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), - MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), - MovedModule("builtins", "__builtin__"), - MovedModule("configparser", "ConfigParser"), - MovedModule("collections_abc", "collections", "collections.abc" if sys.version_info >= (3, 3) else "collections"), - MovedModule("copyreg", "copy_reg"), - MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), - MovedModule("dbm_ndbm", "dbm", "dbm.ndbm"), - MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread" if sys.version_info < (3, 9) else "_thread"), - MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), - MovedModule("http_cookies", "Cookie", "http.cookies"), - MovedModule("html_entities", "htmlentitydefs", "html.entities"), - MovedModule("html_parser", "HTMLParser", "html.parser"), - MovedModule("http_client", "httplib", "http.client"), - MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), - MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"), - MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"), - MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"), - MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"), - MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), - MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), - MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), - MovedModule("cPickle", "cPickle", "pickle"), - MovedModule("queue", "Queue"), - MovedModule("reprlib", "repr"), - MovedModule("socketserver", "SocketServer"), - MovedModule("_thread", "thread", "_thread"), - MovedModule("tkinter", "Tkinter"), - MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), - MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), - MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), - MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), - MovedModule("tkinter_tix", "Tix", "tkinter.tix"), - MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), - MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), - MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), - MovedModule("tkinter_colorchooser", "tkColorChooser", - "tkinter.colorchooser"), - MovedModule("tkinter_commondialog", "tkCommonDialog", - "tkinter.commondialog"), - MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), - MovedModule("tkinter_font", "tkFont", "tkinter.font"), - MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), - MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", - "tkinter.simpledialog"), - MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), - MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), - MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), - MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), - MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"), - MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"), -] -# Add windows specific modules. -if sys.platform == "win32": - _moved_attributes += [ - MovedModule("winreg", "_winreg"), - ] - -for attr in _moved_attributes: - setattr(_MovedItems, attr.name, attr) - if isinstance(attr, MovedModule): - _importer._add_module(attr, "moves." + attr.name) -del attr - -_MovedItems._moved_attributes = _moved_attributes - -moves = _MovedItems(__name__ + ".moves") -_importer._add_module(moves, "moves") - - -class Module_six_moves_urllib_parse(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_parse""" - - -_urllib_parse_moved_attributes = [ - MovedAttribute("ParseResult", "urlparse", "urllib.parse"), - MovedAttribute("SplitResult", "urlparse", "urllib.parse"), - MovedAttribute("parse_qs", "urlparse", "urllib.parse"), - MovedAttribute("parse_qsl", "urlparse", "urllib.parse"), - MovedAttribute("urldefrag", "urlparse", "urllib.parse"), - MovedAttribute("urljoin", "urlparse", "urllib.parse"), - MovedAttribute("urlparse", "urlparse", "urllib.parse"), - MovedAttribute("urlsplit", "urlparse", "urllib.parse"), - MovedAttribute("urlunparse", "urlparse", "urllib.parse"), - MovedAttribute("urlunsplit", "urlparse", "urllib.parse"), - MovedAttribute("quote", "urllib", "urllib.parse"), - MovedAttribute("quote_plus", "urllib", "urllib.parse"), - MovedAttribute("unquote", "urllib", "urllib.parse"), - MovedAttribute("unquote_plus", "urllib", "urllib.parse"), - MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"), - MovedAttribute("urlencode", "urllib", "urllib.parse"), - MovedAttribute("splitquery", "urllib", "urllib.parse"), - MovedAttribute("splittag", "urllib", "urllib.parse"), - MovedAttribute("splituser", "urllib", "urllib.parse"), - MovedAttribute("splitvalue", "urllib", "urllib.parse"), - MovedAttribute("uses_fragment", "urlparse", "urllib.parse"), - MovedAttribute("uses_netloc", "urlparse", "urllib.parse"), - MovedAttribute("uses_params", "urlparse", "urllib.parse"), - MovedAttribute("uses_query", "urlparse", "urllib.parse"), - MovedAttribute("uses_relative", "urlparse", "urllib.parse"), -] -for attr in _urllib_parse_moved_attributes: - setattr(Module_six_moves_urllib_parse, attr.name, attr) -del attr - -Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes - -_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), - "moves.urllib_parse", "moves.urllib.parse") - - -class Module_six_moves_urllib_error(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_error""" - - -_urllib_error_moved_attributes = [ - MovedAttribute("URLError", "urllib2", "urllib.error"), - MovedAttribute("HTTPError", "urllib2", "urllib.error"), - MovedAttribute("ContentTooShortError", "urllib", "urllib.error"), -] -for attr in _urllib_error_moved_attributes: - setattr(Module_six_moves_urllib_error, attr.name, attr) -del attr - -Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes - -_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), - "moves.urllib_error", "moves.urllib.error") - - -class Module_six_moves_urllib_request(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_request""" - - -_urllib_request_moved_attributes = [ - MovedAttribute("urlopen", "urllib2", "urllib.request"), - MovedAttribute("install_opener", "urllib2", "urllib.request"), - MovedAttribute("build_opener", "urllib2", "urllib.request"), - MovedAttribute("pathname2url", "urllib", "urllib.request"), - MovedAttribute("url2pathname", "urllib", "urllib.request"), - MovedAttribute("getproxies", "urllib", "urllib.request"), - MovedAttribute("Request", "urllib2", "urllib.request"), - MovedAttribute("OpenerDirector", "urllib2", "urllib.request"), - MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"), - MovedAttribute("ProxyHandler", "urllib2", "urllib.request"), - MovedAttribute("BaseHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"), - MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"), - MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"), - MovedAttribute("FileHandler", "urllib2", "urllib.request"), - MovedAttribute("FTPHandler", "urllib2", "urllib.request"), - MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"), - MovedAttribute("UnknownHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"), - MovedAttribute("urlretrieve", "urllib", "urllib.request"), - MovedAttribute("urlcleanup", "urllib", "urllib.request"), - MovedAttribute("URLopener", "urllib", "urllib.request"), - MovedAttribute("FancyURLopener", "urllib", "urllib.request"), - MovedAttribute("proxy_bypass", "urllib", "urllib.request"), - MovedAttribute("parse_http_list", "urllib2", "urllib.request"), - MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"), -] -for attr in _urllib_request_moved_attributes: - setattr(Module_six_moves_urllib_request, attr.name, attr) -del attr - -Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes - -_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), - "moves.urllib_request", "moves.urllib.request") - - -class Module_six_moves_urllib_response(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_response""" - - -_urllib_response_moved_attributes = [ - MovedAttribute("addbase", "urllib", "urllib.response"), - MovedAttribute("addclosehook", "urllib", "urllib.response"), - MovedAttribute("addinfo", "urllib", "urllib.response"), - MovedAttribute("addinfourl", "urllib", "urllib.response"), -] -for attr in _urllib_response_moved_attributes: - setattr(Module_six_moves_urllib_response, attr.name, attr) -del attr - -Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes - -_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), - "moves.urllib_response", "moves.urllib.response") - - -class Module_six_moves_urllib_robotparser(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_robotparser""" - - -_urllib_robotparser_moved_attributes = [ - MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"), -] -for attr in _urllib_robotparser_moved_attributes: - setattr(Module_six_moves_urllib_robotparser, attr.name, attr) -del attr - -Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes - -_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), - "moves.urllib_robotparser", "moves.urllib.robotparser") - - -class Module_six_moves_urllib(types.ModuleType): - - """Create a six.moves.urllib namespace that resembles the Python 3 namespace""" - __path__ = [] # mark as package - parse = _importer._get_module("moves.urllib_parse") - error = _importer._get_module("moves.urllib_error") - request = _importer._get_module("moves.urllib_request") - response = _importer._get_module("moves.urllib_response") - robotparser = _importer._get_module("moves.urllib_robotparser") - - def __dir__(self): - return ['parse', 'error', 'request', 'response', 'robotparser'] - -_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"), - "moves.urllib") - - -def add_move(move): - """Add an item to six.moves.""" - setattr(_MovedItems, move.name, move) - - -def remove_move(name): - """Remove item from six.moves.""" - try: - delattr(_MovedItems, name) - except AttributeError: - try: - del moves.__dict__[name] - except KeyError: - raise AttributeError("no such move, %r" % (name,)) - - -if PY3: - _meth_func = "__func__" - _meth_self = "__self__" - - _func_closure = "__closure__" - _func_code = "__code__" - _func_defaults = "__defaults__" - _func_globals = "__globals__" -else: - _meth_func = "im_func" - _meth_self = "im_self" - - _func_closure = "func_closure" - _func_code = "func_code" - _func_defaults = "func_defaults" - _func_globals = "func_globals" - - -try: - advance_iterator = next -except NameError: - def advance_iterator(it): - return it.next() -next = advance_iterator - - -try: - callable = callable -except NameError: - def callable(obj): - return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) - - -if PY3: - def get_unbound_function(unbound): - return unbound - - create_bound_method = types.MethodType - - def create_unbound_method(func, cls): - return func - - Iterator = object -else: - def get_unbound_function(unbound): - return unbound.im_func - - def create_bound_method(func, obj): - return types.MethodType(func, obj, obj.__class__) - - def create_unbound_method(func, cls): - return types.MethodType(func, None, cls) - - class Iterator(object): - - def next(self): - return type(self).__next__(self) - - callable = callable -_add_doc(get_unbound_function, - """Get the function out of a possibly unbound function""") - - -get_method_function = operator.attrgetter(_meth_func) -get_method_self = operator.attrgetter(_meth_self) -get_function_closure = operator.attrgetter(_func_closure) -get_function_code = operator.attrgetter(_func_code) -get_function_defaults = operator.attrgetter(_func_defaults) -get_function_globals = operator.attrgetter(_func_globals) - - -if PY3: - def iterkeys(d, **kw): - return iter(d.keys(**kw)) - - def itervalues(d, **kw): - return iter(d.values(**kw)) - - def iteritems(d, **kw): - return iter(d.items(**kw)) - - def iterlists(d, **kw): - return iter(d.lists(**kw)) - - viewkeys = operator.methodcaller("keys") - - viewvalues = operator.methodcaller("values") - - viewitems = operator.methodcaller("items") -else: - def iterkeys(d, **kw): - return d.iterkeys(**kw) - - def itervalues(d, **kw): - return d.itervalues(**kw) - - def iteritems(d, **kw): - return d.iteritems(**kw) - - def iterlists(d, **kw): - return d.iterlists(**kw) - - viewkeys = operator.methodcaller("viewkeys") - - viewvalues = operator.methodcaller("viewvalues") - - viewitems = operator.methodcaller("viewitems") - -_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") -_add_doc(itervalues, "Return an iterator over the values of a dictionary.") -_add_doc(iteritems, - "Return an iterator over the (key, value) pairs of a dictionary.") -_add_doc(iterlists, - "Return an iterator over the (key, [values]) pairs of a dictionary.") - - -if PY3: - def b(s): - return s.encode("latin-1") - - def u(s): - return s - unichr = chr - import struct - int2byte = struct.Struct(">B").pack - del struct - byte2int = operator.itemgetter(0) - indexbytes = operator.getitem - iterbytes = iter - import io - StringIO = io.StringIO - BytesIO = io.BytesIO - del io - _assertCountEqual = "assertCountEqual" - if sys.version_info[1] <= 1: - _assertRaisesRegex = "assertRaisesRegexp" - _assertRegex = "assertRegexpMatches" - _assertNotRegex = "assertNotRegexpMatches" - else: - _assertRaisesRegex = "assertRaisesRegex" - _assertRegex = "assertRegex" - _assertNotRegex = "assertNotRegex" -else: - def b(s): - return s - # Workaround for standalone backslash - - def u(s): - return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape") - unichr = unichr - int2byte = chr - - def byte2int(bs): - return ord(bs[0]) - - def indexbytes(buf, i): - return ord(buf[i]) - iterbytes = functools.partial(itertools.imap, ord) - import StringIO - StringIO = BytesIO = StringIO.StringIO - _assertCountEqual = "assertItemsEqual" - _assertRaisesRegex = "assertRaisesRegexp" - _assertRegex = "assertRegexpMatches" - _assertNotRegex = "assertNotRegexpMatches" -_add_doc(b, """Byte literal""") -_add_doc(u, """Text literal""") - - -def assertCountEqual(self, *args, **kwargs): - return getattr(self, _assertCountEqual)(*args, **kwargs) - - -def assertRaisesRegex(self, *args, **kwargs): - return getattr(self, _assertRaisesRegex)(*args, **kwargs) - - -def assertRegex(self, *args, **kwargs): - return getattr(self, _assertRegex)(*args, **kwargs) - - -def assertNotRegex(self, *args, **kwargs): - return getattr(self, _assertNotRegex)(*args, **kwargs) - - -if PY3: - exec_ = getattr(moves.builtins, "exec") - - def reraise(tp, value, tb=None): - try: - if value is None: - value = tp() - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value - finally: - value = None - tb = None - -else: - def exec_(_code_, _globs_=None, _locs_=None): - """Execute code in a namespace.""" - if _globs_ is None: - frame = sys._getframe(1) - _globs_ = frame.f_globals - if _locs_ is None: - _locs_ = frame.f_locals - del frame - elif _locs_ is None: - _locs_ = _globs_ - exec("""exec _code_ in _globs_, _locs_""") - - exec_("""def reraise(tp, value, tb=None): - try: - raise tp, value, tb - finally: - tb = None -""") - - -if sys.version_info[:2] > (3,): - exec_("""def raise_from(value, from_value): - try: - raise value from from_value - finally: - value = None -""") -else: - def raise_from(value, from_value): - raise value - - -print_ = getattr(moves.builtins, "print", None) -if print_ is None: - def print_(*args, **kwargs): - """The new-style print function for Python 2.4 and 2.5.""" - fp = kwargs.pop("file", sys.stdout) - if fp is None: - return - - def write(data): - if not isinstance(data, basestring): - data = str(data) - # If the file has an encoding, encode unicode with it. - if (isinstance(fp, file) and - isinstance(data, unicode) and - fp.encoding is not None): - errors = getattr(fp, "errors", None) - if errors is None: - errors = "strict" - data = data.encode(fp.encoding, errors) - fp.write(data) - want_unicode = False - sep = kwargs.pop("sep", None) - if sep is not None: - if isinstance(sep, unicode): - want_unicode = True - elif not isinstance(sep, str): - raise TypeError("sep must be None or a string") - end = kwargs.pop("end", None) - if end is not None: - if isinstance(end, unicode): - want_unicode = True - elif not isinstance(end, str): - raise TypeError("end must be None or a string") - if kwargs: - raise TypeError("invalid keyword arguments to print()") - if not want_unicode: - for arg in args: - if isinstance(arg, unicode): - want_unicode = True - break - if want_unicode: - newline = unicode("\n") - space = unicode(" ") - else: - newline = "\n" - space = " " - if sep is None: - sep = space - if end is None: - end = newline - for i, arg in enumerate(args): - if i: - write(sep) - write(arg) - write(end) -if sys.version_info[:2] < (3, 3): - _print = print_ - - def print_(*args, **kwargs): - fp = kwargs.get("file", sys.stdout) - flush = kwargs.pop("flush", False) - _print(*args, **kwargs) - if flush and fp is not None: - fp.flush() - -_add_doc(reraise, """Reraise an exception.""") - -if sys.version_info[0:2] < (3, 4): - # This does exactly the same what the :func:`py3:functools.update_wrapper` - # function does on Python versions after 3.2. It sets the ``__wrapped__`` - # attribute on ``wrapper`` object and it doesn't raise an error if any of - # the attributes mentioned in ``assigned`` and ``updated`` are missing on - # ``wrapped`` object. - def _update_wrapper(wrapper, wrapped, - assigned=functools.WRAPPER_ASSIGNMENTS, - updated=functools.WRAPPER_UPDATES): - for attr in assigned: - try: - value = getattr(wrapped, attr) - except AttributeError: - continue - else: - setattr(wrapper, attr, value) - for attr in updated: - getattr(wrapper, attr).update(getattr(wrapped, attr, {})) - wrapper.__wrapped__ = wrapped - return wrapper - _update_wrapper.__doc__ = functools.update_wrapper.__doc__ - - def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, - updated=functools.WRAPPER_UPDATES): - return functools.partial(_update_wrapper, wrapped=wrapped, - assigned=assigned, updated=updated) - wraps.__doc__ = functools.wraps.__doc__ - -else: - wraps = functools.wraps - - -def with_metaclass(meta, *bases): - """Create a base class with a metaclass.""" - # This requires a bit of explanation: the basic idea is to make a dummy - # metaclass for one level of class instantiation that replaces itself with - # the actual metaclass. - class metaclass(type): - - def __new__(cls, name, this_bases, d): - if sys.version_info[:2] >= (3, 7): - # This version introduced PEP 560 that requires a bit - # of extra care (we mimic what is done by __build_class__). - resolved_bases = types.resolve_bases(bases) - if resolved_bases is not bases: - d['__orig_bases__'] = bases - else: - resolved_bases = bases - return meta(name, resolved_bases, d) - - @classmethod - def __prepare__(cls, name, this_bases): - return meta.__prepare__(name, bases) - return type.__new__(metaclass, 'temporary_class', (), {}) - - -def add_metaclass(metaclass): - """Class decorator for creating a class with a metaclass.""" - def wrapper(cls): - orig_vars = cls.__dict__.copy() - slots = orig_vars.get('__slots__') - if slots is not None: - if isinstance(slots, str): - slots = [slots] - for slots_var in slots: - orig_vars.pop(slots_var) - orig_vars.pop('__dict__', None) - orig_vars.pop('__weakref__', None) - if hasattr(cls, '__qualname__'): - orig_vars['__qualname__'] = cls.__qualname__ - return metaclass(cls.__name__, cls.__bases__, orig_vars) - return wrapper - - -def ensure_binary(s, encoding='utf-8', errors='strict'): - """Coerce **s** to six.binary_type. - - For Python 2: - - `unicode` -> encoded to `str` - - `str` -> `str` - - For Python 3: - - `str` -> encoded to `bytes` - - `bytes` -> `bytes` - """ - if isinstance(s, text_type): - return s.encode(encoding, errors) - elif isinstance(s, binary_type): - return s - else: - raise TypeError("not expecting type '%s'" % type(s)) - - -def ensure_str(s, encoding='utf-8', errors='strict'): - """Coerce *s* to `str`. - - For Python 2: - - `unicode` -> encoded to `str` - - `str` -> `str` - - For Python 3: - - `str` -> `str` - - `bytes` -> decoded to `str` - """ - if not isinstance(s, (text_type, binary_type)): - raise TypeError("not expecting type '%s'" % type(s)) - if PY2 and isinstance(s, text_type): - s = s.encode(encoding, errors) - elif PY3 and isinstance(s, binary_type): - s = s.decode(encoding, errors) - return s - - -def ensure_text(s, encoding='utf-8', errors='strict'): - """Coerce *s* to six.text_type. - - For Python 2: - - `unicode` -> `unicode` - - `str` -> `unicode` - - For Python 3: - - `str` -> `str` - - `bytes` -> decoded to `str` - """ - if isinstance(s, binary_type): - return s.decode(encoding, errors) - elif isinstance(s, text_type): - return s - else: - raise TypeError("not expecting type '%s'" % type(s)) - - -def python_2_unicode_compatible(klass): - """ - A class decorator that defines __unicode__ and __str__ methods under Python 2. - Under Python 3 it does nothing. - - To support Python 2 and 3 with a single code base, define a __str__ method - returning text and apply this decorator to the class. - """ - if PY2: - if '__str__' not in klass.__dict__: - raise ValueError("@python_2_unicode_compatible cannot be applied " - "to %s because it doesn't define __str__()." % - klass.__name__) - klass.__unicode__ = klass.__str__ - klass.__str__ = lambda self: self.__unicode__().encode('utf-8') - return klass - - -# Complete the moves implementation. -# This code is at the end of this module to speed up module loading. -# Turn this module into a package. -__path__ = [] # required for PEP 302 and PEP 451 -__package__ = __name__ # see PEP 366 @ReservedAssignment -if globals().get("__spec__") is not None: - __spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable -# Remove other six meta path importers, since they cause problems. This can -# happen if six is removed from sys.modules and then reloaded. (Setuptools does -# this for some reason.) -if sys.meta_path: - for i, importer in enumerate(sys.meta_path): - # Here's some real nastiness: Another "instance" of the six module might - # be floating around. Therefore, we can't use isinstance() to check for - # the six meta path importer, since the other six instance will have - # inserted an importer with different class. - if (type(importer).__name__ == "_SixMetaPathImporter" and - importer.name == __name__): - del sys.meta_path[i] - break - del i, importer -# Finally, add the importer to the meta path import hook. -sys.meta_path.append(_importer) - -import warnings - -def deprecated(message): - def deprecated_decorator(func): - def deprecated_func(*args, **kwargs): - warnings.warn("{} is a deprecated function. {}".format(func.__name__, message), - category=DeprecationWarning, - stacklevel=2) - warnings.simplefilter('default', DeprecationWarning) - return func(*args, **kwargs) - return deprecated_func - return deprecated_decorator \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 2ae28399f..000000000 --- a/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -pass diff --git a/tests/modularinput/__init__.py b/tests/modularinput/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/modularinput/modularinput_testlib.py b/tests/modularinput/modularinput_testlib.py index 819301736..d4846a408 100644 --- a/tests/modularinput/modularinput_testlib.py +++ b/tests/modularinput/modularinput_testlib.py @@ -15,11 +15,7 @@ # under the License. # Utility file for unit tests, import common functions and modules -from __future__ import absolute_import -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest import sys, os import io diff --git a/tests/modularinput/test_event.py b/tests/modularinput/test_event.py index 865656031..278abb81f 100644 --- a/tests/modularinput/test_event.py +++ b/tests/modularinput/test_event.py @@ -14,13 +14,12 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import import sys import pytest -from tests.modularinput.modularinput_testlib import unittest, xml_compare, data_open +from tests.modularinput.modularinput_testlib import xml_compare, data_open from splunklib.modularinput.event import Event, ET from splunklib.modularinput.event_writer import EventWriter @@ -99,7 +98,7 @@ def test_writing_events_on_event_writer(capsys): first_out_part = captured.out with data_open("data/stream_with_one_event.xml") as data: - found = ET.fromstring("%s" % first_out_part) + found = ET.fromstring(f"{first_out_part}") expected = ET.parse(data).getroot() assert xml_compare(expected, found) diff --git a/tests/modularinput/test_input_definition.py b/tests/modularinput/test_input_definition.py index d0f59a04e..520eafbcf 100644 --- a/tests/modularinput/test_input_definition.py +++ b/tests/modularinput/test_input_definition.py @@ -14,10 +14,10 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests.modularinput.modularinput_testlib import unittest, data_open from splunklib.modularinput.input_definition import InputDefinition + class InputDefinitionTestCase(unittest.TestCase): def test_parse_inputdef_with_zero_inputs(self): @@ -72,5 +72,6 @@ def test_attempt_to_parse_malformed_input_definition_will_throw_exception(self): with self.assertRaises(ValueError): found = InputDefinition.parse(data_open("data/conf_with_invalid_inputs.xml")) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/modularinput/test_scheme.py b/tests/modularinput/test_scheme.py index e1b3463a3..e38d81a5d 100644 --- a/tests/modularinput/test_scheme.py +++ b/tests/modularinput/test_scheme.py @@ -13,15 +13,11 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import +import xml.etree.ElementTree as ET from tests.modularinput.modularinput_testlib import unittest, xml_compare, data_open from splunklib.modularinput.scheme import Scheme from splunklib.modularinput.argument import Argument -try: - import xml.etree.cElementTree as ET -except ImportError: - import xml.etree.ElementTree as ET class SchemeTest(unittest.TestCase): def test_generate_xml_from_scheme_with_default_values(self): @@ -40,7 +36,7 @@ def test_generate_xml_from_scheme(self): some arguments added matches what we expect.""" scheme = Scheme("abcd") - scheme.description = u"쎼 and 쎶 and <&> für" + scheme.description = "쎼 and 쎶 and <&> für" scheme.streaming_mode = Scheme.streaming_mode_simple scheme.use_external_validation = "false" scheme.use_single_instance = "true" @@ -50,7 +46,7 @@ def test_generate_xml_from_scheme(self): arg2 = Argument( name="arg2", - description=u"쎼 and 쎶 and <&> für", + description="쎼 and 쎶 and <&> für", validation="is_pos_int('some_name')", data_type=Argument.data_type_number, required_on_edit=True, @@ -69,7 +65,7 @@ def test_generate_xml_from_scheme_with_arg_title(self): some arguments added matches what we expect. Also sets the title on an argument.""" scheme = Scheme("abcd") - scheme.description = u"쎼 and 쎶 and <&> für" + scheme.description = "쎼 and 쎶 and <&> für" scheme.streaming_mode = Scheme.streaming_mode_simple scheme.use_external_validation = "false" scheme.use_single_instance = "true" @@ -79,7 +75,7 @@ def test_generate_xml_from_scheme_with_arg_title(self): arg2 = Argument( name="arg2", - description=u"쎼 and 쎶 and <&> für", + description="쎼 and 쎶 and <&> für", validation="is_pos_int('some_name')", data_type=Argument.data_type_number, required_on_edit=True, @@ -113,7 +109,7 @@ def test_generate_xml_from_argument(self): argument = Argument( name="some_name", - description=u"쎼 and 쎶 and <&> für", + description="쎼 and 쎶 and <&> für", validation="is_pos_int('some_name')", data_type=Argument.data_type_boolean, required_on_edit="true", diff --git a/tests/modularinput/test_script.py b/tests/modularinput/test_script.py index b15885dc7..48be8826b 100644 --- a/tests/modularinput/test_script.py +++ b/tests/modularinput/test_script.py @@ -1,16 +1,13 @@ import sys +import io +import xml.etree.ElementTree as ET from splunklib.client import Service from splunklib.modularinput import Script, EventWriter, Scheme, Argument, Event -import io from splunklib.modularinput.utils import xml_compare from tests.modularinput.modularinput_testlib import data_open -try: - import xml.etree.cElementTree as ET -except ImportError: - import xml.etree.ElementTree as ET TEST_SCRIPT_PATH = "__IGNORED_SCRIPT_PATH__" @@ -51,7 +48,7 @@ def test_scheme_properly_generated_by_script(capsys): class NewScript(Script): def get_scheme(self): scheme = Scheme("abcd") - scheme.description = u"\uC3BC and \uC3B6 and <&> f\u00FCr" + scheme.description = "\uC3BC and \uC3B6 and <&> f\u00FCr" scheme.streaming_mode = scheme.streaming_mode_simple scheme.use_external_validation = False scheme.use_single_instance = True @@ -60,7 +57,7 @@ def get_scheme(self): scheme.add_argument(arg1) arg2 = Argument("arg2") - arg2.description = u"\uC3BC and \uC3B6 and <&> f\u00FCr" + arg2.description = "\uC3BC and \uC3B6 and <&> f\u00FCr" arg2.data_type = Argument.data_type_number arg2.required_on_create = True arg2.required_on_edit = True @@ -208,7 +205,7 @@ def test_service_property(capsys): # Override abstract methods class NewScript(Script): def __init__(self): - super(NewScript, self).__init__() + super().__init__() self.authority_uri = None def get_scheme(self): diff --git a/tests/modularinput/test_validation_definition.py b/tests/modularinput/test_validation_definition.py index c8046f3b3..43871c51a 100644 --- a/tests/modularinput/test_validation_definition.py +++ b/tests/modularinput/test_validation_definition.py @@ -14,10 +14,11 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import + from tests.modularinput.modularinput_testlib import unittest, data_open from splunklib.modularinput.validation_definition import ValidationDefinition + class ValidationDefinitionTestCase(unittest.TestCase): def test_validation_definition_parse(self): """Check that parsing produces expected result""" @@ -42,5 +43,6 @@ def test_validation_definition_parse(self): self.assertEqual(expected, found) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/searchcommands/__init__.py b/tests/searchcommands/__init__.py index 2f282889a..0f260b58f 100644 --- a/tests/searchcommands/__init__.py +++ b/tests/searchcommands/__init__.py @@ -15,10 +15,6 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - -from sys import version_info as python_version - from os import path import logging diff --git a/tests/searchcommands/chunked_data_stream.py b/tests/searchcommands/chunked_data_stream.py index ae5363eff..39782c444 100644 --- a/tests/searchcommands/chunked_data_stream.py +++ b/tests/searchcommands/chunked_data_stream.py @@ -4,33 +4,32 @@ import json import splunklib.searchcommands.internals -from splunklib import six -class Chunk(object): +class Chunk: def __init__(self, version, meta, data): - self.version = six.ensure_str(version) + self.version = version self.meta = json.loads(meta) dialect = splunklib.searchcommands.internals.CsvDialect self.data = csv.DictReader(io.StringIO(data.decode("utf-8")), dialect=dialect) -class ChunkedDataStreamIter(collections.Iterator): +class ChunkedDataStreamIter(collections.abc.Iterator): def __init__(self, chunk_stream): self.chunk_stream = chunk_stream def __next__(self): - return self.next() + return next(self) - def next(self): + def __next__(self): try: return self.chunk_stream.read_chunk() except EOFError: raise StopIteration -class ChunkedDataStream(collections.Iterable): +class ChunkedDataStream(collections.abc.Iterable): def __iter__(self): return ChunkedDataStreamIter(self) @@ -54,7 +53,7 @@ def read_chunk(self): def build_chunk(keyval, data=None): - metadata = six.ensure_binary(json.dumps(keyval), 'utf-8') + metadata = json.dumps(keyval).encode('utf-8') data_output = _build_data_csv(data) return b"chunked 1.0,%d,%d\n%s%s" % (len(metadata), len(data_output), metadata, data_output) @@ -87,14 +86,14 @@ def _build_data_csv(data): return b'' if isinstance(data, bytes): return data - csvout = splunklib.six.StringIO() + csvout = io.StringIO() headers = set() for datum in data: - headers.update(datum.keys()) + headers.update(list(datum.keys())) writer = csv.DictWriter(csvout, headers, dialect=splunklib.searchcommands.internals.CsvDialect) writer.writeheader() for datum in data: writer.writerow(datum) - return six.ensure_binary(csvout.getvalue()) + return csvout.getvalue().encode('utf-8') diff --git a/tests/searchcommands/test_builtin_options.py b/tests/searchcommands/test_builtin_options.py index e5c2dd8dd..07a343eff 100644 --- a/tests/searchcommands/test_builtin_options.py +++ b/tests/searchcommands/test_builtin_options.py @@ -15,19 +15,15 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - -from splunklib.six.moves import cStringIO as StringIO -try: - from unittest2 import main, TestCase -except ImportError: - from unittest import main, TestCase import os import sys import logging +from unittest import main, TestCase import pytest +from io import StringIO + from splunklib.searchcommands import environment from splunklib.searchcommands.decorators import Configuration @@ -117,18 +113,18 @@ def test_logging_configuration(self): except ValueError: pass except BaseException as e: - self.fail('Expected ValueError, but {} was raised'.format(type(e))) + self.fail(f'Expected ValueError, but {type(e)} was raised') else: - self.fail('Expected ValueError, but logging_configuration={}'.format(command.logging_configuration)) + self.fail(f'Expected ValueError, but logging_configuration={command.logging_configuration}') try: command.logging_configuration = os.path.join(package_directory, 'non-existent.logging.conf') except ValueError: pass except BaseException as e: - self.fail('Expected ValueError, but {} was raised'.format(type(e))) + self.fail(f'Expected ValueError, but {type(e)} was raised') else: - self.fail('Expected ValueError, but logging_configuration={}'.format(command.logging_configuration)) + self.fail(f'Expected ValueError, but logging_configuration={command.logging_configuration}') def test_logging_level(self): @@ -146,7 +142,7 @@ def test_logging_level(self): self.assertEqual(warning, command.logging_level) for level in level_names(): - if type(level) is int: + if isinstance(level, int): command.logging_level = level level_name = logging.getLevelName(level) self.assertEqual(command.logging_level, warning if level_name == notset else level_name) @@ -171,9 +167,9 @@ def test_logging_level(self): except ValueError: pass except BaseException as e: - self.fail('Expected ValueError, but {} was raised'.format(type(e))) + self.fail(f'Expected ValueError, but {type(e)} was raised') else: - self.fail('Expected ValueError, but logging_level={}'.format(command.logging_level)) + self.fail(f'Expected ValueError, but logging_level={command.logging_level}') self.assertEqual(command.logging_level, current_value) @@ -211,13 +207,9 @@ def _test_boolean_option(self, option): except ValueError: pass except BaseException as error: - self.fail('Expected ValueError when setting {}={}, but {} was raised'.format( - option.name, repr(value), type(error))) + self.fail(f'Expected ValueError when setting {option.name}={repr(value)}, but {type(error)} was raised') else: - self.fail('Expected ValueError, but {}={} was accepted.'.format( - option.name, repr(option.fget(command)))) - - return + self.fail(f'Expected ValueError, but {option.name}={repr(option.fget(command))} was accepted.') if __name__ == "__main__": diff --git a/tests/searchcommands/test_configuration_settings.py b/tests/searchcommands/test_configuration_settings.py index dd07b57fc..65d0d3a4a 100644 --- a/tests/searchcommands/test_configuration_settings.py +++ b/tests/searchcommands/test_configuration_settings.py @@ -24,20 +24,19 @@ # * If a value is not set in code, the value specified in commands.conf is enforced # * If a value is set in code, it overrides the value specified in commands.conf -from __future__ import absolute_import, division, print_function, unicode_literals -from splunklib.searchcommands.decorators import Configuration from unittest import main, TestCase -from splunklib import six - import pytest +from splunklib.searchcommands.decorators import Configuration + + @pytest.mark.smoke class TestConfigurationSettings(TestCase): def test_generating_command(self): - from splunklib.searchcommands import Configuration, GeneratingCommand + from splunklib.searchcommands import GeneratingCommand @Configuration() class TestCommand(GeneratingCommand): @@ -48,7 +47,7 @@ def generate(self): command._protocol_version = 1 self.assertTrue( - [(name, value) for name, value in six.iteritems(command.configuration)], + list(command.configuration.items()), [('generating', True)]) self.assertIs(command.configuration.generates_timeorder, None) @@ -66,12 +65,12 @@ def generate(self): except AttributeError: pass except Exception as error: - self.fail('Expected AttributeError, not {}: {}'.format(type(error).__name__, error)) + self.fail(f'Expected AttributeError, not {type(error).__name__}: {error}') else: self.fail('Expected AttributeError') self.assertEqual( - [(name, value) for name, value in six.iteritems(command.configuration)], + list(command.configuration.items()), [('generates_timeorder', True), ('generating', True), ('local', True), ('retainsevents', True), ('streaming', True)]) @@ -79,7 +78,7 @@ def generate(self): command._protocol_version = 2 self.assertEqual( - [(name, value) for name, value in six.iteritems(command.configuration)], + list(command.configuration.items()), [('generating', True), ('type', 'stateful')]) self.assertIs(command.configuration.distributed, False) @@ -93,19 +92,17 @@ def generate(self): except AttributeError: pass except Exception as error: - self.fail('Expected AttributeError, not {}: {}'.format(type(error).__name__, error)) + self.fail(f'Expected AttributeError, not {type(error).__name__}: {error}') else: self.fail('Expected AttributeError') self.assertEqual( - [(name, value) for name, value in six.iteritems(command.configuration)], + list(command.configuration.items()), [('generating', True), ('type', 'streaming')]) - return - def test_streaming_command(self): - from splunklib.searchcommands import Configuration, StreamingCommand + from splunklib.searchcommands import StreamingCommand @Configuration() class TestCommand(StreamingCommand): @@ -117,7 +114,7 @@ def stream(self, records): command._protocol_version = 1 self.assertEqual( - [(name, value) for name, value in six.iteritems(command.configuration)], + list(command.configuration.items()), [('streaming', True)]) self.assertIs(command.configuration.clear_required_fields, None) @@ -136,19 +133,20 @@ def stream(self, records): except AttributeError: pass except Exception as error: - self.fail('Expected AttributeError, not {}: {}'.format(type(error).__name__, error)) + self.fail(f'Expected AttributeError, not {type(error).__name__}: {error}') else: self.fail('Expected AttributeError') self.assertEqual( - [(name, value) for name, value in six.iteritems(command.configuration)], - [('clear_required_fields', True), ('local', True), ('overrides_timeorder', True), ('required_fields', ['field_1', 'field_2', 'field_3']), ('streaming', True)]) + list(command.configuration.items()), + [('clear_required_fields', True), ('local', True), ('overrides_timeorder', True), + ('required_fields', ['field_1', 'field_2', 'field_3']), ('streaming', True)]) command = TestCommand() command._protocol_version = 2 self.assertEqual( - [(name, value) for name, value in six.iteritems(command.configuration)], + list(command.configuration.items()), [('type', 'streaming')]) self.assertIs(command.configuration.distributed, True) @@ -162,15 +160,14 @@ def stream(self, records): except AttributeError: pass except Exception as error: - self.fail('Expected AttributeError, not {}: {}'.format(type(error).__name__, error)) + self.fail(f'Expected AttributeError, not {type(error).__name__}: {error}') else: self.fail('Expected AttributeError') self.assertEqual( - [(name, value) for name, value in six.iteritems(command.configuration)], + list(command.configuration.items()), [('required_fields', ['field_1', 'field_2', 'field_3']), ('type', 'stateful')]) - return if __name__ == "__main__": main() diff --git a/tests/searchcommands/test_decorators.py b/tests/searchcommands/test_decorators.py index dd65aa0ab..d258729cb 100755 --- a/tests/searchcommands/test_decorators.py +++ b/tests/searchcommands/test_decorators.py @@ -15,35 +15,23 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals -try: - from unittest2 import main, TestCase -except ImportError: - from unittest import main, TestCase +from unittest import main, TestCase import sys from io import TextIOWrapper +import pytest from splunklib.searchcommands import Configuration, Option, environment, validators from splunklib.searchcommands.decorators import ConfigurationSetting from splunklib.searchcommands.internals import json_encode_string from splunklib.searchcommands.search_command import SearchCommand -try: - from tests.searchcommands import rebase_environment -except ImportError: - # Skip on Python 2.6 - pass - -from splunklib import six - -import pytest +from tests.searchcommands import rebase_environment @Configuration() class TestSearchCommand(SearchCommand): - boolean = Option( doc=''' **Syntax:** **boolean=**** @@ -121,7 +109,7 @@ class TestSearchCommand(SearchCommand): **Syntax:** **integer=**** **Description:** An integer value''', require=True, validate=validators.Integer()) - + float = Option( doc=''' **Syntax:** **float=**** @@ -234,49 +222,48 @@ def fix_up(cls, command_class): return ConfiguredSearchCommand.ConfigurationSettings for name, values, error_values in ( - ('clear_required_fields', - (True, False), - (None, 'anything other than a bool')), - ('distributed', - (True, False), - (None, 'anything other than a bool')), - ('generates_timeorder', - (True, False), - (None, 'anything other than a bool')), - ('generating', - (True, False), - (None, 'anything other than a bool')), - ('maxinputs', - (0, 50000, sys.maxsize), - (None, -1, sys.maxsize + 1, 'anything other than an int')), - ('overrides_timeorder', - (True, False), - (None, 'anything other than a bool')), - ('required_fields', - (['field_1', 'field_2'], set(['field_1', 'field_2']), ('field_1', 'field_2')), - (None, 0xdead, {'foo': 1, 'bar': 2})), - ('requires_preop', - (True, False), - (None, 'anything other than a bool')), - ('retainsevents', - (True, False), - (None, 'anything other than a bool')), - ('run_in_preview', - (True, False), - (None, 'anything other than a bool')), - ('streaming', - (True, False), - (None, 'anything other than a bool')), - ('streaming_preop', - (u'some unicode string', b'some byte string'), - (None, 0xdead)), - ('type', - # TODO: Do we need to validate byte versions of these strings? - ('events', 'reporting', 'streaming'), - ('eventing', 0xdead))): + ('clear_required_fields', + (True, False), + (None, 'anything other than a bool')), + ('distributed', + (True, False), + (None, 'anything other than a bool')), + ('generates_timeorder', + (True, False), + (None, 'anything other than a bool')), + ('generating', + (True, False), + (None, 'anything other than a bool')), + ('maxinputs', + (0, 50000, sys.maxsize), + (None, -1, sys.maxsize + 1, 'anything other than an int')), + ('overrides_timeorder', + (True, False), + (None, 'anything other than a bool')), + ('required_fields', + (['field_1', 'field_2'], set(['field_1', 'field_2']), ('field_1', 'field_2')), + (None, 0xdead, {'foo': 1, 'bar': 2})), + ('requires_preop', + (True, False), + (None, 'anything other than a bool')), + ('retainsevents', + (True, False), + (None, 'anything other than a bool')), + ('run_in_preview', + (True, False), + (None, 'anything other than a bool')), + ('streaming', + (True, False), + (None, 'anything other than a bool')), + ('streaming_preop', + ('some unicode string', b'some byte string'), + (None, 0xdead)), + ('type', + # TODO: Do we need to validate byte versions of these strings? + ('events', 'reporting', 'streaming'), + ('eventing', 0xdead))): for value in values: - settings_class = new_configuration_settings_class(name, value) # Setting property exists @@ -299,25 +286,24 @@ def fix_up(cls, command_class): self.assertIn(backing_field_name, settings_instance.__dict__), self.assertEqual(getattr(settings_instance, name), value) self.assertEqual(settings_instance.__dict__[backing_field_name], value) - pass for value in error_values: try: new_configuration_settings_class(name, value) except Exception as error: - self.assertIsInstance(error, ValueError, 'Expected ValueError, not {}({}) for {}={}'.format(type(error).__name__, error, name, repr(value))) + self.assertIsInstance(error, ValueError, + 'Expected ValueError, not {}({}) for {}={}'.format(type(error).__name__, + error, name, repr(value))) else: - self.fail('Expected ValueError, not success for {}={}'.format(name, repr(value))) + self.fail(f'Expected ValueError, not success for {name}={repr(value)}') settings_class = new_configuration_settings_class() settings_instance = settings_class(command=None) self.assertRaises(ValueError, setattr, settings_instance, name, value) - return - def test_new_configuration_setting(self): - class Test(object): + class Test: generating = ConfigurationSetting() @ConfigurationSetting(name='required_fields') @@ -366,13 +352,13 @@ def test_option(self): command = TestSearchCommand() options = command.options - itervalues = lambda: six.itervalues(options) + #itervalues = lambda: options.values() options.reset() missing = options.get_missing() - self.assertListEqual(missing, [option.name for option in itervalues() if option.is_required]) - self.assertListEqual(presets, [six.text_type(option) for option in itervalues() if option.value is not None]) - self.assertListEqual(presets, [six.text_type(option) for option in itervalues() if six.text_type(option) != option.name + '=None']) + self.assertListEqual(missing, [option.name for option in list(options.values()) if option.is_required]) + self.assertListEqual(presets, [str(option) for option in list(options.values()) if option.value is not None]) + self.assertListEqual(presets, [str(option) for option in list(options.values()) if str(option) != option.name + '=None']) test_option_values = { validators.Boolean: ('0', 'non-boolean value'), @@ -389,7 +375,7 @@ def test_option(self): validators.RegularExpression: ('\\s+', '(poorly formed regular expression'), validators.Set: ('bar', 'non-existent set entry')} - for option in itervalues(): + for option in list(options.values()): validator = option.validator if validator is None: @@ -401,57 +387,57 @@ def test_option(self): self.assertEqual( validator.format(option.value), validator.format(validator.__call__(legal_value)), - "{}={}".format(option.name, legal_value)) + f"{option.name}={legal_value}") try: option.value = illegal_value except ValueError: pass except BaseException as error: - self.assertFalse('Expected ValueError for {}={}, not this {}: {}'.format( - option.name, illegal_value, type(error).__name__, error)) + self.assertFalse( + f'Expected ValueError for {option.name}={illegal_value}, not this {type(error).__name__}: {error}') else: - self.assertFalse('Expected ValueError for {}={}, not a pass.'.format(option.name, illegal_value)) + self.assertFalse(f'Expected ValueError for {option.name}={illegal_value}, not a pass.') expected = { - u'foo': False, + 'foo': False, 'boolean': False, - 'code': u'foo == \"bar\"', + 'code': 'foo == \"bar\"', 'duration': 89999, - 'fieldname': u'some.field_name', - 'file': six.text_type(repr(__file__)), + 'fieldname': 'some.field_name', + 'file': str(repr(__file__)), 'integer': 100, 'float': 99.9, 'logging_configuration': environment.logging_configuration, - 'logging_level': u'WARNING', + 'logging_level': 'WARNING', 'map': 'foo', - 'match': u'123-45-6789', - 'optionname': u'some_option_name', + 'match': '123-45-6789', + 'optionname': 'some_option_name', 'record': False, - 'regularexpression': u'\\s+', + 'regularexpression': '\\s+', 'required_boolean': False, - 'required_code': u'foo == \"bar\"', + 'required_code': 'foo == \"bar\"', 'required_duration': 89999, - 'required_fieldname': u'some.field_name', - 'required_file': six.text_type(repr(__file__)), + 'required_fieldname': 'some.field_name', + 'required_file': str(repr(__file__)), 'required_integer': 100, 'required_float': 99.9, 'required_map': 'foo', - 'required_match': u'123-45-6789', - 'required_optionname': u'some_option_name', - 'required_regularexpression': u'\\s+', - 'required_set': u'bar', - 'set': u'bar', + 'required_match': '123-45-6789', + 'required_optionname': 'some_option_name', + 'required_regularexpression': '\\s+', + 'required_set': 'bar', + 'set': 'bar', 'show_configuration': False, } self.maxDiff = None tuplewrap = lambda x: x if isinstance(x, tuple) else (x,) - invert = lambda x: {v: k for k, v in six.iteritems(x)} + invert = lambda x: {v: k for k, v in list(x.items())} - for x in six.itervalues(command.options): - # isinstance doesn't work for some reason + for x in list(command.options.values()): + # isinstance doesn't work for some reason if type(x.value).__name__ == 'Code': self.assertEqual(expected[x.name], x.value.source) elif type(x.validator).__name__ == 'Map': @@ -459,26 +445,28 @@ def test_option(self): elif type(x.validator).__name__ == 'RegularExpression': self.assertEqual(expected[x.name], x.value.pattern) elif isinstance(x.value, TextIOWrapper): - self.assertEqual(expected[x.name], "'%s'" % x.value.name) - elif not isinstance(x.value, (bool,) + (float,) + (six.text_type,) + (six.binary_type,) + tuplewrap(six.integer_types)): + self.assertEqual(expected[x.name], f"'{x.value.name}'") + elif not isinstance(x.value, (bool,) + (float,) + (str,) + (bytes,) + tuplewrap(int)): self.assertEqual(expected[x.name], repr(x.value)) else: self.assertEqual(expected[x.name], x.value) expected = ( - 'foo="f" boolean="f" code="foo == \\"bar\\"" duration="24:59:59" fieldname="some.field_name" ' - 'file=' + json_encode_string(__file__) + ' float="99.9" integer="100" map="foo" match="123-45-6789" ' - 'optionname="some_option_name" record="f" regularexpression="\\\\s+" required_boolean="f" ' - 'required_code="foo == \\"bar\\"" required_duration="24:59:59" required_fieldname="some.field_name" ' - 'required_file=' + json_encode_string(__file__) + ' required_float="99.9" required_integer="100" required_map="foo" ' - 'required_match="123-45-6789" required_optionname="some_option_name" required_regularexpression="\\\\s+" ' - 'required_set="bar" set="bar" show_configuration="f"') + 'foo="f" boolean="f" code="foo == \\"bar\\"" duration="24:59:59" fieldname="some.field_name" ' + 'file=' + json_encode_string(__file__) + ' float="99.9" integer="100" map="foo" match="123-45-6789" ' + 'optionname="some_option_name" record="f" regularexpression="\\\\s+" required_boolean="f" ' + 'required_code="foo == \\"bar\\"" required_duration="24:59:59" required_fieldname="some.field_name" ' + 'required_file=' + json_encode_string( + __file__) + ' required_float="99.9" required_integer="100" required_map="foo" ' + 'required_match="123-45-6789" required_optionname="some_option_name" required_regularexpression="\\\\s+" ' + 'required_set="bar" set="bar" show_configuration="f"') - observed = six.text_type(command.options) + observed = str(command.options) self.assertEqual(observed, expected) - return +TestSearchCommand.__test__ = False + if __name__ == "__main__": main() diff --git a/tests/searchcommands/test_generator_command.py b/tests/searchcommands/test_generator_command.py index 63ae3ac83..af103977a 100644 --- a/tests/searchcommands/test_generator_command.py +++ b/tests/searchcommands/test_generator_command.py @@ -1,9 +1,8 @@ import io import time -from . import chunked_data_stream as chunky - from splunklib.searchcommands import Configuration, GeneratingCommand +from . import chunked_data_stream as chunky def test_simple_generator(): @@ -12,6 +11,7 @@ class GeneratorTest(GeneratingCommand): def generate(self): for num in range(1, 10): yield {'_time': time.time(), 'event_index': num} + generator = GeneratorTest() in_stream = io.BytesIO() in_stream.write(chunky.build_getinfo_chunk()) @@ -24,7 +24,7 @@ def generate(self): ds = chunky.ChunkedDataStream(out_stream) is_first_chunk = True finished_seen = False - expected = set(map(lambda i: str(i), range(1, 10))) + expected = set(str(i) for i in range(1, 10)) seen = set() for chunk in ds: if is_first_chunk: @@ -40,15 +40,18 @@ def generate(self): assert expected.issubset(seen) assert finished_seen + def test_allow_empty_input_for_generating_command(): """ Passing allow_empty_input for generating command will cause an error """ + @Configuration() class GeneratorTest(GeneratingCommand): def generate(self): for num in range(1, 3): yield {"_index": num} + generator = GeneratorTest() in_stream = io.BytesIO() out_stream = io.BytesIO() @@ -58,6 +61,7 @@ def generate(self): except ValueError as error: assert str(error) == "allow_empty_input cannot be False for Generating Commands" + def test_all_fieldnames_present_for_generated_records(): @Configuration() class GeneratorTest(GeneratingCommand): diff --git a/tests/searchcommands/test_internals_v1.py b/tests/searchcommands/test_internals_v1.py index eb85d040a..a6a68840b 100755 --- a/tests/searchcommands/test_internals_v1.py +++ b/tests/searchcommands/test_internals_v1.py @@ -14,7 +14,11 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals +from contextlib import closing +from unittest import main, TestCase +import os +import pytest +from functools import reduce from splunklib.searchcommands.internals import CommandLineParser, InputHeader, RecordWriterV1 from splunklib.searchcommands.decorators import Configuration, Option @@ -22,19 +26,9 @@ from splunklib.searchcommands.search_command import SearchCommand -from contextlib import closing -from splunklib.six import StringIO, BytesIO +from io import StringIO, BytesIO -from splunklib.six.moves import zip as izip -from unittest import main, TestCase - -import os -from splunklib import six -from splunklib.six.moves import range -from functools import reduce - -import pytest @pytest.mark.smoke class TestInternals(TestCase): @@ -61,7 +55,7 @@ def fix_up(cls, command_class): pass command = TestCommandLineParserCommand() CommandLineParser.parse(command, options) - for option in six.itervalues(command.options): + for option in list(command.options.values()): if option.name in ['logging_configuration', 'logging_level', 'record', 'show_configuration']: self.assertFalse(option.is_set) continue @@ -78,7 +72,7 @@ def fix_up(cls, command_class): pass command = TestCommandLineParserCommand() CommandLineParser.parse(command, options + fieldnames) - for option in six.itervalues(command.options): + for option in list(command.options.values()): if option.name in ['logging_configuration', 'logging_level', 'record', 'show_configuration']: self.assertFalse(option.is_set) continue @@ -93,8 +87,9 @@ def fix_up(cls, command_class): pass command = TestCommandLineParserCommand() CommandLineParser.parse(command, ['required_option=true'] + fieldnames) - for option in six.itervalues(command.options): - if option.name in ['unnecessary_option', 'logging_configuration', 'logging_level', 'record', 'show_configuration']: + for option in list(command.options.values()): + if option.name in ['unnecessary_option', 'logging_configuration', 'logging_level', 'record', + 'show_configuration']: self.assertFalse(option.is_set) continue self.assertTrue(option.is_set) @@ -112,7 +107,8 @@ def fix_up(cls, command_class): pass # Command line with unrecognized options - self.assertRaises(ValueError, CommandLineParser.parse, command, ['unrecognized_option_1=foo', 'unrecognized_option_2=bar']) + self.assertRaises(ValueError, CommandLineParser.parse, command, + ['unrecognized_option_1=foo', 'unrecognized_option_2=bar']) # Command line with a variety of quoted/escaped text options @@ -145,19 +141,19 @@ def fix_up(cls, command_class): pass r'"Hello World!"' ] - for string, expected_value in izip(strings, expected_values): + for string, expected_value in zip(strings, expected_values): command = TestCommandLineParserCommand() argv = ['text', '=', string] CommandLineParser.parse(command, argv) self.assertEqual(command.text, expected_value) - for string, expected_value in izip(strings, expected_values): + for string, expected_value in zip(strings, expected_values): command = TestCommandLineParserCommand() argv = [string] CommandLineParser.parse(command, argv) self.assertEqual(command.fieldnames[0], expected_value) - for string, expected_value in izip(strings, expected_values): + for string, expected_value in zip(strings, expected_values): command = TestCommandLineParserCommand() argv = ['text', '=', string] + strings CommandLineParser.parse(command, argv) @@ -176,25 +172,23 @@ def fix_up(cls, command_class): pass argv = [string] self.assertRaises(SyntaxError, CommandLineParser.parse, command, argv) - return - def test_command_line_parser_unquote(self): parser = CommandLineParser options = [ - r'foo', # unquoted string with no escaped characters - r'fo\o\ b\"a\\r', # unquoted string with some escaped characters - r'"foo"', # quoted string with no special characters - r'"""foobar1"""', # quoted string with quotes escaped like this: "" - r'"\"foobar2\""', # quoted string with quotes escaped like this: \" - r'"foo ""x"" bar"', # quoted string with quotes escaped like this: "" - r'"foo \"x\" bar"', # quoted string with quotes escaped like this: \" - r'"\\foobar"', # quoted string with an escaped backslash - r'"foo \\ bar"', # quoted string with an escaped backslash - r'"foobar\\"', # quoted string with an escaped backslash - r'foo\\\bar', # quoted string with an escaped backslash and an escaped 'b' - r'""', # pair of quotes - r''] # empty string + r'foo', # unquoted string with no escaped characters + r'fo\o\ b\"a\\r', # unquoted string with some escaped characters + r'"foo"', # quoted string with no special characters + r'"""foobar1"""', # quoted string with quotes escaped like this: "" + r'"\"foobar2\""', # quoted string with quotes escaped like this: \" + r'"foo ""x"" bar"', # quoted string with quotes escaped like this: "" + r'"foo \"x\" bar"', # quoted string with quotes escaped like this: \" + r'"\\foobar"', # quoted string with an escaped backslash + r'"foo \\ bar"', # quoted string with an escaped backslash + r'"foobar\\"', # quoted string with an escaped backslash + r'foo\\\bar', # quoted string with an escaped backslash and an escaped 'b' + r'""', # pair of quotes + r''] # empty string expected = [ r'foo', @@ -288,7 +282,7 @@ def test_input_header(self): 'sentence': 'hello world!'} input_header = InputHeader() - text = reduce(lambda value, item: value + '{}:{}\n'.format(item[0], item[1]), six.iteritems(collection), '') + '\n' + text = reduce(lambda value, item: value + f'{item[0]}:{item[1]}\n', list(collection.items()), '') + '\n' with closing(StringIO(text)) as input_file: input_header.read(input_file) @@ -308,13 +302,10 @@ def test_input_header(self): self.assertEqual(sorted(input_header.keys()), sorted(collection.keys())) self.assertEqual(sorted(input_header.values()), sorted(collection.values())) - return - def test_messages_header(self): @Configuration() class TestMessagesHeaderCommand(SearchCommand): - class ConfigurationSettings(SearchCommand.ConfigurationSettings): @classmethod @@ -346,7 +337,6 @@ def fix_up(cls, command_class): pass '\r\n') self.assertEqual(output_buffer.getvalue().decode('utf-8'), expected) - return _package_path = os.path.dirname(__file__) diff --git a/tests/searchcommands/test_internals_v2.py b/tests/searchcommands/test_internals_v2.py index ec9b3f666..f54549778 100755 --- a/tests/searchcommands/test_internals_v2.py +++ b/tests/searchcommands/test_internals_v2.py @@ -15,46 +15,37 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals +import gzip +import io +import json +import os +import random +import sys -from splunklib.searchcommands.internals import MetadataDecoder, MetadataEncoder, Recorder, RecordWriterV2 -from splunklib.searchcommands import SearchMetric -from splunklib import six -from splunklib.six.moves import range -from collections import OrderedDict -from collections import namedtuple, deque -from splunklib.six import BytesIO as BytesIO +import pytest from functools import wraps -from glob import iglob from itertools import chain -from splunklib.six.moves import filter as ifilter -from splunklib.six.moves import map as imap -from splunklib.six.moves import zip as izip from sys import float_info from tempfile import mktemp from time import time from types import MethodType -from sys import version_info as python_version -try: - from unittest2 import main, TestCase -except ImportError: - from unittest import main, TestCase +from unittest import main, TestCase -import splunklib.six.moves.cPickle as pickle -import gzip -import io -import json -import os -import random +from collections import OrderedDict +from collections import namedtuple, deque + +from splunklib.searchcommands.internals import MetadataDecoder, MetadataEncoder, Recorder, RecordWriterV2 +from splunklib.searchcommands import SearchMetric +from io import BytesIO +import pickle -import pytest # region Functions for producing random apps # Confirmed: [minint, maxint) covers the full range of values that xrange allows -minint = (-six.MAXSIZE - 1) // 2 -maxint = six.MAXSIZE // 2 +minint = (-sys.maxsize - 1) // 2 +maxint = sys.maxsize // 2 max_length = 1 * 1024 @@ -90,7 +81,7 @@ def random_integer(): def random_integers(): - return random_list(six.moves.range, minint, maxint) + return random_list(range, minint, maxint) def random_list(population, *args): @@ -98,7 +89,7 @@ def random_list(population, *args): def random_unicode(): - return ''.join(imap(lambda x: six.unichr(x), random.sample(range(MAX_NARROW_UNICODE), random.randint(0, max_length)))) + return ''.join([str(x) for x in random.sample(list(range(MAX_NARROW_UNICODE)), random.randint(0, max_length))]) # endregion @@ -118,7 +109,6 @@ def test_object_view(self): json_output = encoder.encode(view) self.assertEqual(self._json_input, json_output) - return def test_record_writer_with_random_data(self, save_recording=False): @@ -138,7 +128,7 @@ def test_record_writer_with_random_data(self, save_recording=False): for serial_number in range(0, 31): values = [serial_number, time(), random_bytes(), random_dict(), random_integers(), random_unicode()] - record = OrderedDict(izip(fieldnames, values)) + record = OrderedDict(list(zip(fieldnames, values))) #try: write_record(record) #except Exception as error: @@ -167,7 +157,7 @@ def test_record_writer_with_random_data(self, save_recording=False): test_data['metrics'] = metrics - for name, metric in six.iteritems(metrics): + for name, metric in list(metrics.items()): writer.write_metric(name, metric) self.assertEqual(writer._chunk_count, 0) @@ -182,8 +172,8 @@ def test_record_writer_with_random_data(self, save_recording=False): self.assertListEqual(writer._inspector['messages'], messages) self.assertDictEqual( - dict(ifilter(lambda k_v: k_v[0].startswith('metric.'), six.iteritems(writer._inspector))), - dict(imap(lambda k_v1: ('metric.' + k_v1[0], k_v1[1]), six.iteritems(metrics)))) + dict(k_v for k_v in list(writer._inspector.items()) if k_v[0].startswith('metric.')), + dict(('metric.' + k_v1[0], k_v1[1]) for k_v1 in list(metrics.items()))) writer.flush(finished=True) @@ -213,18 +203,15 @@ def test_record_writer_with_random_data(self, save_recording=False): # P2 [ ] TODO: Verify that RecordWriter gives consumers the ability to finish early by calling # RecordWriter.flush(finish=True). - return - def _compare_chunks(self, chunks_1, chunks_2): self.assertEqual(len(chunks_1), len(chunks_2)) n = 0 - for chunk_1, chunk_2 in izip(chunks_1, chunks_2): + for chunk_1, chunk_2 in zip(chunks_1, chunks_2): self.assertDictEqual( chunk_1.metadata, chunk_2.metadata, 'Chunk {0}: metadata error: "{1}" != "{2}"'.format(n, chunk_1.metadata, chunk_2.metadata)) self.assertMultiLineEqual(chunk_1.body, chunk_2.body, 'Chunk {0}: data error'.format(n)) n += 1 - return def _load_chunks(self, ifile): import re @@ -276,11 +263,11 @@ def _load_chunks(self, ifile): 'n': 12 } - _json_input = six.text_type(json.dumps(_dictionary, separators=(',', ':'))) + _json_input = str(json.dumps(_dictionary, separators=(',', ':'))) _package_path = os.path.dirname(os.path.abspath(__file__)) -class TestRecorder(object): +class TestRecorder: def __init__(self, test_case): @@ -293,7 +280,6 @@ def _not_implemented(self): raise NotImplementedError('class {} is not in playback or record mode'.format(self.__class__.__name__)) self.get = self.next_part = self.stop = MethodType(_not_implemented, self, self.__class__) - return @property def output(self): @@ -322,7 +308,6 @@ def stop(self): self._test_case.assertEqual(test_data['results'], self._output.getvalue()) self.stop = MethodType(stop, self, self.__class__) - return def record(self, path): @@ -357,7 +342,6 @@ def stop(self): pickle.dump(test, f) self.stop = MethodType(stop, self, self.__class__) - return def recorded(method): @@ -369,12 +353,11 @@ def _record(*args, **kwargs): return _record -class Test(object): +class Test: def __init__(self, fieldnames, data_generators): TestCase.__init__(self) - self._data_generators = list(chain((lambda: self._serial_number, time), data_generators)) self._fieldnames = list(chain(('_serial', '_time'), fieldnames)) self._recorder = TestRecorder(self) @@ -418,15 +401,17 @@ def _run(self): names = self.fieldnames for self._serial_number in range(0, 31): - record = OrderedDict(izip(names, self.row)) + record = OrderedDict(list(zip(names, self.row))) write_record(record) - return - # test = Test(['random_bytes', 'random_unicode'], [random_bytes, random_unicode]) # test.record() # test.playback() +Test.__test__ = False +TestRecorder.__test__ = False + + if __name__ == "__main__": main() diff --git a/tests/searchcommands/test_multibyte_processing.py b/tests/searchcommands/test_multibyte_processing.py index 4d6127fe9..1d021eed7 100644 --- a/tests/searchcommands/test_multibyte_processing.py +++ b/tests/searchcommands/test_multibyte_processing.py @@ -4,7 +4,6 @@ from os import path -from splunklib import six from splunklib.searchcommands import StreamingCommand, Configuration @@ -25,15 +24,13 @@ def get_input_file(name): def test_multibyte_chunked(): data = gzip.open(get_input_file("multibyte_input")) - if not six.PY2: - data = io.TextIOWrapper(data) + data = io.TextIOWrapper(data) cmd = build_test_command() cmd._process_protocol_v2(sys.argv, data, sys.stdout) def test_v1_searchcommand(): data = gzip.open(get_input_file("v1_search_input")) - if not six.PY2: - data = io.TextIOWrapper(data) + data = io.TextIOWrapper(data) cmd = build_test_command() cmd._process_protocol_v1(["test_script.py", "__EXECUTE__"], data, sys.stdout) diff --git a/tests/searchcommands/test_reporting_command.py b/tests/searchcommands/test_reporting_command.py index e5add818c..2111447d5 100644 --- a/tests/searchcommands/test_reporting_command.py +++ b/tests/searchcommands/test_reporting_command.py @@ -1,6 +1,6 @@ import io -import splunklib.searchcommands as searchcommands +from splunklib import searchcommands from . import chunked_data_stream as chunky @@ -15,7 +15,7 @@ def reduce(self, records): cmd = TestReportingCommand() ifile = io.BytesIO() - data = list() + data = [] for i in range(0, 10): data.append({"value": str(i)}) ifile.write(chunky.build_getinfo_chunk()) diff --git a/tests/searchcommands/test_search_command.py b/tests/searchcommands/test_search_command.py index c5ce36066..0aadc8db3 100755 --- a/tests/searchcommands/test_search_command.py +++ b/tests/searchcommands/test_search_command.py @@ -15,16 +15,6 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - -from splunklib import six -from splunklib.searchcommands import Configuration, StreamingCommand -from splunklib.searchcommands.decorators import ConfigurationSetting, Option -from splunklib.searchcommands.search_command import SearchCommand -from splunklib.client import Service - -from splunklib.six import StringIO, BytesIO -from splunklib.six.moves import zip as izip from json.encoder import encode_basestring as encode_string from unittest import main, TestCase @@ -37,30 +27,35 @@ import pytest +import splunklib +from splunklib.searchcommands import Configuration, StreamingCommand +from splunklib.searchcommands.decorators import ConfigurationSetting, Option +from splunklib.searchcommands.search_command import SearchCommand +from splunklib.client import Service + +from io import StringIO, BytesIO + + def build_command_input(getinfo_metadata, execute_metadata, execute_body): - input = ('chunked 1.0,{},0\n{}'.format(len(six.ensure_binary(getinfo_metadata)), getinfo_metadata) + - 'chunked 1.0,{},{}\n{}{}'.format(len(six.ensure_binary(execute_metadata)), len(six.ensure_binary(execute_body)), execute_metadata, execute_body)) + input = (f'chunked 1.0,{len(splunklib.ensure_binary(getinfo_metadata))},0\n{getinfo_metadata}' + + f'chunked 1.0,{len(splunklib.ensure_binary(execute_metadata))},{len(splunklib.ensure_binary(execute_body))}\n{execute_metadata}{execute_body}') - ifile = BytesIO(six.ensure_binary(input)) + ifile = BytesIO(splunklib.ensure_binary(input)) - if not six.PY2: - ifile = TextIOWrapper(ifile) + ifile = TextIOWrapper(ifile) return ifile + @Configuration() class TestCommand(SearchCommand): - required_option_1 = Option(require=True) required_option_2 = Option(require=True) def echo(self, records): for record in records: if record.get('action') == 'raise_exception': - if six.PY2: - raise StandardError(self) - else: - raise Exception(self) + raise Exception(self) yield record def _execute(self, ifile, process): @@ -108,7 +103,7 @@ def stream(self, records): value = self.search_results_info if action == 'get_search_results_info' else None yield {'_serial': serial_number, 'data': value} serial_number += 1 - return + @pytest.mark.smoke class TestSearchCommand(TestCase): @@ -123,37 +118,37 @@ def test_process_scpv2(self): metadata = ( '{{' - '"action": "getinfo", "preview": false, "searchinfo": {{' - '"latest_time": "0",' - '"splunk_version": "20150522",' - '"username": "admin",' - '"app": "searchcommands_app",' - '"args": [' - '"logging_configuration={logging_configuration}",' - '"logging_level={logging_level}",' - '"record={record}",' - '"show_configuration={show_configuration}",' - '"required_option_1=value_1",' - '"required_option_2=value_2"' - '],' - '"search": "A%7C%20inputlookup%20tweets%20%7C%20countmatches%20fieldname%3Dword_count%20pattern%3D%22%5Cw%2B%22%20text%20record%3Dt%20%7C%20export%20add_timestamp%3Df%20add_offset%3Dt%20format%3Dcsv%20segmentation%3Draw",' - '"earliest_time": "0",' - '"session_key": "0JbG1fJEvXrL6iYZw9y7tmvd6nHjTKj7ggaE7a4Jv5R0UIbeYJ65kThn^3hiNeoqzMT_LOtLpVR3Y8TIJyr5bkHUElMijYZ8l14wU0L4n^Oa5QxepsZNUIIQCBm^",' - '"owner": "admin",' - '"sid": "1433261372.158",' - '"splunkd_uri": "https://127.0.0.1:8089",' - '"dispatch_dir": {dispatch_dir},' - '"raw_args": [' - '"logging_configuration={logging_configuration}",' - '"logging_level={logging_level}",' - '"record={record}",' - '"show_configuration={show_configuration}",' - '"required_option_1=value_1",' - '"required_option_2=value_2"' - '],' - '"maxresultrows": 10,' - '"command": "countmatches"' - '}}' + '"action": "getinfo", "preview": false, "searchinfo": {{' + '"latest_time": "0",' + '"splunk_version": "20150522",' + '"username": "admin",' + '"app": "searchcommands_app",' + '"args": [' + '"logging_configuration={logging_configuration}",' + '"logging_level={logging_level}",' + '"record={record}",' + '"show_configuration={show_configuration}",' + '"required_option_1=value_1",' + '"required_option_2=value_2"' + '],' + '"search": "A%7C%20inputlookup%20tweets%20%7C%20countmatches%20fieldname%3Dword_count%20pattern%3D%22%5Cw%2B%22%20text%20record%3Dt%20%7C%20export%20add_timestamp%3Df%20add_offset%3Dt%20format%3Dcsv%20segmentation%3Draw",' + '"earliest_time": "0",' + '"session_key": "0JbG1fJEvXrL6iYZw9y7tmvd6nHjTKj7ggaE7a4Jv5R0UIbeYJ65kThn^3hiNeoqzMT_LOtLpVR3Y8TIJyr5bkHUElMijYZ8l14wU0L4n^Oa5QxepsZNUIIQCBm^",' + '"owner": "admin",' + '"sid": "1433261372.158",' + '"splunkd_uri": "https://127.0.0.1:8089",' + '"dispatch_dir": {dispatch_dir},' + '"raw_args": [' + '"logging_configuration={logging_configuration}",' + '"logging_level={logging_level}",' + '"record={record}",' + '"show_configuration={show_configuration}",' + '"required_option_1=value_1",' + '"required_option_2=value_2"' + '],' + '"maxresultrows": 10,' + '"command": "countmatches"' + '}}' '}}') basedir = self._package_directory @@ -198,7 +193,7 @@ def test_process_scpv2(self): expected = ( 'chunked 1.0,68,0\n' - '{"inspector":{"messages":[["INFO","test command configuration: "]]}}\n' + '{"inspector":{"messages":[["INFO","test command configuration: "]]}}' 'chunked 1.0,17,32\n' '{"finished":true}test,__mv_test\r\n' 'data,\r\n' @@ -235,14 +230,18 @@ def test_process_scpv2(self): self.assertEqual(command_metadata.preview, input_header['preview']) self.assertEqual(command_metadata.searchinfo.app, 'searchcommands_app') - self.assertEqual(command_metadata.searchinfo.args, ['logging_configuration=' + logging_configuration, 'logging_level=ERROR', 'record=false', 'show_configuration=true', 'required_option_1=value_1', 'required_option_2=value_2']) + self.assertEqual(command_metadata.searchinfo.args, + ['logging_configuration=' + logging_configuration, 'logging_level=ERROR', 'record=false', + 'show_configuration=true', 'required_option_1=value_1', 'required_option_2=value_2']) self.assertEqual(command_metadata.searchinfo.dispatch_dir, os.path.dirname(input_header['infoPath'])) self.assertEqual(command_metadata.searchinfo.earliest_time, 0.0) self.assertEqual(command_metadata.searchinfo.latest_time, 0.0) self.assertEqual(command_metadata.searchinfo.owner, 'admin') self.assertEqual(command_metadata.searchinfo.raw_args, command_metadata.searchinfo.args) - self.assertEqual(command_metadata.searchinfo.search, 'A| inputlookup tweets | countmatches fieldname=word_count pattern="\\w+" text record=t | export add_timestamp=f add_offset=t format=csv segmentation=raw') - self.assertEqual(command_metadata.searchinfo.session_key, '0JbG1fJEvXrL6iYZw9y7tmvd6nHjTKj7ggaE7a4Jv5R0UIbeYJ65kThn^3hiNeoqzMT_LOtLpVR3Y8TIJyr5bkHUElMijYZ8l14wU0L4n^Oa5QxepsZNUIIQCBm^') + self.assertEqual(command_metadata.searchinfo.search, + 'A| inputlookup tweets | countmatches fieldname=word_count pattern="\\w+" text record=t | export add_timestamp=f add_offset=t format=csv segmentation=raw') + self.assertEqual(command_metadata.searchinfo.session_key, + '0JbG1fJEvXrL6iYZw9y7tmvd6nHjTKj7ggaE7a4Jv5R0UIbeYJ65kThn^3hiNeoqzMT_LOtLpVR3Y8TIJyr5bkHUElMijYZ8l14wU0L4n^Oa5QxepsZNUIIQCBm^') self.assertEqual(command_metadata.searchinfo.sid, '1433261372.158') self.assertEqual(command_metadata.searchinfo.splunk_version, '20150522') self.assertEqual(command_metadata.searchinfo.splunkd_uri, 'https://127.0.0.1:8089') @@ -266,5 +265,8 @@ def test_process_scpv2(self): _package_directory = os.path.dirname(os.path.abspath(__file__)) +TestCommand.__test__ = False +TestStreamingCommand.__test__ = False + if __name__ == "__main__": main() diff --git a/tests/searchcommands/test_streaming_command.py b/tests/searchcommands/test_streaming_command.py index ffe6a7376..afb2e8caa 100644 --- a/tests/searchcommands/test_streaming_command.py +++ b/tests/searchcommands/test_streaming_command.py @@ -1,7 +1,7 @@ import io -from . import chunked_data_stream as chunky from splunklib.searchcommands import StreamingCommand, Configuration +from . import chunked_data_stream as chunky def test_simple_streaming_command(): @@ -16,7 +16,7 @@ def stream(self, records): cmd = TestStreamingCommand() ifile = io.BytesIO() ifile.write(chunky.build_getinfo_chunk()) - data = list() + data = [] for i in range(0, 10): data.append({"in_index": str(i)}) ifile.write(chunky.build_data_chunk(data, finished=True)) @@ -44,7 +44,7 @@ def stream(self, records): cmd = TestStreamingCommand() ifile = io.BytesIO() ifile.write(chunky.build_getinfo_chunk()) - data = list() + data = [] for i in range(0, 10): data.append({"in_index": str(i)}) ifile.write(chunky.build_data_chunk(data, finished=True)) @@ -53,14 +53,14 @@ def stream(self, records): cmd._process_protocol_v2([], ifile, ofile) ofile.seek(0) output_iter = chunky.ChunkedDataStream(ofile).__iter__() - output_iter.next() - output_records = [i for i in output_iter.next().data] + next(output_iter) + output_records = list(next(output_iter).data) # Assert that count of records having "odd_field" is 0 - assert len(list(filter(lambda r: "odd_field" in r, output_records))) == 0 + assert len(list(r for r in output_records if "odd_field" in r)) == 0 # Assert that count of records having "even_field" is 10 - assert len(list(filter(lambda r: "even_field" in r, output_records))) == 10 + assert len(list(r for r in output_records if "even_field" in r)) == 10 def test_field_preservation_positive(): @@ -78,7 +78,7 @@ def stream(self, records): cmd = TestStreamingCommand() ifile = io.BytesIO() ifile.write(chunky.build_getinfo_chunk()) - data = list() + data = [] for i in range(0, 10): data.append({"in_index": str(i)}) ifile.write(chunky.build_data_chunk(data, finished=True)) @@ -87,11 +87,11 @@ def stream(self, records): cmd._process_protocol_v2([], ifile, ofile) ofile.seek(0) output_iter = chunky.ChunkedDataStream(ofile).__iter__() - output_iter.next() - output_records = [i for i in output_iter.next().data] + next(output_iter) + output_records = list(next(output_iter).data) # Assert that count of records having "odd_field" is 10 - assert len(list(filter(lambda r: "odd_field" in r, output_records))) == 10 + assert len(list(r for r in output_records if "odd_field" in r)) == 10 # Assert that count of records having "even_field" is 10 - assert len(list(filter(lambda r: "even_field" in r, output_records))) == 10 + assert len(list(r for r in output_records if "even_field" in r)) == 10 diff --git a/tests/searchcommands/test_validators.py b/tests/searchcommands/test_validators.py index cc524b307..7b815491c 100755 --- a/tests/searchcommands/test_validators.py +++ b/tests/searchcommands/test_validators.py @@ -15,20 +15,15 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - -from splunklib.searchcommands import validators from random import randint from unittest import main, TestCase import os -import re import sys import tempfile -from splunklib import six -from splunklib.six.moves import range - import pytest +from splunklib.searchcommands import validators + # P2 [ ] TODO: Verify that all format methods produce 'None' when value is None @@ -52,14 +47,12 @@ def test_boolean(self): for value in truth_values: for variant in value, value.capitalize(), value.upper(): - s = six.text_type(variant) + s = str(variant) self.assertEqual(validator.__call__(s), truth_values[value]) self.assertIsNone(validator.__call__(None)) self.assertRaises(ValueError, validator.__call__, 'anything-else') - return - def test_duration(self): # Duration validator should parse and format time intervals of the form @@ -68,7 +61,7 @@ def test_duration(self): validator = validators.Duration() for seconds in range(0, 25 * 60 * 60, 59): - value = six.text_type(seconds) + value = str(seconds) self.assertEqual(validator(value), seconds) self.assertEqual(validator(validator.format(seconds)), seconds) value = '%d:%02d' % (seconds / 60, seconds % 60) @@ -97,8 +90,6 @@ def test_duration(self): self.assertRaises(ValueError, validator, '00:00:60') self.assertRaises(ValueError, validator, '00:60:00') - return - def test_fieldname(self): pass @@ -140,8 +131,6 @@ def test_file(self): if os.path.exists(full_path): os.unlink(full_path) - return - def test_integer(self): # Point of interest: @@ -165,14 +154,10 @@ def test_integer(self): validator = validators.Integer() def test(integer): - for s in str(integer), six.text_type(integer): - value = validator.__call__(s) - self.assertEqual(value, integer) - if six.PY2: - self.assertIsInstance(value, long) - else: - self.assertIsInstance(value, int) - self.assertEqual(validator.format(integer), six.text_type(integer)) + value = validator.__call__(integer) + self.assertEqual(value, integer) + self.assertIsInstance(value, int) + self.assertEqual(validator.format(integer), str(integer)) test(2 * minsize) test(minsize) @@ -204,8 +189,6 @@ def test(integer): self.assertRaises(ValueError, validator.__call__, minsize - 1) self.assertRaises(ValueError, validator.__call__, maxsize + 1) - return - def test_float(self): # Float validator test @@ -219,11 +202,11 @@ def test(float_val): float_val = float(float_val) except ValueError: assert False - for s in str(float_val), six.text_type(float_val): - value = validator.__call__(s) - self.assertAlmostEqual(value, float_val) - self.assertIsInstance(value, float) - self.assertEqual(validator.format(float_val), six.text_type(float_val)) + + value = validator.__call__(float_val) + self.assertAlmostEqual(value, float_val) + self.assertIsInstance(value, float) + self.assertEqual(validator.format(float_val), str(float_val)) test(2 * minsize) test(minsize) @@ -261,8 +244,6 @@ def test(float_val): self.assertRaises(ValueError, validator.__call__, minsize - 1) self.assertRaises(ValueError, validator.__call__, maxsize + 1) - return - def test_list(self): validator = validators.List() diff --git a/tests/test_all.py b/tests/test_all.py index 7789f8fd9..e74217970 100755 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -16,15 +16,11 @@ """Runs all the Splunk SDK for Python unit tests.""" -from __future__ import absolute_import import os -try: - import unittest2 as unittest # We must be sure to get unittest2--not unittest--on Python 2.6 -except ImportError: - import unittest +import unittest os.chdir(os.path.dirname(os.path.abspath(__file__))) suite = unittest.defaultTestLoader.discover('.') if __name__ == '__main__': - unittest.TextTestRunner().run(suite) \ No newline at end of file + unittest.TextTestRunner().run(suite) diff --git a/tests/test_app.py b/tests/test_app.py index 3dbc4cffb..39b68a081 100755 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -14,11 +14,9 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from tests import testlib import logging - -import splunklib.client as client +from tests import testlib +from splunklib import client class TestApp(testlib.SDKTestCase): @@ -26,7 +24,7 @@ class TestApp(testlib.SDKTestCase): app_name = None def setUp(self): - super(TestApp, self).setUp() + super().setUp() if self.app is None: for app in self.service.apps: if app.name.startswith('delete-me'): @@ -37,18 +35,17 @@ def setUp(self): # than entities like indexes, this is okay. self.app_name = testlib.tmpname() self.app = self.service.apps.create(self.app_name) - logging.debug("Creating app %s", self.app_name) - else: - logging.debug("App %s already exists. Skipping creation.", self.app_name) + logging.debug(f"Creating app {self.app_name}") + logging.debug(f"App {self.app_name} already exists. Skipping creation.") if self.service.restart_required: self.service.restart(120) - return def tearDown(self): - super(TestApp, self).tearDown() + super().tearDown() # The rest of this will leave Splunk in a state requiring a restart. # It doesn't actually matter, though. self.service = client.connect(**self.opts.kwargs) + app_name = '' for app in self.service.apps: app_name = app.name if app_name.startswith('delete-me'): @@ -90,7 +87,7 @@ def test_delete(self): self.assertTrue(name in self.service.apps) self.service.apps.delete(name) self.assertFalse(name in self.service.apps) - self.clear_restart_message() # We don't actually have to restart here. + self.clear_restart_message() # We don't actually have to restart here. def test_package(self): p = self.app.package() @@ -103,9 +100,7 @@ def test_updateInfo(self): p = self.app.updateInfo() self.assertTrue(p is not None) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest unittest.main() diff --git a/tests/test_binding.py b/tests/test_binding.py index aa3a13911..c54dc3c8e 100755 --- a/tests/test_binding.py +++ b/tests/test_binding.py @@ -14,15 +14,11 @@ # License for the specific language governing permissions and limitations # under the License. - -from __future__ import absolute_import -from io import BytesIO +from http import server as BaseHTTPServer +from io import BytesIO, StringIO from threading import Thread +from urllib.request import Request, urlopen -from splunklib.six.moves import BaseHTTPServer -from splunklib.six.moves.urllib.request import Request, urlopen -from splunklib.six.moves.urllib.error import HTTPError -import splunklib.six as six from xml.etree.ElementTree import XML import json @@ -30,14 +26,12 @@ from tests import testlib import unittest import socket -import sys import ssl -import splunklib.six.moves.http_cookies -import splunklib.binding as binding +import splunklib +from splunklib import binding from splunklib.binding import HTTPError, AuthenticationError, UrlEncoded -import splunklib.data as data -from splunklib import six +from splunklib import data import pytest @@ -65,14 +59,17 @@ def load(response): return data.load(response.body.read()) + class BindingTestCase(unittest.TestCase): context = None + def setUp(self): logging.info("%s", self.__class__.__name__) self.opts = testlib.parse([], {}, ".env") self.context = binding.connect(**self.opts.kwargs) logging.debug("Connected to splunkd.") + class TestResponseReader(BindingTestCase): def test_empty(self): response = binding.ResponseReader(BytesIO(b"")) @@ -106,7 +103,7 @@ def test_read_partial(self): def test_readable(self): txt = "abcd" - response = binding.ResponseReader(six.StringIO(txt)) + response = binding.ResponseReader(StringIO(txt)) self.assertTrue(response.readable()) def test_readinto_bytearray(self): @@ -124,9 +121,6 @@ def test_readinto_bytearray(self): self.assertTrue(response.empty) def test_readinto_memoryview(self): - import sys - if sys.version_info < (2, 7, 0): - return # memoryview is new to Python 2.7 txt = b"Checking readinto works as expected" response = binding.ResponseReader(BytesIO(txt)) arr = bytearray(10) @@ -142,7 +136,6 @@ def test_readinto_memoryview(self): self.assertTrue(response.empty) - class TestUrlEncoded(BindingTestCase): def test_idempotent(self): a = UrlEncoded('abc') @@ -173,6 +166,7 @@ def test_chars(self): def test_repr(self): self.assertEqual(repr(UrlEncoded('% %')), "UrlEncoded('% %')") + class TestAuthority(unittest.TestCase): def test_authority_default(self): self.assertEqual(binding._authority(), @@ -198,6 +192,7 @@ def test_all_fields(self): port="471"), "http://splunk.utopia.net:471") + class TestUserManipulation(BindingTestCase): def setUp(self): BindingTestCase.setUp(self) @@ -278,12 +273,12 @@ class TestSocket(BindingTestCase): def test_socket(self): socket = self.context.connect() socket.write(("POST %s HTTP/1.1\r\n" % \ - self.context._abspath("some/path/to/post/to")).encode('utf-8')) + self.context._abspath("some/path/to/post/to")).encode('utf-8')) socket.write(("Host: %s:%s\r\n" % \ - (self.context.host, self.context.port)).encode('utf-8')) + (self.context.host, self.context.port)).encode('utf-8')) socket.write("Accept-Encoding: identity\r\n".encode('utf-8')) socket.write(("Authorization: %s\r\n" % \ - self.context.token).encode('utf-8')) + self.context.token).encode('utf-8')) socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode('utf-8')) socket.write("\r\n".encode('utf-8')) socket.close() @@ -308,15 +303,17 @@ def test_socket_gethostbyname(self): self.context.host = socket.gethostbyname(self.context.host) self.assertTrue(self.context.connect()) + class TestUnicodeConnect(BindingTestCase): def test_unicode_connect(self): opts = self.opts.kwargs.copy() - opts['host'] = six.text_type(opts['host']) + opts['host'] = str(opts['host']) context = binding.connect(**opts) # Just check to make sure the service is alive response = context.get("/services") self.assertEqual(response.status, 200) + @pytest.mark.smoke class TestAutologin(BindingTestCase): def test_with_autologin(self): @@ -332,6 +329,7 @@ def test_without_autologin(self): self.assertRaises(AuthenticationError, self.context.get, "/services") + class TestAbspath(BindingTestCase): def setUp(self): BindingTestCase.setUp(self) @@ -339,7 +337,6 @@ def setUp(self): if 'app' in self.kwargs: del self.kwargs['app'] if 'owner' in self.kwargs: del self.kwargs['owner'] - def test_default(self): path = self.context._abspath("foo", owner=None, app=None) self.assertTrue(isinstance(path, UrlEncoded)) @@ -371,12 +368,12 @@ def test_sharing_app(self): self.assertEqual(path, "/servicesNS/nobody/MyApp/foo") def test_sharing_global(self): - path = self.context._abspath("foo", owner="me", app="MyApp",sharing="global") + path = self.context._abspath("foo", owner="me", app="MyApp", sharing="global") self.assertTrue(isinstance(path, UrlEncoded)) self.assertEqual(path, "/servicesNS/nobody/MyApp/foo") def test_sharing_system(self): - path = self.context._abspath("foo bar", owner="me", app="MyApp",sharing="system") + path = self.context._abspath("foo bar", owner="me", app="MyApp", sharing="system") self.assertTrue(isinstance(path, UrlEncoded)) self.assertEqual(path, "/servicesNS/nobody/system/foo%20bar") @@ -444,6 +441,7 @@ def test_context_with_owner_as_email(self): self.assertEqual(path, "/servicesNS/me%40me.com/system/foo") self.assertEqual(path, UrlEncoded("/servicesNS/me@me.com/system/foo")) + # An urllib2 based HTTP request handler, used to test the binding layers # support for pluggable request handlers. def urllib2_handler(url, message, **kwargs): @@ -452,13 +450,9 @@ def urllib2_handler(url, message, **kwargs): headers = dict(message.get('headers', [])) req = Request(url, data, headers) try: - # If running Python 2.7.9+, disable SSL certificate validation - if sys.version_info >= (2, 7, 9): - response = urlopen(req, context=ssl._create_unverified_context()) - else: - response = urlopen(req) + response = urlopen(req, context=ssl._create_unverified_context()) except HTTPError as response: - pass # Propagate HTTP errors via the returned response message + pass # Propagate HTTP errors via the returned response message return { 'status': response.code, 'reason': response.msg, @@ -466,6 +460,7 @@ def urllib2_handler(url, message, **kwargs): 'body': BytesIO(response.read()) } + def isatom(body): """Answers if the given response body looks like ATOM.""" root = XML(body) @@ -475,6 +470,7 @@ def isatom(body): root.find(XNAME_ID) is not None and \ root.find(XNAME_TITLE) is not None + class TestPluggableHTTP(testlib.SDKTestCase): # Verify pluggable HTTP reqeust handlers. def test_handlers(self): @@ -491,27 +487,24 @@ def test_handlers(self): body = context.get(path).body.read() self.assertTrue(isatom(body)) + def urllib2_insert_cookie_handler(url, message, **kwargs): method = message['method'].lower() data = message.get('body', b"") if method == 'post' else None headers = dict(message.get('headers', [])) req = Request(url, data, headers) try: - # If running Python 2.7.9+, disable SSL certificate validation - if sys.version_info >= (2, 7, 9): - response = urlopen(req, context=ssl._create_unverified_context()) - else: - response = urlopen(req) + response = urlopen(req, context=ssl._create_unverified_context()) except HTTPError as response: - pass # Propagate HTTP errors via the returned response message + pass # Propagate HTTP errors via the returned response message # Mimic the insertion of 3rd party cookies into the response. # An example is "sticky session"/"insert cookie" persistence # of a load balancer for a SHC. - header_list = [(k, v) for k, v in response.info().items()] + header_list = [(k, v) for k, v in response.info().items()] header_list.append(("Set-Cookie", "BIGipServer_splunk-shc-8089=1234567890.12345.0000; path=/; Httponly; Secure")) header_list.append(("Set-Cookie", "home_made=yummy")) - + return { 'status': response.code, 'reason': response.msg, @@ -519,6 +512,7 @@ def urllib2_insert_cookie_handler(url, message, **kwargs): 'body': BytesIO(response.read()) } + class TestCookiePersistence(testlib.SDKTestCase): # Verify persistence of 3rd party inserted cookies. def test_3rdPartyInsertedCookiePersistence(self): @@ -541,6 +535,7 @@ def test_3rdPartyInsertedCookiePersistence(self): self.assertEqual(persisted_cookies['BIGipServer_splunk-shc-8089'], "1234567890.12345.0000") self.assertEqual(persisted_cookies['home_made'], "yummy") + @pytest.mark.smoke class TestLogout(BindingTestCase): def test_logout(self): @@ -566,7 +561,7 @@ def setUp(self): self.context = binding.connect(**self.opts.kwargs) # Skip these tests if running below Splunk 6.2, cookie-auth didn't exist before - import splunklib.client as client + from splunklib import client service = client.Service(**self.opts.kwargs) # TODO: Workaround the fact that skipTest is not defined by unittest2.TestCase service.login() @@ -653,14 +648,14 @@ def test_login_with_multiple_cookies(self): except AuthenticationError as ae: self.assertEqual(str(ae), "Login failed.") # Bring in a valid cookie now - for key, value in self.context.get_cookies().items(): + for key, value in list(self.context.get_cookies().items()): new_context.get_cookies()[key] = value self.assertEqual(len(new_context.get_cookies()), 2) self.assertTrue('bad' in list(new_context.get_cookies().keys())) self.assertTrue('cookie' in list(new_context.get_cookies().values())) - for k, v in self.context.get_cookies().items(): + for k, v in list(self.context.get_cookies().items()): self.assertEqual(new_context.get_cookies()[k], v) self.assertEqual(new_context.get("apps/local").status, 200) @@ -681,80 +676,81 @@ def test_login_fails_without_cookie_or_token(self): class TestNamespace(unittest.TestCase): def test_namespace(self): tests = [ - ({ }, - { 'sharing': None, 'owner': None, 'app': None }), + ({}, + {'sharing': None, 'owner': None, 'app': None}), - ({ 'owner': "Bob" }, - { 'sharing': None, 'owner': "Bob", 'app': None }), + ({'owner': "Bob"}, + {'sharing': None, 'owner': "Bob", 'app': None}), - ({ 'app': "search" }, - { 'sharing': None, 'owner': None, 'app': "search" }), + ({'app': "search"}, + {'sharing': None, 'owner': None, 'app': "search"}), - ({ 'owner': "Bob", 'app': "search" }, - { 'sharing': None, 'owner': "Bob", 'app': "search" }), + ({'owner': "Bob", 'app': "search"}, + {'sharing': None, 'owner': "Bob", 'app': "search"}), - ({ 'sharing': "user", 'owner': "Bob@bob.com" }, - { 'sharing': "user", 'owner': "Bob@bob.com", 'app': None }), + ({'sharing': "user", 'owner': "Bob@bob.com"}, + {'sharing': "user", 'owner': "Bob@bob.com", 'app': None}), - ({ 'sharing': "user" }, - { 'sharing': "user", 'owner': None, 'app': None }), + ({'sharing': "user"}, + {'sharing': "user", 'owner': None, 'app': None}), - ({ 'sharing': "user", 'owner': "Bob" }, - { 'sharing': "user", 'owner': "Bob", 'app': None }), + ({'sharing': "user", 'owner': "Bob"}, + {'sharing': "user", 'owner': "Bob", 'app': None}), - ({ 'sharing': "user", 'app': "search" }, - { 'sharing': "user", 'owner': None, 'app': "search" }), + ({'sharing': "user", 'app': "search"}, + {'sharing': "user", 'owner': None, 'app': "search"}), - ({ 'sharing': "user", 'owner': "Bob", 'app': "search" }, - { 'sharing': "user", 'owner': "Bob", 'app': "search" }), + ({'sharing': "user", 'owner': "Bob", 'app': "search"}, + {'sharing': "user", 'owner': "Bob", 'app': "search"}), - ({ 'sharing': "app" }, - { 'sharing': "app", 'owner': "nobody", 'app': None }), + ({'sharing': "app"}, + {'sharing': "app", 'owner': "nobody", 'app': None}), - ({ 'sharing': "app", 'owner': "Bob" }, - { 'sharing': "app", 'owner': "nobody", 'app': None }), + ({'sharing': "app", 'owner': "Bob"}, + {'sharing': "app", 'owner': "nobody", 'app': None}), - ({ 'sharing': "app", 'app': "search" }, - { 'sharing': "app", 'owner': "nobody", 'app': "search" }), + ({'sharing': "app", 'app': "search"}, + {'sharing': "app", 'owner': "nobody", 'app': "search"}), - ({ 'sharing': "app", 'owner': "Bob", 'app': "search" }, - { 'sharing': "app", 'owner': "nobody", 'app': "search" }), + ({'sharing': "app", 'owner': "Bob", 'app': "search"}, + {'sharing': "app", 'owner': "nobody", 'app': "search"}), - ({ 'sharing': "global" }, - { 'sharing': "global", 'owner': "nobody", 'app': None }), + ({'sharing': "global"}, + {'sharing': "global", 'owner': "nobody", 'app': None}), - ({ 'sharing': "global", 'owner': "Bob" }, - { 'sharing': "global", 'owner': "nobody", 'app': None }), + ({'sharing': "global", 'owner': "Bob"}, + {'sharing': "global", 'owner': "nobody", 'app': None}), - ({ 'sharing': "global", 'app': "search" }, - { 'sharing': "global", 'owner': "nobody", 'app': "search" }), + ({'sharing': "global", 'app': "search"}, + {'sharing': "global", 'owner': "nobody", 'app': "search"}), - ({ 'sharing': "global", 'owner': "Bob", 'app': "search" }, - { 'sharing': "global", 'owner': "nobody", 'app': "search" }), + ({'sharing': "global", 'owner': "Bob", 'app': "search"}, + {'sharing': "global", 'owner': "nobody", 'app': "search"}), - ({ 'sharing': "system" }, - { 'sharing': "system", 'owner': "nobody", 'app': "system" }), + ({'sharing': "system"}, + {'sharing': "system", 'owner': "nobody", 'app': "system"}), - ({ 'sharing': "system", 'owner': "Bob" }, - { 'sharing': "system", 'owner': "nobody", 'app': "system" }), + ({'sharing': "system", 'owner': "Bob"}, + {'sharing': "system", 'owner': "nobody", 'app': "system"}), - ({ 'sharing': "system", 'app': "search" }, - { 'sharing': "system", 'owner': "nobody", 'app': "system" }), + ({'sharing': "system", 'app': "search"}, + {'sharing': "system", 'owner': "nobody", 'app': "system"}), - ({ 'sharing': "system", 'owner': "Bob", 'app': "search" }, - { 'sharing': "system", 'owner': "nobody", 'app': "system" }), + ({'sharing': "system", 'owner': "Bob", 'app': "search"}, + {'sharing': "system", 'owner': "nobody", 'app': "system"}), - ({ 'sharing': 'user', 'owner': '-', 'app': '-'}, - { 'sharing': 'user', 'owner': '-', 'app': '-'})] + ({'sharing': 'user', 'owner': '-', 'app': '-'}, + {'sharing': 'user', 'owner': '-', 'app': '-'})] for kwargs, expected in tests: namespace = binding.namespace(**kwargs) - for k, v in six.iteritems(expected): + for k, v in list(expected.items()): self.assertEqual(namespace[k], v) def test_namespace_fails(self): self.assertRaises(ValueError, binding.namespace, sharing="gobble") + @pytest.mark.smoke class TestBasicAuthentication(unittest.TestCase): def setUp(self): @@ -765,13 +761,13 @@ def setUp(self): opts["password"] = self.opts.kwargs["password"] self.context = binding.connect(**opts) - import splunklib.client as client + from splunklib import client service = client.Service(**opts) if getattr(unittest.TestCase, 'assertIsNotNone', None) is None: def assertIsNotNone(self, obj, msg=None): - if obj is None: - raise self.failureException(msg or '%r is not None' % obj) + if obj is None: + raise self.failureException(msg or '%r is not None' % obj) def test_basic_in_auth_headers(self): self.assertIsNotNone(self.context._auth_headers) @@ -782,6 +778,7 @@ def test_basic_in_auth_headers(self): self.assertEqual(self.context._auth_headers[0][1][:6], "Basic ") self.assertEqual(self.context.get("/services").status, 200) + @pytest.mark.smoke class TestTokenAuthentication(BindingTestCase): def test_preexisting_token(self): @@ -797,14 +794,14 @@ def test_preexisting_token(self): socket = newContext.connect() socket.write(("POST %s HTTP/1.1\r\n" % \ - self.context._abspath("some/path/to/post/to")).encode('utf-8')) + self.context._abspath("some/path/to/post/to")).encode('utf-8')) socket.write(("Host: %s:%s\r\n" % \ - (self.context.host, self.context.port)).encode('utf-8')) - socket.write(("Accept-Encoding: identity\r\n").encode('utf-8')) + (self.context.host, self.context.port)).encode('utf-8')) + socket.write("Accept-Encoding: identity\r\n".encode('utf-8')) socket.write(("Authorization: %s\r\n" % \ - self.context.token).encode('utf-8')) + self.context.token).encode('utf-8')) socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode('utf-8')) - socket.write(("\r\n").encode('utf-8')) + socket.write("\r\n".encode('utf-8')) socket.close() def test_preexisting_token_sans_splunk(self): @@ -824,18 +821,17 @@ def test_preexisting_token_sans_splunk(self): self.assertEqual(response.status, 200) socket = newContext.connect() - socket.write(("POST %s HTTP/1.1\r\n" %\ - self.context._abspath("some/path/to/post/to")).encode('utf-8')) - socket.write(("Host: %s:%s\r\n" %\ - (self.context.host, self.context.port)).encode('utf-8')) + socket.write(("POST %s HTTP/1.1\r\n" % \ + self.context._abspath("some/path/to/post/to")).encode('utf-8')) + socket.write(("Host: %s:%s\r\n" % \ + (self.context.host, self.context.port)).encode('utf-8')) socket.write("Accept-Encoding: identity\r\n".encode('utf-8')) - socket.write(("Authorization: %s\r\n" %\ - self.context.token).encode('utf-8')) - socket.write(("X-Splunk-Input-Mode: Streaming\r\n").encode('utf-8')) - socket.write(("\r\n").encode('utf-8')) + socket.write(("Authorization: %s\r\n" % \ + self.context.token).encode('utf-8')) + socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode('utf-8')) + socket.write("\r\n".encode('utf-8')) socket.close() - def test_connect_with_preexisting_token_sans_user_and_pass(self): token = self.context.token opts = self.opts.kwargs.copy() @@ -849,12 +845,12 @@ def test_connect_with_preexisting_token_sans_user_and_pass(self): socket = newContext.connect() socket.write(("POST %s HTTP/1.1\r\n" % \ - self.context._abspath("some/path/to/post/to")).encode('utf-8')) + self.context._abspath("some/path/to/post/to")).encode('utf-8')) socket.write(("Host: %s:%s\r\n" % \ - (self.context.host, self.context.port)).encode('utf-8')) + (self.context.host, self.context.port)).encode('utf-8')) socket.write("Accept-Encoding: identity\r\n".encode('utf-8')) socket.write(("Authorization: %s\r\n" % \ - self.context.token).encode('utf-8')) + self.context.token).encode('utf-8')) socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode('utf-8')) socket.write("\r\n".encode('utf-8')) socket.close() @@ -864,34 +860,37 @@ class TestPostWithBodyParam(unittest.TestCase): def test_post(self): def handler(url, message, **kwargs): - assert six.ensure_str(url) == "https://localhost:8089/servicesNS/testowner/testapp/foo/bar" - assert six.ensure_str(message["body"]) == "testkey=testvalue" + assert url == "https://localhost:8089/servicesNS/testowner/testapp/foo/bar" + assert message["body"] == b"testkey=testvalue" return splunklib.data.Record({ "status": 200, "headers": [], }) + ctx = binding.Context(handler=handler) ctx.post("foo/bar", owner="testowner", app="testapp", body={"testkey": "testvalue"}) def test_post_with_params_and_body(self): def handler(url, message, **kwargs): assert url == "https://localhost:8089/servicesNS/testowner/testapp/foo/bar?extrakey=extraval" - assert six.ensure_str(message["body"]) == "testkey=testvalue" + assert message["body"] == b"testkey=testvalue" return splunklib.data.Record({ "status": 200, "headers": [], }) + ctx = binding.Context(handler=handler) ctx.post("foo/bar", extrakey="extraval", owner="testowner", app="testapp", body={"testkey": "testvalue"}) def test_post_with_params_and_no_body(self): def handler(url, message, **kwargs): assert url == "https://localhost:8089/servicesNS/testowner/testapp/foo/bar" - assert six.ensure_str(message["body"]) == "extrakey=extraval" + assert message["body"] == b"extrakey=extraval" return splunklib.data.Record({ "status": 200, "headers": [], }) + ctx = binding.Context(handler=handler) ctx.post("foo/bar", extrakey="extraval", owner="testowner", app="testapp") @@ -903,12 +902,13 @@ def wrapped(handler_self): handler_self.send_response(response_code) handler_self.end_headers() handler_self.wfile.write(body) + return wrapped -class MockServer(object): +class MockServer: def __init__(self, port=9093, **handlers): - methods = {"do_" + k: _wrap_handler(v) for (k, v) in handlers.items()} + methods = {"do_" + k: _wrap_handler(v) for (k, v) in list(handlers.items())} def init(handler_self, socket, address, server): BaseHTTPServer.BaseHTTPRequestHandler.__init__(handler_self, socket, address, server) @@ -925,6 +925,7 @@ def log(*args): # To silence server access logs def run(): self._svr.handle_request() + self._thread = Thread(target=run) self._thread.daemon = True @@ -943,7 +944,7 @@ def test_post_with_body_urlencoded(self): def check_response(handler): length = int(handler.headers.get('content-length', 0)) body = handler.rfile.read(length) - assert six.ensure_str(body) == "foo=bar" + assert body.decode('utf-8') == "foo=bar" with MockServer(POST=check_response): ctx = binding.connect(port=9093, scheme='http', token="waffle") @@ -953,19 +954,20 @@ def test_post_with_body_string(self): def check_response(handler): length = int(handler.headers.get('content-length', 0)) body = handler.rfile.read(length) - assert six.ensure_str(handler.headers['content-type']) == 'application/json' + assert handler.headers['content-type'] == 'application/json' assert json.loads(body)["baz"] == "baf" with MockServer(POST=check_response): - ctx = binding.connect(port=9093, scheme='http', token="waffle", headers=[("Content-Type", "application/json")]) + ctx = binding.connect(port=9093, scheme='http', token="waffle", + headers=[("Content-Type", "application/json")]) ctx.post("/", foo="bar", body='{"baz": "baf"}') def test_post_with_body_dict(self): def check_response(handler): length = int(handler.headers.get('content-length', 0)) body = handler.rfile.read(length) - assert six.ensure_str(handler.headers['content-type']) == 'application/x-www-form-urlencoded' - assert six.ensure_str(body) == 'baz=baf&hep=cat' or six.ensure_str(body) == 'hep=cat&baz=baf' + assert handler.headers['content-type'] == 'application/x-www-form-urlencoded' + assert body.decode('utf-8') == 'baz=baf&hep=cat' or body.decode('utf-8') == 'hep=cat&baz=baf' with MockServer(POST=check_response): ctx = binding.connect(port=9093, scheme='http', token="waffle") @@ -973,8 +975,4 @@ def check_response(handler): if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest unittest.main() diff --git a/tests/test_collection.py b/tests/test_collection.py index 0fd9a1c33..bf74e30cc 100755 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -14,14 +14,12 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib import logging from contextlib import contextmanager -import splunklib.client as client -from splunklib.six.moves import range +from splunklib import client collections = [ 'apps', @@ -41,9 +39,9 @@ class CollectionTestCase(testlib.SDKTestCase): def setUp(self): - super(CollectionTestCase, self).setUp() + super().setUp() if self.service.splunk_version[0] >= 5 and 'modular_input_kinds' not in collections: - collections.append('modular_input_kinds') # Not supported before Splunk 5.0 + collections.append('modular_input_kinds') # Not supported before Splunk 5.0 else: logging.info("Skipping modular_input_kinds; not supported by Splunk %s" % \ '.'.join(str(x) for x in self.service.splunk_version)) @@ -69,59 +67,51 @@ def test_metadata(self): found_fields_keys = set(metadata.fields.keys()) self.assertTrue(found_access_keys >= expected_access_keys, msg='metadata.access is missing keys on ' + \ - '%s (found: %s, expected: %s)' % \ - (coll, found_access_keys, - expected_access_keys)) + f'{coll} (found: {found_access_keys}, expected: {expected_access_keys})') self.assertTrue(found_fields_keys >= expected_fields_keys, msg='metadata.fields is missing keys on ' + \ - '%s (found: %s, expected: %s)' % \ - (coll, found_fields_keys, - expected_fields_keys)) + f'{coll} (found: {found_fields_keys}, expected: {expected_fields_keys})') def test_list(self): for coll_name in collections: coll = getattr(self.service, coll_name) expected = [ent.name for ent in coll.list(count=10, sort_mode="auto")] if len(expected) == 0: - logging.debug("No entities in collection %s; skipping test.", coll_name) + logging.debug(f"No entities in collection {coll_name}; skipping test.") found = [ent.name for ent in coll.list()][:10] self.assertEqual(expected, found, - msg='on %s (expected: %s, found: %s)' % \ - (coll_name, expected, found)) + msg=f'on {coll_name} (expected {expected}, found {found})') def test_list_with_count(self): N = 5 for coll_name in collections: coll = getattr(self.service, coll_name) - expected = [ent.name for ent in coll.list(count=N+5)][:N] - N = len(expected) # in case there are v2") self.assertEqual(result, - {'e1': {'a1': 'v1', 'e2': {'$text': 'v2', 'a1': 'v1'}}}) + {'e1': {'a1': 'v1', 'e2': {'$text': 'v2', 'a1': 'v1'}}}) def test_real(self): """Test some real Splunk response examples.""" @@ -120,12 +119,8 @@ def test_invalid(self): if sys.version_info[1] >= 7: self.assertRaises(et.ParseError, data.load, "") else: - if six.PY2: - from xml.parsers.expat import ExpatError - self.assertRaises(ExpatError, data.load, "") - else: - from xml.etree.ElementTree import ParseError - self.assertRaises(ParseError, data.load, "") + from xml.etree.ElementTree import ParseError + self.assertRaises(ParseError, data.load, "") self.assertRaises(KeyError, data.load, "a") @@ -166,8 +161,8 @@ def test_dict(self): """) - self.assertEqual(result, - {'content': {'n1': {'n1n1': "n1v1"}, 'n2': {'n2n1': "n2v1"}}}) + self.assertEqual(result, + {'content': {'n1': {'n1n1': "n1v1"}, 'n2': {'n2n1': "n2v1"}}}) result = data.load(""" @@ -179,8 +174,8 @@ def test_dict(self): """) - self.assertEqual(result, - {'content': {'n1': ['1', '2', '3', '4']}}) + self.assertEqual(result, + {'content': {'n1': ['1', '2', '3', '4']}}) def test_list(self): result = data.load("""""") @@ -222,8 +217,8 @@ def test_list(self): v4 """) - self.assertEqual(result, - {'content': [{'n1':"v1"}, {'n2':"v2"}, {'n3':"v3"}, {'n4':"v4"}]}) + self.assertEqual(result, + {'content': [{'n1': "v1"}, {'n2': "v2"}, {'n3': "v3"}, {'n4': "v4"}]}) result = data.load(""" @@ -233,7 +228,7 @@ def test_list(self): """) self.assertEqual(result, - {'build': '101089', 'cpu_arch': 'i386', 'isFree': '0'}) + {'build': '101089', 'cpu_arch': 'i386', 'isFree': '0'}) def test_record(self): d = data.record() @@ -244,17 +239,14 @@ def test_record(self): 'bar.zrp.peem': 9}) self.assertEqual(d['foo'], 5) self.assertEqual(d['bar.baz'], 6) - self.assertEqual(d['bar'], {'baz': 6, 'qux': 7, 'zrp': {'meep': 8, 'peem':9}}) + self.assertEqual(d['bar'], {'baz': 6, 'qux': 7, 'zrp': {'meep': 8, 'peem': 9}}) self.assertEqual(d.foo, 5) self.assertEqual(d.bar.baz, 6) - self.assertEqual(d.bar, {'baz': 6, 'qux': 7, 'zrp': {'meep': 8, 'peem':9}}) + self.assertEqual(d.bar, {'baz': 6, 'qux': 7, 'zrp': {'meep': 8, 'peem': 9}}) self.assertRaises(KeyError, d.__getitem__, 'boris') if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest - unittest.main() + import unittest + unittest.main() diff --git a/tests/test_event_type.py b/tests/test_event_type.py index 5ae2c7ecd..9e4959771 100755 --- a/tests/test_event_type.py +++ b/tests/test_event_type.py @@ -14,17 +14,15 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -import logging -import splunklib.client as client class TestRead(testlib.SDKTestCase): def test_read(self): for event_type in self.service.event_types.list(count=1): self.check_entity(event_type) + class TestCreate(testlib.SDKTestCase): def test_create(self): self.event_type_name = testlib.tmpname() @@ -42,22 +40,23 @@ def test_create(self): self.assertEqual(self.event_type_name, event_type.name) def tearDown(self): - super(TestCreate, self).setUp() + super().setUp() try: self.service.event_types.delete(self.event_type_name) except KeyError: pass + class TestEventType(testlib.SDKTestCase): def setUp(self): - super(TestEventType, self).setUp() + super().setUp() self.event_type_name = testlib.tmpname() self.event_type = self.service.event_types.create( self.event_type_name, search="index=_internal *") def tearDown(self): - super(TestEventType, self).setUp() + super().setUp() try: self.service.event_types.delete(self.event_type_name) except KeyError: @@ -69,16 +68,13 @@ def test_delete(self): self.assertFalse(self.event_type_name in self.service.event_types) def test_update(self): - kwargs = {} - kwargs['search'] = "index=_audit *" - kwargs['description'] = "An audit event" - kwargs['priority'] = '3' + kwargs = {'search': "index=_audit *", 'description': "An audit event", 'priority': '3'} self.event_type.update(**kwargs) self.event_type.refresh() self.assertEqual(self.event_type['search'], kwargs['search']) self.assertEqual(self.event_type['description'], kwargs['description']) self.assertEqual(self.event_type['priority'], kwargs['priority']) - + def test_enable_disable(self): self.assertEqual(self.event_type['disabled'], '0') self.event_type.disable() @@ -88,9 +84,8 @@ def test_enable_disable(self): self.event_type.refresh() self.assertEqual(self.event_type['disabled'], '0') + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_fired_alert.py b/tests/test_fired_alert.py index 2480d4150..fb185dbec 100755 --- a/tests/test_fired_alert.py +++ b/tests/test_fired_alert.py @@ -14,22 +14,19 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -import logging -import splunklib.client as client class FiredAlertTestCase(testlib.SDKTestCase): def setUp(self): - super(FiredAlertTestCase, self).setUp() + super().setUp() self.index_name = testlib.tmpname() self.assertFalse(self.index_name in self.service.indexes) self.index = self.service.indexes.create(self.index_name) saved_searches = self.service.saved_searches self.saved_search_name = testlib.tmpname() self.assertFalse(self.saved_search_name in saved_searches) - query = "search index=%s" % self.index_name + query = f"search index={self.index_name}" kwargs = {'alert_type': 'always', 'alert.severity': "3", 'alert.suppress': "0", @@ -43,7 +40,7 @@ def setUp(self): query, **kwargs) def tearDown(self): - super(FiredAlertTestCase, self).tearDown() + super().tearDown() if self.service.splunk_version >= (5,): self.service.indexes.delete(self.index_name) for saved_search in self.service.saved_searches: @@ -57,7 +54,7 @@ def test_new_search_is_empty(self): self.assertEqual(len(self.saved_search.history()), 0) self.assertEqual(len(self.saved_search.fired_alerts), 0) self.assertFalse(self.saved_search_name in self.service.fired_alerts) - + def test_alerts_on_events(self): self.assertEqual(self.saved_search.alert_count, 0) self.assertEqual(len(self.saved_search.fired_alerts), 0) @@ -71,14 +68,17 @@ def test_alerts_on_events(self): self.index.refresh() self.index.submit('This is a test ' + testlib.tmpname(), sourcetype='sdk_use', host='boris') + def f(): self.index.refresh() - return int(self.index['totalEventCount']) == eventCount+1 + return int(self.index['totalEventCount']) == eventCount + 1 + self.assertEventuallyTrue(f, timeout=50) def g(): self.saved_search.refresh() return self.saved_search.alert_count == 1 + self.assertEventuallyTrue(g, timeout=200) alerts = self.saved_search.fired_alerts @@ -90,9 +90,8 @@ def test_read(self): for alert in alert_group.alerts: alert.content + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_index.py b/tests/test_index.py index 9e2a53298..fb876496b 100755 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -14,30 +14,23 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from __future__ import print_function -from tests import testlib import logging -import os import time -import splunklib.client as client -try: - import unittest -except ImportError: - import unittest2 as unittest - import pytest +from tests import testlib +from splunklib import client + class IndexTest(testlib.SDKTestCase): def setUp(self): - super(IndexTest, self).setUp() + super().setUp() self.index_name = testlib.tmpname() self.index = self.service.indexes.create(self.index_name) self.assertEventuallyTrue(lambda: self.index.refresh()['disabled'] == '0') def tearDown(self): - super(IndexTest, self).tearDown() + super().tearDown() # We can't delete an index with the REST API before Splunk # 5.0. In 4.x, we just have to leave them lying around until # someone cares to go clean them up. Unique naming prevents @@ -92,14 +85,14 @@ def test_disable_enable(self): # self.assertEqual(self.index['totalEventCount'], '0') def test_prefresh(self): - self.assertEqual(self.index['disabled'], '0') # Index is prefreshed + self.assertEqual(self.index['disabled'], '0') # Index is prefreshed def test_submit(self): event_count = int(self.index['totalEventCount']) self.assertEqual(self.index['sync'], '0') self.assertEqual(self.index['disabled'], '0') self.index.submit("Hello again!", sourcetype="Boris", host="meep") - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+1, timeout=50) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=50) def test_submit_namespaced(self): s = client.connect(**{ @@ -114,14 +107,14 @@ def test_submit_namespaced(self): self.assertEqual(i['sync'], '0') self.assertEqual(i['disabled'], '0') i.submit("Hello again namespaced!", sourcetype="Boris", host="meep") - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+1, timeout=50) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=50) def test_submit_via_attach(self): event_count = int(self.index['totalEventCount']) cn = self.index.attach() cn.send(b"Hello Boris!\r\n") cn.close() - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+1, timeout=60) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) def test_submit_via_attach_using_token_header(self): # Remove the prefix from the token @@ -133,14 +126,14 @@ def test_submit_via_attach_using_token_header(self): cn = i.attach() cn.send(b"Hello Boris 5!\r\n") cn.close() - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+1, timeout=60) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) def test_submit_via_attached_socket(self): event_count = int(self.index['totalEventCount']) f = self.index.attached_socket with f() as sock: sock.send(b'Hello world!\r\n') - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+1, timeout=60) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) def test_submit_via_attach_with_cookie_header(self): # Skip this test if running below Splunk 6.2, cookie-auth didn't exist before @@ -156,7 +149,7 @@ def test_submit_via_attach_with_cookie_header(self): cn = service.indexes[self.index_name].attach() cn.send(b"Hello Boris!\r\n") cn.close() - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+1, timeout=60) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) def test_submit_via_attach_with_multiple_cookie_headers(self): # Skip this test if running below Splunk 6.2, cookie-auth didn't exist before @@ -171,7 +164,7 @@ def test_submit_via_attach_with_multiple_cookie_headers(self): cn = service.indexes[self.index_name].attach() cn.send(b"Hello Boris!\r\n") cn.close() - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+1, timeout=60) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) @pytest.mark.app def test_upload(self): @@ -181,11 +174,10 @@ def test_upload(self): path = self.pathInApp("file_to_upload", ["log.txt"]) self.index.upload(path) - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+4, timeout=60) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 4, timeout=60) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_input.py b/tests/test_input.py index c7d48dc38..26943cd99 100755 --- a/tests/test_input.py +++ b/tests/test_input.py @@ -13,22 +13,12 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from __future__ import print_function - +import logging +import pytest from splunklib.binding import HTTPError from tests import testlib -import logging -from splunklib import six -try: - import unittest -except ImportError: - import unittest2 as unittest - -import splunklib.client as client - -import pytest +from splunklib import client def highest_port(service, base_port, *kinds): @@ -42,7 +32,7 @@ def highest_port(service, base_port, *kinds): class TestTcpInputNameHandling(testlib.SDKTestCase): def setUp(self): - super(TestTcpInputNameHandling, self).setUp() + super().setUp() self.base_port = highest_port(self.service, 10000, 'tcp', 'splunktcp', 'udp') + 1 def tearDown(self): @@ -50,11 +40,11 @@ def tearDown(self): port = int(input.name.split(':')[-1]) if port >= self.base_port: input.delete() - super(TestTcpInputNameHandling, self).tearDown() + super().tearDown() def create_tcp_input(self, base_port, kind, **options): port = base_port - while True: # Find the next unbound port + while True: # Find the next unbound port try: input = self.service.inputs.create(str(port), kind, **options) return input @@ -75,7 +65,7 @@ def test_cannot_create_with_restrictToHost_in_name(self): ) def test_create_tcp_ports_with_restrictToHost(self): - for kind in ['tcp', 'splunktcp']: # Multiplexed UDP ports are not supported + for kind in ['tcp', 'splunktcp']: # Multiplexed UDP ports are not supported # Make sure we can create two restricted inputs on the same port boris = self.service.inputs.create(str(self.base_port), kind, restrictToHost='boris') natasha = self.service.inputs.create(str(self.base_port), kind, restrictToHost='natasha') @@ -110,7 +100,7 @@ def test_unrestricted_to_restricted_collision(self): unrestricted.delete() def test_update_restrictToHost_fails(self): - for kind in ['tcp', 'splunktcp']: # No UDP, since it's broken in Splunk + for kind in ['tcp', 'splunktcp']: # No UDP, since it's broken in Splunk boris = self.create_tcp_input(self.base_port, kind, restrictToHost='boris') self.assertRaises( @@ -149,7 +139,6 @@ def test_read_invalid_input(self): self.assertTrue("HTTP 404 Not Found" in str(he)) def test_inputs_list_on_one_kind_with_count(self): - N = 10 expected = [x.name for x in self.service.inputs.list('monitor')[:10]] found = [x.name for x in self.service.inputs.list('monitor', count=10)] self.assertEqual(expected, found) @@ -181,21 +170,22 @@ def test_oneshot(self): def f(): index.refresh() - return int(index['totalEventCount']) == eventCount+4 + return int(index['totalEventCount']) == eventCount + 4 + self.assertEventuallyTrue(f, timeout=60) def test_oneshot_on_nonexistant_file(self): name = testlib.tmpname() self.assertRaises(HTTPError, - self.service.inputs.oneshot, name) + self.service.inputs.oneshot, name) class TestInput(testlib.SDKTestCase): def setUp(self): - super(TestInput, self).setUp() + super().setUp() inputs = self.service.inputs - unrestricted_port = str(highest_port(self.service, 10000, 'tcp', 'splunktcp', 'udp')+1) - restricted_port = str(highest_port(self.service, int(unrestricted_port)+1, 'tcp', 'splunktcp')+1) + unrestricted_port = str(highest_port(self.service, 10000, 'tcp', 'splunktcp', 'udp') + 1) + restricted_port = str(highest_port(self.service, int(unrestricted_port) + 1, 'tcp', 'splunktcp') + 1) test_inputs = [{'kind': 'tcp', 'name': unrestricted_port, 'host': 'sdk-test'}, {'kind': 'udp', 'name': unrestricted_port, 'host': 'sdk-test'}, {'kind': 'tcp', 'name': 'boris:' + restricted_port, 'host': 'sdk-test'}] @@ -209,8 +199,8 @@ def setUp(self): inputs.create(restricted_port, 'tcp', restrictToHost='boris') def tearDown(self): - super(TestInput, self).tearDown() - for entity in six.itervalues(self._test_entities): + super().tearDown() + for entity in list(self._test_entities.values()): try: self.service.inputs.delete( kind=entity.kind, @@ -233,7 +223,7 @@ def test_lists_modular_inputs(self): self.uncheckedRestartSplunk() inputs = self.service.inputs - if ('abcd','test2') not in inputs: + if ('abcd', 'test2') not in inputs: inputs.create('abcd', 'test2', field1='boris') input = inputs['abcd', 'test2'] @@ -241,7 +231,7 @@ def test_lists_modular_inputs(self): def test_create(self): inputs = self.service.inputs - for entity in six.itervalues(self._test_entities): + for entity in list(self._test_entities.values()): self.check_entity(entity) self.assertTrue(isinstance(entity, client.Input)) @@ -252,7 +242,7 @@ def test_get_kind_list(self): def test_read(self): inputs = self.service.inputs - for this_entity in six.itervalues(self._test_entities): + for this_entity in list(self._test_entities.values()): kind, name = this_entity.kind, this_entity.name read_entity = inputs[name, kind] self.assertEqual(this_entity.kind, read_entity.kind) @@ -268,7 +258,7 @@ def test_read_indiviually(self): def test_update(self): inputs = self.service.inputs - for entity in six.itervalues(self._test_entities): + for entity in list(self._test_entities.values()): kind, name = entity.kind, entity.name kwargs = {'host': 'foo'} entity.update(**kwargs) @@ -278,19 +268,19 @@ def test_update(self): @pytest.mark.skip('flaky') def test_delete(self): inputs = self.service.inputs - remaining = len(self._test_entities)-1 - for input_entity in six.itervalues(self._test_entities): + remaining = len(self._test_entities) - 1 + for input_entity in list(self._test_entities.values()): name = input_entity.name kind = input_entity.kind self.assertTrue(name in inputs) - self.assertTrue((name,kind) in inputs) + self.assertTrue((name, kind) in inputs) if remaining == 0: inputs.delete(name) self.assertFalse(name in inputs) else: if not name.startswith('boris'): self.assertRaises(client.AmbiguousReferenceException, - inputs.delete, name) + inputs.delete, name) self.service.inputs.delete(name, kind) self.assertFalse((name, kind) in inputs) self.assertRaises(client.HTTPError, @@ -299,8 +289,6 @@ def test_delete(self): if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_job.py b/tests/test_job.py index 18f3189a9..bab74f652 100755 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -14,9 +14,6 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from __future__ import print_function - from io import BytesIO from time import sleep @@ -24,13 +21,10 @@ from tests import testlib -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest -import splunklib.client as client -import splunklib.results as results +from splunklib import client +from splunklib import results from splunklib.binding import _log_duration, HTTPError @@ -84,21 +78,21 @@ def test_export(self): self.assertTrue(len(nonmessages) <= 3) def test_export_docstring_sample(self): - import splunklib.client as client - import splunklib.results as results + from splunklib import client + from splunklib import results service = self.service # cheat rr = results.JSONResultsReader(service.jobs.export("search * | head 5", output_mode='json')) for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - pass #print '%s: %s' % (result.type, result.message) + pass #print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - pass #print result + pass #print(result) assert rr.is_preview == False def test_results_docstring_sample(self): - import splunklib.results as results + from splunklib import results service = self.service # cheat job = service.jobs.create("search * | head 5") while not job.is_done(): @@ -107,42 +101,42 @@ def test_results_docstring_sample(self): for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - pass #print '%s: %s' % (result.type, result.message) + pass #print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - pass #print result + pass #print(result) assert rr.is_preview == False def test_preview_docstring_sample(self): - import splunklib.client as client - import splunklib.results as results + from splunklib import client + from splunklib import results service = self.service # cheat job = service.jobs.create("search * | head 5") rr = results.JSONResultsReader(job.preview(output_mode='json')) for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - pass #print '%s: %s' % (result.type, result.message) + pass #print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - pass #print result + pass #print(result) if rr.is_preview: - pass #print "Preview of a running search job." + pass #print("Preview of a running search job.") else: - pass #print "Job is finished. Results are final." + pass #print("Job is finished. Results are final.") def test_oneshot_docstring_sample(self): - import splunklib.client as client - import splunklib.results as results + from splunklib import client + from splunklib import results service = self.service # cheat rr = results.JSONResultsReader(service.jobs.oneshot("search * | head 5", output_mode='json')) for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - pass #print '%s: %s' % (result.type, result.message) + pass #print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - pass #print result + pass #print(result) assert rr.is_preview == False def test_normal_job_with_garbage_fails(self): @@ -188,7 +182,6 @@ def check_job(self, job): 'statusBuckets', 'ttl'] for key in keys: self.assertTrue(key in job.content) - return def test_read_jobs(self): jobs = self.service.jobs @@ -212,11 +205,11 @@ def test_get_job(self): class TestJobWithDelayedDone(testlib.SDKTestCase): def setUp(self): - super(TestJobWithDelayedDone, self).setUp() + super().setUp() self.job = None def tearDown(self): - super(TestJobWithDelayedDone, self).tearDown() + super().tearDown() if self.job is not None: self.job.cancel() self.assertEventuallyTrue(lambda: self.job.sid not in self.service.jobs) @@ -243,7 +236,6 @@ def is_preview_enabled(): return self.job.content['isPreviewEnabled'] == '1' self.assertEventuallyTrue(is_preview_enabled) - return @pytest.mark.app def test_setpriority(self): @@ -279,12 +271,11 @@ def f(): return int(self.job.content['priority']) == new_priority self.assertEventuallyTrue(f, timeout=sleep_duration + 5) - return class TestJob(testlib.SDKTestCase): def setUp(self): - super(TestJob, self).setUp() + super().setUp() self.query = "search index=_internal | head 3" self.job = self.service.jobs.create( query=self.query, @@ -292,7 +283,7 @@ def setUp(self): latest_time="now") def tearDown(self): - super(TestJob, self).tearDown() + super().tearDown() self.job.cancel() @_log_duration diff --git a/tests/test_kvstore_batch.py b/tests/test_kvstore_batch.py index d32b665e6..9c2f3afe1 100755 --- a/tests/test_kvstore_batch.py +++ b/tests/test_kvstore_batch.py @@ -14,22 +14,17 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -from splunklib.six.moves import range -try: - import unittest -except ImportError: - import unittest2 as unittest -import splunklib.client as client + +from splunklib import client + class KVStoreBatchTestCase(testlib.SDKTestCase): def setUp(self): - super(KVStoreBatchTestCase, self).setUp() - #self.service.namespace['owner'] = 'nobody' + super().setUp() self.service.namespace['app'] = 'search' confs = self.service.kvstore - if ('test' in confs): + if 'test' in confs: confs['test'].delete() confs.create('test') @@ -69,15 +64,13 @@ def test_insert_find_update_data(self): self.assertEqual(testData[x][0]['data'], '#' + str(x + 1)) self.assertEqual(testData[x][0]['num'], x + 1) - def tearDown(self): confs = self.service.kvstore if ('test' in confs): confs['test'].delete() + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_kvstore_conf.py b/tests/test_kvstore_conf.py index a24537288..beca1f69c 100755 --- a/tests/test_kvstore_conf.py +++ b/tests/test_kvstore_conf.py @@ -14,18 +14,12 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -try: - import unittest -except ImportError: - import unittest2 as unittest -import splunklib.client as client +from splunklib import client class KVStoreConfTestCase(testlib.SDKTestCase): def setUp(self): - super(KVStoreConfTestCase, self).setUp() - #self.service.namespace['owner'] = 'nobody' + super().setUp() self.service.namespace['app'] = 'search' self.confs = self.service.kvstore if ('test' in self.confs): @@ -40,7 +34,7 @@ def test_create_delete_collection(self): self.confs.create('test') self.assertTrue('test' in self.confs) self.confs['test'].delete() - self.assertTrue(not 'test' in self.confs) + self.assertTrue('test' not in self.confs) def test_update_collection(self): self.confs.create('test') @@ -93,8 +87,5 @@ def tearDown(self): self.confs['test'].delete() if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest unittest.main() diff --git a/tests/test_kvstore_data.py b/tests/test_kvstore_data.py index 6ddeae688..5627921f0 100755 --- a/tests/test_kvstore_data.py +++ b/tests/test_kvstore_data.py @@ -14,20 +14,15 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import import json from tests import testlib -from splunklib.six.moves import range -try: - import unittest -except ImportError: - import unittest2 as unittest -import splunklib.client as client + +from splunklib import client + class KVStoreDataTestCase(testlib.SDKTestCase): def setUp(self): - super(KVStoreDataTestCase, self).setUp() - #self.service.namespace['owner'] = 'nobody' + super().setUp() self.service.namespace['app'] = 'search' self.confs = self.service.kvstore if ('test' in self.confs): @@ -74,7 +69,6 @@ def test_query_data(self): data = self.col.query(limit=2, skip=9) self.assertEqual(len(data), 1) - def test_invalid_insert_update(self): self.assertRaises(client.HTTPError, lambda: self.col.insert('NOT VALID DATA')) id = self.col.insert(json.dumps({'foo': 'bar'}))['_key'] @@ -89,16 +83,15 @@ def test_params_data_type_conversion(self): self.assertEqual(len(data), 20) for x in range(20): self.assertEqual(data[x]['data'], 39 - x) - self.assertTrue(not 'ignore' in data[x]) - self.assertTrue(not '_key' in data[x]) + self.assertTrue('ignore' not in data[x]) + self.assertTrue('_key' not in data[x]) def tearDown(self): if ('test' in self.confs): self.confs['test'].delete() + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_logger.py b/tests/test_logger.py index 7e9f5c8e0..0541d79ab 100755 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -14,13 +14,13 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -import splunklib.client as client +from splunklib import client LEVELS = ["INFO", "WARN", "ERROR", "DEBUG", "CRIT"] + class LoggerTestCase(testlib.SDKTestCase): def check_logger(self, logger): self.check_entity(logger) @@ -44,9 +44,8 @@ def test_crud(self): logger.refresh() self.assertEqual(self.service.loggers['AuditLogger']['level'], saved) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_message.py b/tests/test_message.py index cd76c783b..0c94402e5 100755 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -14,10 +14,10 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -import splunklib.client as client +from splunklib import client + class MessageTest(testlib.SDKTestCase): def setUp(self): @@ -31,6 +31,7 @@ def tearDown(self): testlib.SDKTestCase.tearDown(self) self.service.messages.delete(self.message_name) + class TestCreateDelete(testlib.SDKTestCase): def test_create_delete(self): message_name = testlib.tmpname() @@ -46,11 +47,10 @@ def test_create_delete(self): def test_invalid_name(self): self.assertRaises(client.InvalidNameException, self.service.messages.create, None, value="What?") self.assertRaises(client.InvalidNameException, self.service.messages.create, 42, value="Who, me?") - self.assertRaises(client.InvalidNameException, self.service.messages.create, [1,2,3], value="Who, me?") + self.assertRaises(client.InvalidNameException, self.service.messages.create, [1, 2, 3], value="Who, me?") + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_modular_input.py b/tests/test_modular_input.py index ae6e797db..688b26b6b 100755 --- a/tests/test_modular_input.py +++ b/tests/test_modular_input.py @@ -14,12 +14,6 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from __future__ import print_function -try: - import unittest2 as unittest -except ImportError: - import unittest from tests import testlib import pytest @@ -27,7 +21,7 @@ @pytest.mark.smoke class ModularInputKindTestCase(testlib.SDKTestCase): def setUp(self): - super(ModularInputKindTestCase, self).setUp() + super().setUp() self.uncheckedRestartSplunk() @pytest.mark.app @@ -38,7 +32,7 @@ def test_lists_modular_inputs(self): self.uncheckedRestartSplunk() inputs = self.service.inputs - if ('abcd','test2') not in inputs: + if ('abcd', 'test2') not in inputs: inputs.create('abcd', 'test2', field1='boris') input = inputs['abcd', 'test2'] @@ -55,5 +49,8 @@ def check_modular_input_kind(self, m): self.assertEqual('test2', m['title']) self.assertEqual('simple', m['streaming_mode']) + if __name__ == "__main__": + import unittest + unittest.main() diff --git a/tests/test_modular_input_kinds.py b/tests/test_modular_input_kinds.py index c6b7391ea..303804754 100755 --- a/tests/test_modular_input_kinds.py +++ b/tests/test_modular_input_kinds.py @@ -14,20 +14,16 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from __future__ import print_function from tests import testlib -try: - import unittest -except ImportError: - import unittest2 as unittest -import splunklib.client as client + +from splunklib import client import pytest + class ModularInputKindTestCase(testlib.SDKTestCase): def setUp(self): - super(ModularInputKindTestCase, self).setUp() + super().setUp() self.uncheckedRestartSplunk() @pytest.mark.app @@ -40,9 +36,9 @@ def test_list_arguments(self): test1 = self.service.modular_input_kinds['test1'] - expected_args = set(["name", "resname", "key_id", "no_description", "empty_description", - "arg_required_on_edit", "not_required_on_edit", "required_on_create", - "not_required_on_create", "number_field", "string_field", "boolean_field"]) + expected_args = {"name", "resname", "key_id", "no_description", "empty_description", "arg_required_on_edit", + "not_required_on_edit", "required_on_create", "not_required_on_create", "number_field", + "string_field", "boolean_field"} found_args = set(test1.arguments.keys()) self.assertEqual(expected_args, found_args) @@ -77,9 +73,8 @@ def test_list_modular_inputs(self): for m in self.service.modular_input_kinds: self.check_modular_input_kind(m) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_results.py b/tests/test_results.py index 5fdca2b91..035951245 100755 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -14,14 +14,11 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import - from io import BytesIO -from splunklib.six import StringIO from tests import testlib from time import sleep -import splunklib.results as results +from splunklib import results import io @@ -164,9 +161,8 @@ def assert_parsed_results_equals(self, xml_text, expected_results): actual_results = [x for x in results_reader] self.assertEqual(expected_results, actual_results) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_role.py b/tests/test_role.py index 16205d057..ca9f50090 100755 --- a/tests/test_role.py +++ b/tests/test_role.py @@ -14,20 +14,20 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib import logging -import splunklib.client as client +from splunklib import client + class RoleTestCase(testlib.SDKTestCase): def setUp(self): - super(RoleTestCase, self).setUp() + super().setUp() self.role_name = testlib.tmpname() self.role = self.service.roles.create(self.role_name) def tearDown(self): - super(RoleTestCase, self).tearDown() + super().tearDown() for role in self.service.roles: if role.name.startswith('delete-me'): self.service.roles.delete(role.name) @@ -91,7 +91,6 @@ def test_invalid_revoke(self): def test_revoke_capability_not_granted(self): self.role.revoke('change_own_password') - def test_update(self): kwargs = {} if 'user' in self.role['imported_roles']: @@ -105,9 +104,8 @@ def test_update(self): self.assertEqual(self.role['imported_roles'], kwargs['imported_roles']) self.assertEqual(int(self.role['srchJobsQuota']), kwargs['srchJobsQuota']) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_saved_search.py b/tests/test_saved_search.py index d1f8f57c2..8d559bc5d 100755 --- a/tests/test_saved_search.py +++ b/tests/test_saved_search.py @@ -14,22 +14,21 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import import datetime from tests import testlib import logging from time import sleep -import splunklib.client as client -from splunklib.six.moves import zip +from splunklib import client import pytest + @pytest.mark.smoke class TestSavedSearch(testlib.SDKTestCase): def setUp(self): - super(TestSavedSearch, self).setUp() + super().setUp() saved_searches = self.service.saved_searches logging.debug("Saved searches namespace: %s", saved_searches.service.namespace) self.saved_search_name = testlib.tmpname() @@ -37,7 +36,7 @@ def setUp(self): self.saved_search = saved_searches.create(self.saved_search_name, query) def tearDown(self): - super(TestSavedSearch, self).setUp() + super().setUp() for saved_search in self.service.saved_searches: if saved_search.name.startswith('delete-me'): try: @@ -91,7 +90,6 @@ def test_delete(self): self.assertRaises(client.HTTPError, self.saved_search.refresh) - def test_update(self): is_visible = testlib.to_bool(self.saved_search['is_visible']) self.saved_search.update(is_visible=not is_visible) @@ -148,7 +146,7 @@ def test_dispatch(self): def test_dispatch_with_options(self): try: - kwargs = { 'dispatch.buckets': 100 } + kwargs = {'dispatch.buckets': 100} job = self.saved_search.dispatch(**kwargs) while not job.is_ready(): sleep(0.1) @@ -165,7 +163,7 @@ def test_history(self): while not job.is_ready(): sleep(0.1) history = self.saved_search.history() - self.assertEqual(len(history), N+1) + self.assertEqual(len(history), N + 1) self.assertTrue(job.sid in [j.sid for j in history]) finally: job.cancel() @@ -199,13 +197,8 @@ def test_scheduled_times(self): for x in scheduled_times])) time_pairs = list(zip(scheduled_times[:-1], scheduled_times[1:])) for earlier, later in time_pairs: - diff = later-earlier - # diff is an instance of datetime.timedelta, which - # didn't get a total_seconds() method until Python 2.7. - # Since we support Python 2.6, we have to calculate the - # total seconds ourselves. - total_seconds = diff.days*24*60*60 + diff.seconds - self.assertEqual(total_seconds/60.0, 5) + diff = later - earlier + self.assertEqual(diff.total_seconds() / 60.0, 5) def test_no_equality(self): self.assertRaises(client.IncomparableException, @@ -214,7 +207,7 @@ def test_no_equality(self): def test_suppress(self): suppressed_time = self.saved_search['suppressed'] self.assertGreaterEqual(suppressed_time, 0) - new_suppressed_time = suppressed_time+100 + new_suppressed_time = suppressed_time + 100 self.saved_search.suppress(new_suppressed_time) self.assertLessEqual(self.saved_search['suppressed'], new_suppressed_time) @@ -248,8 +241,6 @@ def test_acl_fails_without_owner(self): ) if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_service.py b/tests/test_service.py index 34afef2c8..fb6e7730e 100755 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -14,12 +14,11 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from tests import testlib import unittest +from tests import testlib -import splunklib.client as client -from splunklib.client import AuthenticationError +from splunklib import client +from splunklib.binding import AuthenticationError from splunklib.client import Service from splunklib.binding import HTTPError @@ -57,7 +56,7 @@ def test_info_with_namespace(self): try: self.assertEqual(self.service.info.licenseState, 'OK') except HTTPError as he: - self.fail("Couldn't get the server info, probably got a 403! %s" % he.message) + self.fail(f"Couldn't get the server info, probably got a 403! {he.message}") self.service.namespace["owner"] = owner self.service.namespace["app"] = app @@ -187,7 +186,7 @@ def setUp(self): def assertIsNotNone(self, obj, msg=None): if obj is None: - raise self.failureException(msg or '%r is not None' % obj) + raise self.failureException(msg or f'{obj} is not None') def test_login_and_store_cookie(self): self.assertIsNotNone(self.service.get_cookies()) @@ -364,8 +363,5 @@ def test_proper_namespace_with_service_namespace(self): if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + unittest.main() diff --git a/tests/test_storage_passwords.py b/tests/test_storage_passwords.py index 4f2fee81f..578b4fb02 100644 --- a/tests/test_storage_passwords.py +++ b/tests/test_storage_passwords.py @@ -14,11 +14,9 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -import logging -import splunklib.client as client +from splunklib import client class Tests(testlib.SDKTestCase): @@ -153,11 +151,10 @@ def test_read(self): p = self.storage_passwords.create("changeme", username) self.assertEqual(start_count + 1, len(self.storage_passwords)) - for sp in self.storage_passwords: - self.assertTrue(p.name in self.storage_passwords) - # Name works with or without a trailing colon - self.assertTrue((":" + username + ":") in self.storage_passwords) - self.assertTrue((":" + username) in self.storage_passwords) + self.assertTrue(p.name in self.storage_passwords) + # Name works with or without a trailing colon + self.assertTrue((":" + username + ":") in self.storage_passwords) + self.assertTrue((":" + username) in self.storage_passwords) p.delete() self.assertEqual(start_count, len(self.storage_passwords)) @@ -233,9 +230,8 @@ def test_spaces_in_username(self): p.delete() self.assertEqual(start_count, len(self.storage_passwords)) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_user.py b/tests/test_user.py index b8a97f811..389588141 100755 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -14,20 +14,19 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -import logging -import splunklib.client as client +from splunklib import client + class UserTestCase(testlib.SDKTestCase): def check_user(self, user): self.check_entity(user) # Verify expected fields exist [user[f] for f in ['email', 'password', 'realname', 'roles']] - + def setUp(self): - super(UserTestCase, self).setUp() + super().setUp() self.username = testlib.tmpname() self.user = self.service.users.create( self.username, @@ -35,7 +34,7 @@ def setUp(self): roles=['power', 'user']) def tearDown(self): - super(UserTestCase, self).tearDown() + super().tearDown() for user in self.service.users: if user.name.startswith('delete-me'): self.service.users.delete(user.name) @@ -84,9 +83,8 @@ def test_delete_is_case_insensitive(self): self.assertFalse(self.username in users) self.assertFalse(self.username.upper() in users) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py index 5eedbaba3..922d380f8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,15 +1,8 @@ -from __future__ import absolute_import -from tests import testlib - import unittest import os import sys - -try: - from utils import * -except ImportError: - raise Exception("Add the SDK repository to your PYTHONPATH to run the test cases " - "(e.g., export PYTHONPATH=~/splunk-sdk-python.") +from tests import testlib +from utils import dslice TEST_DICT = { 'username': 'admin', @@ -22,7 +15,7 @@ class TestUtils(testlib.SDKTestCase): def setUp(self): - super(TestUtils, self).setUp() + super().setUp() # Test dslice when a dict is passed to change key names def test_dslice_dict_args(self): @@ -96,10 +89,7 @@ def checkFilePermissions(dir_path): path = os.path.join(dir_path, file) if os.path.isfile(path): permission = oct(os.stat(path).st_mode) - if sys.version_info >= (3, 0): - self.assertEqual(permission, '0o100644') - else : - self.assertEqual(permission, '0100644') + self.assertEqual(permission, '0o100644') else: checkFilePermissions(path) @@ -108,8 +98,6 @@ def checkFilePermissions(dir_path): if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/testlib.py b/tests/testlib.py index 4a99e026a..ac8a3e1ef 100644 --- a/tests/testlib.py +++ b/tests/testlib.py @@ -15,35 +15,26 @@ # under the License. """Shared unit test utilities.""" -from __future__ import absolute_import -from __future__ import print_function import contextlib +import os +import time +import logging import sys -from splunklib import six # Run the test suite on the SDK without installing it. sys.path.insert(0, '../') -import splunklib.client as client from time import sleep from datetime import datetime, timedelta -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest -try: - from utils import parse -except ImportError: - raise Exception("Add the SDK repository to your PYTHONPATH to run the test cases " - "(e.g., export PYTHONPATH=~/splunk-sdk-python.") +from utils import parse + +from splunklib import client -import os -import time -import logging logging.basicConfig( filename='test.log', @@ -62,10 +53,9 @@ class WaitTimedOutError(Exception): def to_bool(x): if x == '1': return True - elif x == '0': + if x == '0': return False - else: - raise ValueError("Not a boolean value: %s", x) + raise ValueError("Not a boolean value: %s", x) def tmpname(): @@ -102,7 +92,7 @@ def assertEventuallyTrue(self, predicate, timeout=30, pause_time=0.5, logging.debug("wait finished after %s seconds", datetime.now() - start) def check_content(self, entity, **kwargs): - for k, v in six.iteritems(kwargs): + for k, v in list(kwargs): self.assertEqual(entity[k], str(v)) def check_entity(self, entity): @@ -154,9 +144,7 @@ def clear_restart_message(self): try: self.service.delete("messages/restart_required") except client.HTTPError as he: - if he.status == 404: - pass - else: + if he.status != 404: raise @contextlib.contextmanager @@ -179,7 +167,7 @@ def install_app_from_collection(self, name): self.service.post("apps/local", **kwargs) except client.HTTPError as he: if he.status == 400: - raise IOError("App %s not found in app collection" % name) + raise IOError(f"App {name} not found in app collection") if self.service.restart_required: self.service.restart(120) self.installedApps.append(name) @@ -268,6 +256,6 @@ def tearDown(self): except HTTPError as error: if not (os.name == 'nt' and error.status == 500): raise - print('Ignoring failure to delete {0} during tear down: {1}'.format(appName, error)) + print(f'Ignoring failure to delete {appName} during tear down: {error}') if self.service.restart_required: self.clear_restart_message() diff --git a/tox.ini b/tox.ini index 8b8bcb1b5..b5bcf34cb 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = clean,docs,py27,py37 +envlist = clean,docs,py37,py39 skipsdist = {env:TOXBUILD:false} [testenv:pep8] @@ -29,7 +29,6 @@ allowlist_externals = make deps = pytest pytest-cov xmlrunner - unittest2 unittest-xml-reporting python-dotenv deprecation diff --git a/utils/__init__.py b/utils/__init__.py index bd0900c3d..3a2b48de5 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -14,14 +14,14 @@ """Utility module shared by the SDK unit tests.""" -from __future__ import absolute_import from utils.cmdopts import * -from splunklib import six + def config(option, opt, value, parser): assert opt == "--config" parser.load(value) + # Default Splunk cmdline rules RULES_SPLUNK = { 'config': { @@ -30,7 +30,7 @@ def config(option, opt, value, parser): 'callback': config, 'type': "string", 'nargs': "1", - 'help': "Load options from config file" + 'help': "Load options from config file" }, 'scheme': { 'flags': ["--scheme"], @@ -40,30 +40,30 @@ def config(option, opt, value, parser): 'host': { 'flags': ["--host"], 'default': "localhost", - 'help': "Host name (default 'localhost')" + 'help': "Host name (default 'localhost')" }, - 'port': { + 'port': { 'flags': ["--port"], 'default': "8089", - 'help': "Port number (default 8089)" + 'help': "Port number (default 8089)" }, 'app': { - 'flags': ["--app"], + 'flags': ["--app"], 'help': "The app context (optional)" }, 'owner': { - 'flags': ["--owner"], + 'flags': ["--owner"], 'help': "The user context (optional)" }, 'username': { 'flags': ["--username"], 'default': None, - 'help': "Username to login with" + 'help': "Username to login with" }, 'password': { - 'flags': ["--password"], + 'flags': ["--password"], 'default': None, - 'help': "Password to login with" + 'help': "Password to login with" }, 'version': { 'flags': ["--version"], @@ -84,28 +84,30 @@ def config(option, opt, value, parser): FLAGS_SPLUNK = list(RULES_SPLUNK.keys()) + # value: dict, args: [(dict | list | str)*] def dslice(value, *args): """Returns a 'slice' of the given dictionary value containing only the requested keys. The keys can be requested in a variety of ways, as an arg list of keys, as a list of keys, or as a dict whose key(s) represent - the source keys and whose corresponding values represent the resulting - key(s) (enabling key rename), or any combination of the above.""" + the source keys and whose corresponding values represent the resulting + key(s) (enabling key rename), or any combination of the above.""" result = {} for arg in args: if isinstance(arg, dict): - for k, v in six.iteritems(arg): - if k in value: + for k, v in (list(arg.items())): + if k in value: result[v] = value[k] elif isinstance(arg, list): for k in arg: - if k in value: + if k in value: result[k] = value[k] else: - if arg in value: + if arg in value: result[arg] = value[arg] return result + def parse(argv, rules=None, config=None, **kwargs): """Parse the given arg vector with the default Splunk command rules.""" parser_ = parser(rules, **kwargs) @@ -113,8 +115,8 @@ def parse(argv, rules=None, config=None, **kwargs): parser_.loadenv(config) return parser_.parse(argv).result + def parser(rules=None, **kwargs): """Instantiate a parser with the default Splunk command rules.""" rules = RULES_SPLUNK if rules is None else dict(RULES_SPLUNK, **rules) return Parser(rules, **kwargs) - diff --git a/utils/cmdopts.py b/utils/cmdopts.py index b0cbb7328..e9fffb3b8 100644 --- a/utils/cmdopts.py +++ b/utils/cmdopts.py @@ -14,8 +14,6 @@ """Command line utilities shared by command line tools & unit tests.""" -from __future__ import absolute_import -from __future__ import print_function from os import path from optparse import OptionParser import sys @@ -25,7 +23,7 @@ # Print the given message to stderr, and optionally exit def error(message, exitcode = None): - print("Error: %s" % message, file=sys.stderr) + print(f"Error: {message}", file=sys.stderr) if exitcode is not None: sys.exit(exitcode)