引言 参照的是左程云的课程:https://space.bilibili.com/8888480/lists/3509640?type=series
是class019的内容,本笔记重点介绍Python在算法竞赛和笔试中的高效输入输出处理技巧,包含子矩阵最大累加和问题的完整实现,以及与Java的性能对比分析。
019【必备】算法笔试中处理输入和输出 概述 在算法竞赛和大厂笔试中,输入输出的处理效率往往是程序性能的关键瓶颈。Python虽然在执行速度上不如C++和Java,但通过合理的IO优化技巧,同样可以在大多数场景下取得良好的性能表现。
Python vs Java IO性能对比
特性
Python
Java
默认IO
input()、print() 较慢
Scanner、System.out 较慢
高效IO
sys.stdin.read()、批量输出
BufferedReader、PrintWriter
内存管理
自动垃圾回收,相对简单
需要考虑静态空间分配
大整数
原生支持任意精度
需要BigInteger类
编程复杂度
语法简洁,容易上手
类型安全,但代码较长
Python编程风格选择 填函数风格(OJ平台推荐) 适用于LeetCode、牛客网等在线判题平台,平台会自动处理输入输出。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 class Solution : def sumOfSubMatrix (self, mat, n ): """ 只需实现核心算法逻辑 平台自动调用并验证结果 """ return self .maxSumSubmatrix(mat, n, n) def maxSumSubmatrix (self, mat, rows, cols ): max_sum = float ('-inf' ) for i in range (rows): arr = [0 ] * cols for j in range (i, rows): for k in range (cols): arr[k] += mat[j][k] max_sum = max (max_sum, self .maxSumSubarray(arr)) return max_sum def maxSumSubarray (self, arr ): max_sum = float ('-inf' ) cur = 0 for num in arr: cur += num max_sum = max (max_sum, cur) cur = max (cur, 0 ) return max_sum
ACM风格(竞赛笔试推荐) 适用于ACM竞赛、大厂笔试等需要自己处理输入输出的场景。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 import sysdef main (): """ 完整的输入输出处理 需要自己解析数据格式 """ lines = sys.stdin.read().split() ptr = 0 output = [] while ptr < len (lines): n = int (lines[ptr]) ptr += 1 m = int (lines[ptr]) ptr += 1 mat = [] for i in range (n): row = [] for j in range (m): row.append(int (lines[ptr])) ptr += 1 mat.append(row) result = maxSumSubmatrix(mat, n, m) output.append(str (result)) print ('\n' .join(output)) if __name__ == '__main__' : main()
高效输入输出技巧 输入优化策略 推荐:一次性读取大数据 1 2 3 4 5 6 7 8 9 10 11 import syslines = sys.stdin.read().split() ptr = 0 while ptr < len (lines): n = int (lines[ptr]) ptr += 1
优点 :
IO次数最少,效率最高
相当于Java的BufferedReader一次性读取
适合大数据量场景
可选:按行读取 1 2 3 4 5 for line in sys.stdin: parts = line.strip().split() total = sum (int (num) for num in parts) print (total)
适用场景 :
每行数据格式不同
无法提前确定数据量
相当于Java的逐行readLine()
1 2 3 4 n = int (input ()) for i in range (n): x = int (input ())
输出优化策略 推荐:批量输出 1 2 3 4 5 6 7 output = [] for i in range (n): result = solve(data[i]) output.append(str (result)) print ('\n' .join(output))
不推荐:频繁print() 1 2 3 for i in range (n): print (solve(data[i]))
内存优化:静态空间vs动态空间 推荐:静态空间分配 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 MAXN = 201 MAXM = 201 mat = [[0 ] * MAXM for _ in range (MAXN)] arr = [0 ] * MAXM def solve (): global n, m for i in range (n): for j in range (m): mat[i][j] = read_next_int() for i in range (m): arr[i] = 0
不推荐:频繁动态分配 1 2 3 4 5 6 7 8 def solve (): mat = [] for i in range (n): row = [] for j in range (m): row.append(read_next_int()) mat.append(row)
子矩阵最大累加和问题详解 问题描述与算法思路 问题 :给定包含正数、负数、零的矩阵,求累加和最大的子矩阵。
核心思想 :将二维问题转化为一维最大子数组和问题
枚举子矩阵的上下边界(第i行到第j行)
将每列在这个范围内的元素累加,得到一维数组
对一维数组使用Kadane算法求最大子数组和
完整实现(填函数风格) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 class Code01_FillFunction : def sumOfSubMatrix (self, mat, n ): """主方法,求n×n矩阵的最大子矩阵和""" return self .maxSumSubmatrix(mat, n, n) @staticmethod def maxSumSubmatrix (mat, n, m ): """求子矩阵的最大累加和""" max_sum = float ('-inf' ) for i in range (n): arr = [0 ] * m for j in range (i, n): for k in range (m): arr[k] += mat[j][k] max_sum = max (max_sum, Code01_FillFunction.maxSumSubarray(arr, m)) return max_sum @staticmethod def maxSumSubarray (arr, m ): """Kadane算法求最大子数组和""" max_sum = float ('-inf' ) cur = 0 for i in range (m): cur += arr[i] max_sum = max (max_sum, cur) cur = 0 if cur < 0 else cur return max_sum
算法复杂度 :
时间复杂度 :O(n² × m)
空间复杂度 :O(m)
ACM风格实现(静态空间优化) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 import sysMAXN = 201 MAXM = 201 mat = [[0 ] * MAXM for _ in range (MAXN)] arr = [0 ] * MAXM n = m = 0 def main (): global n, m tokens = sys.stdin.read().split() idx = 0 output = [] while idx < len (tokens): n = int (tokens[idx]) idx += 1 m = int (tokens[idx]) idx += 1 for i in range (n): for j in range (m): mat[i][j] = int (tokens[idx]) idx += 1 output.append(str (maxSumSubmatrix())) print ('\n' .join(output)) def maxSumSubmatrix (): """使用静态空间的子矩阵最大和算法""" max_sum = float ('-inf' ) for i in range (n): for x in range (m): arr[x] = 0 for j in range (i, n): for k in range (m): arr[k] += mat[j][k] max_sum = max (max_sum, maxSumSubarray()) return max_sum def maxSumSubarray (): """一维最大子数组和""" max_sum = float ('-inf' ) cur = 0 for i in range (m): cur += arr[i] max_sum = max (max_sum, cur) cur = 0 if cur < 0 else cur return max_sum if __name__ == '__main__' : main()
执行过程示例 以矩阵为例:
1 2 3 [[-90, 48, 78], [ 64,-40, 64], [-81, -7, 66]]
枚举过程 :
i=0, j=0 (第0行):arr=[-90, 48, 78] → 最大子数组和=126
i=0, j=1 (第0-1行):arr=[-26, 8, 142] → 最大子数组和=150
i=0, j=2 (第0-2行):arr=[-107, 1, 208] → 最大子数组和=209
i=1, j=1 (第1行):arr=[64, -40, 64] → 最大子数组和=88
i=1, j=2 (第1-2行):arr=[-17, -47, 130] → 最大子数组和=130
i=2, j=2 (第2行):arr=[-81, -7, 66] → 最大子数组和=66
最终结果 :209(来自第0-2行,第2-2列的子矩阵)
Python高级IO优化 FastReader快读类 FastReader 是一个极致追求输入速度的工具,适用于数据量极大的算法竞赛场景。
主要作用
极速读取输入,尤其是大量数字(如百万级数据)。
通过 一次性读取大块数据(8KB) 到内存,减少系统I/O调用次数。
按字节处理并手动解析数字,比标准 input() 或 sys.stdin.readline() 更快。
工作原理
缓冲区:用 self.buffer 存储从输入流一次性读取的大块数据。
按字节解析:用 readByte 方法逐字节读取,跳过非数字字符,自己实现整数解析(包括负数)。
高效:只处理数字和符号,省略 split、strip 等高层方法,极致优化输入。
典型用法 适合极端大数据输入、对时间卡得很紧的OJ平台。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 import sysclass FastReader : def __init__ (self, file=sys.stdin ): self .file = file self .buffer = "" self .idx = 0 def _read (self ): """一次性读取8KB数据到缓冲区""" self .buffer = self .file.read(8192 ) self .idx = 0 def readByte (self ): """读取下一个字节""" if self .idx >= len (self .buffer): self ._read() if self .buffer == "" : return -1 byte = self .buffer[self .idx] self .idx += 1 return ord (byte) def readInt (self ): """快速读取整数""" num = 0 minus = False b = self .readByte() while b != -1 and (b < ord ('0' ) or b > ord ('9' )) and b != ord ('-' ): b = self .readByte() if b == ord ('-' ): minus = True b = self .readByte() while b != -1 and (ord ('0' ) <= b <= ord ('9' )): num = num * 10 + (b - ord ('0' )) b = self .readByte() return -num if minus else num def readLong (self ): """读取长整数(Python中与int相同)""" return self .readInt()
FastWriter快写类 FastWriter 是一个高效输出工具,适用于需要频繁输出、输出量大的场合。
主要作用
减少输出次数:把所有输出内容先缓存在内存里,最后统一输出,减少系统调用。
链式调用和兼容Java风格,用起来很方便。
工作原理
缓冲区:所有待输出内容先存入 self.buffer 列表。
批量输出:调用 flush() 时,一次性将所有内容写入输出流。
兼容性:支持 write(写字符串)、writeln/println(写一行),用法灵活。
典型用法
适合数据量大、频繁输出的算法题/竞赛场景,防止 print() 太慢导致超时。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 import sysclass FastWriter : def __init__ (self, file=sys.stdout ): self .file = file self .buffer = [] def write (self, s ): """写入字符串到缓冲区""" self .buffer.append(str (s)) return self def writeln (self, s="" ): """写入一行""" self .buffer.append(str (s) + "\n" ) return self def println (self, s="" ): """兼容Java习惯的方法名""" return self .writeln(s) def flush (self ): """刷新缓冲区,实际写入文件""" if self .buffer: self .file.write("" .join(self .buffer)) self .file.flush() self .buffer = [] def close (self ): """关闭写入器""" self .flush() if self .file != sys.stdout: self .file.close()
Kattio类(Python版) Kattio 类是一个高效的输入输出(I/O)工具类,最初流行于 Java 的竞赛编程圈。它的 Python 版本(如你上面给出的代码)主要是用来简化和加快处理标准输入输出,尤其适合数据量较大、输入格式“特殊”或需要频繁读取单个数据的场景,比如各类算法竞赛、OJ(Online Judge)平台等。
主要作用
高效读取输入:普通的 input() 在数据量大时会变慢,Kattio 通过缓冲和一次性读取一行数据,提升了读取效率。
简化输入格式处理:常见的输入格式如多行多列、混合类型(int、float、str)都能方便读取,不用每次都写 split、map 一大堆。
输出简便:带有封装的 println 方法,输出不再需要手动 print(…, file=…)。
工作原理
维护一个缓冲区(self.buffer),每次读取一整行并分割成单词。
提供 next() 方法按顺序读取下一个字符串,nextInt() 读取下一个整数,nextDouble() 读取下一个浮点数等。
适配输入和输出流(默认为标准输入输出,但也可以重定向到文件),并提供 close 方法在需要时关闭流。
参考Java版Kattio的Python实现,处理特殊输入格式:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 import sysfrom typing import Optional class Kattio : """ 高效IO类,适用于特殊格式输入 效率略低于FastReader,但兼容性更好 """ def __init__ (self, input_stream=sys.stdin, output_stream=sys.stdout ): self .input = input_stream self .output = output_stream self .buffer = [] self .idx = 0 def _fill_buffer (self ): """填充缓冲区""" line = self .input .readline() if line == '' : return self .buffer = line.strip().split() self .idx = 0 def next (self ) -> Optional [str ]: """读取下一个字符串""" while self .idx >= len (self .buffer): self ._fill_buffer() if not self .buffer: return None result = self .buffer[self .idx] self .idx += 1 return result def nextInt (self ) -> int : """读取下一个整数""" return int (self .next ()) def nextLong (self ) -> int : """读取下一个长整数""" return int (self .next ()) def nextDouble (self ) -> float : """读取下一个浮点数""" return float (self .next ()) def println (self, s ): """输出一行""" print (s, file=self .output) def close (self ): """关闭IO流""" if self .input != sys.stdin: self .input .close() if self .output != sys.stdout: self .output.close()
Python常用数据结构快速参考 基础容器操作 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 from collections import deque, defaultdictimport bisectarr = [1 , 2 , 3 ] arr.append(4 ) arr.insert(0 , 0 ) arr.pop() arr.pop(0 ) print (arr[1 ]) dq = deque() dq.append(1 ) dq.appendleft(2 ) dq.pop() dq.popleft() s1 = {1 , 2 , 3 } s2 = {2 , 3 , 4 } print (s1 | s2) print (s1 & s2) print (s1 - s2) d = defaultdict(int ) d["key" ] += 1
排序与查找 1 2 3 4 5 6 7 8 9 10 11 arr = [3 , 1 , 4 , 1 , 5 ] arr.sort() arr.sort(reverse=True ) sorted_arr = sorted (arr) custom_sorted = sorted (arr, key=lambda x: -x) arr = [1 , 2 , 4 , 7 , 9 ] idx = bisect.bisect_left(arr, 4 ) idx = bisect.bisect_right(arr, 4 )
大整数与高精度 1 2 3 4 5 6 7 8 big_num = 10 **100 result = big_num * big_num import decimaldecimal.getcontext().prec = 50 a = decimal.Decimal('1' ) / decimal.Decimal('3' )
实战技巧与注意事项 性能优化技巧 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 for i in range (n): for j in range (len (arr)): pass arr_len = len (arr) for i in range (n): for j in range (arr_len): pass def process (): for i in range (n): result += global_data[i] def process (): local_data = global_data for i in range (n): result += local_data[i] s = "" for i in range (n): s += str (i) parts = [] for i in range (n): parts.append(str (i)) s = "" .join(parts)
常见陷阱与解决方案 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 import syssys.setrecursionlimit(10000 ) import mathdef is_equal (a, b, eps=1e-9 ): return abs (a - b) < eps matrix = [[0 ] * m] * n matrix = [[0 ] * m for _ in range (n)] from collections import defaultdictcount = defaultdict(int ) count[key] += 1
调试技巧 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 DEBUG = False def debug_print (*args ): if DEBUG: print ("DEBUG:" , *args) import sysif DEBUG: sys.stdin = open ('input.txt' , 'r' ) sys.stdout = open ('output.txt' , 'w' ) import timestart_time = time.time() print (f"执行时间: {time.time() - start_time:.3 f} 秒" )