diff --git a/.gitignore b/.gitignore index 05505fe7..74dc1d1e 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,6 @@ tests/searchcommands/apps/app_with_logging_configuration/*.log venv/ .tox test-reports/ +.venv/ +orca.json +__pycache__/ diff --git a/splunklib/searchcommands/internals.py b/splunklib/searchcommands/internals.py index 1ea2833d..05c13303 100644 --- a/splunklib/searchcommands/internals.py +++ b/splunklib/searchcommands/internals.py @@ -772,14 +772,18 @@ class RecordWriterV2(RecordWriter): def flush(self, finished=None, partial=None): RecordWriter.flush(self, finished, partial) # validates arguments and the state of this instance + + if partial: + finished = False - if partial or not finished: - # Don't flush partial chunks, since the SCP v2 protocol does not - # provide a way to send partial chunks yet. - return + # self.write_chunk(finished) + # Note: when the stdout buffer is flushed, it does not mean we are "finished" if not self.is_flushed: - self.write_chunk(finished=True) + self.write_chunk(finished) + elif finished: + self._write_chunk((('finished', True),), '') + # self.write_chunk(finished) def write_chunk(self, finished=None): inspector = self._inspector @@ -841,4 +845,4 @@ def _write_chunk(self, metadata, body): self.write(metadata) self.write(body) self._ofile.flush() - self._flushed = True + self._flushed = False diff --git a/splunklib/searchcommands/search_command.py b/splunklib/searchcommands/search_command.py index 30b1d1c2..c13144af 100644 --- a/splunklib/searchcommands/search_command.py +++ b/splunklib/searchcommands/search_command.py @@ -404,7 +404,7 @@ def flush(self): :return: :const:`None` """ - self._record_writer.flush(finished=False) + self._record_writer.flush(partial=True) def prepare(self): """ Prepare for execution. @@ -801,6 +801,7 @@ def _process_protocol_v2(self, argv, ifile, ofile): # noinspection PyBroadException try: debug('Executing under protocol_version=2') + self._records = self._records_protocol_v2 self._metadata.action = 'execute' self._execute(ifile, None) except SystemExit: @@ -872,12 +873,18 @@ def _execute(self, ifile, process): :rtype: NoneType """ - if self.protocol_version == 1: - self._record_writer.write_records(process(self._records(ifile))) - self.finish() - else: - assert self._protocol_version == 2 - self._execute_v2(ifile, process) + # TODO: refactor and account for DVPL-12129 + # if self.protocol_version == 1: + # self._record_writer.write_records(process(self._records(ifile))) + # else: + # assert self._protocol_version == 2 + # self._execute_v2(ifile, process) + + # DVPL-12129 + # Must process the all the records via self._records prior to writing + # See self._records_protocol_v1 and self._records_protocol_v2 + self._record_writer.write_records(process(self._records(ifile))) + self.finish() @staticmethod def _as_binary_stream(ifile): @@ -940,6 +947,57 @@ def _read_chunk(istream): def _records_protocol_v1(self, ifile): return self._read_csv_records(ifile) + + def _records_protocol_v2(self, ifile): + istream = self._as_binary_stream(ifile) + + while True: + result = self._read_chunk(istream) + + if not result: + return + + metadata, body = result + action = getattr(metadata, 'action', None) + + if action != 'execute': + raise RuntimeError('Expected execute action, not {}'.format(action)) + + finished = getattr(metadata, 'finished', False) + self._record_writer.is_flushed = False + + # DVPL-12129 + # Read records from here and must not write the records (self._record_writer.write_records) right away + # so that in can be processed in batch as per maxresultrows. Otherwise we are writing smaller chunks of results, + # hence inneffective for streaming commands. + if len(body) > 0: + reader = csv.reader(StringIO(body), dialect=CsvDialect) + + try: + fieldnames = next(reader) + except StopIteration: + return + + mv_fieldnames = dict([(name, name[len('__mv_'):]) for name in fieldnames if name.startswith('__mv_')]) + + if len(mv_fieldnames) == 0: + for values in reader: + yield OrderedDict(izip(fieldnames, values)) + else: + for values in reader: + record = OrderedDict() + for fieldname, value in izip(fieldnames, values): + if fieldname.startswith('__mv_'): + if len(value) > 0: + record[mv_fieldnames[fieldname]] = self._decode_list(value) + elif fieldname not in record: + record[fieldname] = value + yield record + + if finished: + return + + self.flush() def _read_csv_records(self, ifile): reader = csv.reader(ifile, dialect=CsvDialect) @@ -966,6 +1024,8 @@ def _read_csv_records(self, ifile): record[fieldname] = value yield record + # Leaving this method here for generating_command.py for the time being + # TODO: refactor and account for DVPL-12129 def _execute_v2(self, ifile, process): istream = self._as_binary_stream(ifile) diff --git a/tests/searchcommands/test_internals_v2.py b/tests/searchcommands/test_internals_v2.py index ec9b3f66..d0973d8e 100755 --- a/tests/searchcommands/test_internals_v2.py +++ b/tests/searchcommands/test_internals_v2.py @@ -136,7 +136,7 @@ def test_record_writer_with_random_data(self, save_recording=False): write_record = writer.write_record - for serial_number in range(0, 31): + for serial_number in range(0, 502): values = [serial_number, time(), random_bytes(), random_dict(), random_integers(), random_unicode()] record = OrderedDict(izip(fieldnames, values)) #try: @@ -170,12 +170,12 @@ def test_record_writer_with_random_data(self, save_recording=False): for name, metric in six.iteritems(metrics): writer.write_metric(name, metric) - self.assertEqual(writer._chunk_count, 0) - self.assertEqual(writer._record_count, 31) - self.assertEqual(writer.pending_record_count, 31) + self.assertEqual(writer._chunk_count, 50) + self.assertEqual(writer._record_count, 2) + self.assertEqual(writer.pending_record_count, 2) self.assertGreater(writer._buffer.tell(), 0) - self.assertEqual(writer._total_record_count, 0) - self.assertEqual(writer.committed_record_count, 0) + self.assertEqual(writer._total_record_count, 500) + self.assertEqual(writer.committed_record_count, 500) fieldnames.sort() writer._fieldnames.sort() self.assertListEqual(writer._fieldnames, fieldnames) @@ -185,15 +185,15 @@ def test_record_writer_with_random_data(self, save_recording=False): dict(ifilter(lambda k_v: k_v[0].startswith('metric.'), six.iteritems(writer._inspector))), dict(imap(lambda k_v1: ('metric.' + k_v1[0], k_v1[1]), six.iteritems(metrics)))) - writer.flush(finished=True) + writer.flush(partial=True) - self.assertEqual(writer._chunk_count, 1) + self.assertEqual(writer._chunk_count, 51) self.assertEqual(writer._record_count, 0) self.assertEqual(writer.pending_record_count, 0) self.assertEqual(writer._buffer.tell(), 0) self.assertEqual(writer._buffer.getvalue(), '') - self.assertEqual(writer._total_record_count, 31) - self.assertEqual(writer.committed_record_count, 31) + self.assertEqual(writer._total_record_count, 502) + self.assertEqual(writer.committed_record_count, 502) self.assertRaises(AssertionError, writer.flush, finished=True, partial=True) self.assertRaises(AssertionError, writer.flush, finished='non-boolean') diff --git a/tests/searchcommands/test_reporting_command.py b/tests/searchcommands/test_reporting_command.py index e5add818..43cab639 100644 --- a/tests/searchcommands/test_reporting_command.py +++ b/tests/searchcommands/test_reporting_command.py @@ -10,7 +10,7 @@ class TestReportingCommand(searchcommands.ReportingCommand): def reduce(self, records): value = 0 for record in records: - value += int(record["value"]) + value += int(record.get("value")) yield {'sum': value} cmd = TestReportingCommand()