0%

数据结构与算法自学笔记(5)- 算法笔试中处理输入和输出

引言

参照的是左程云的课程:https://space.bilibili.com/8888480/lists/3509640?type=series

是class019的内容,本笔记重点介绍Python在算法竞赛和笔试中的高效输入输出处理技巧,包含子矩阵最大累加和问题的完整实现,以及与Java的性能对比分析。

019【必备】算法笔试中处理输入和输出

概述

在算法竞赛和大厂笔试中,输入输出的处理效率往往是程序性能的关键瓶颈。Python虽然在执行速度上不如C++和Java,但通过合理的IO优化技巧,同样可以在大多数场景下取得良好的性能表现。

Python vs Java IO性能对比

特性 Python Java
默认IO input()print() 较慢 ScannerSystem.out 较慢
高效IO sys.stdin.read()、批量输出 BufferedReaderPrintWriter
内存管理 自动垃圾回收,相对简单 需要考虑静态空间分配
大整数 原生支持任意精度 需要BigInteger
编程复杂度 语法简洁,容易上手 类型安全,但代码较长

Python编程风格选择

填函数风格(OJ平台推荐)

适用于LeetCode、牛客网等在线判题平台,平台会自动处理输入输出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution:
def sumOfSubMatrix(self, mat, n):
"""
只需实现核心算法逻辑
平台自动调用并验证结果
"""
return self.maxSumSubmatrix(mat, n, n)

def maxSumSubmatrix(self, mat, rows, cols):
max_sum = float('-inf')
for i in range(rows):
arr = [0] * cols # 辅助数组
for j in range(i, rows):
for k in range(cols):
arr[k] += mat[j][k]
max_sum = max(max_sum, self.maxSumSubarray(arr))
return max_sum

def maxSumSubarray(self, arr):
max_sum = float('-inf')
cur = 0
for num in arr:
cur += num
max_sum = max(max_sum, cur)
cur = max(cur, 0) # 负数时重置为0
return max_sum

ACM风格(竞赛笔试推荐)

适用于ACM竞赛、大厂笔试等需要自己处理输入输出的场景。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import sys

def main():
"""
完整的输入输出处理
需要自己解析数据格式
"""
lines = sys.stdin.read().split()
ptr = 0
output = []

while ptr < len(lines):
n = int(lines[ptr])
ptr += 1
m = int(lines[ptr])
ptr += 1

# 构建矩阵
mat = []
for i in range(n):
row = []
for j in range(m):
row.append(int(lines[ptr]))
ptr += 1
mat.append(row)

# 计算结果
result = maxSumSubmatrix(mat, n, m)
output.append(str(result))

# 批量输出
print('\n'.join(output))

if __name__ == '__main__':
main()

高效输入输出技巧

输入优化策略

推荐:一次性读取大数据

1
2
3
4
5
6
7
8
9
10
11
import sys

# 最高效:适用于已知数据量的场景
lines = sys.stdin.read().split()
ptr = 0

# 逐个解析数据
while ptr < len(lines):
n = int(lines[ptr])
ptr += 1
# 处理数据...

优点

  • IO次数最少,效率最高
  • 相当于Java的BufferedReader一次性读取
  • 适合大数据量场景

可选:按行读取

1
2
3
4
5
# 适用:需要按行处理不同格式数据
for line in sys.stdin:
parts = line.strip().split()
total = sum(int(num) for num in parts)
print(total)

适用场景

  • 每行数据格式不同
  • 无法提前确定数据量
  • 相当于Java的逐行readLine()

不推荐:频繁调用input()

1
2
3
4
# 效率低:类似Java的Scanner
n = int(input()) # 大数据时很慢
for i in range(n):
x = int(input()) # 每次都要系统调用

输出优化策略

推荐:批量输出

1
2
3
4
5
6
7
# 收集所有结果,最后一次性输出
output = []
for i in range(n):
result = solve(data[i])
output.append(str(result))

print('\n'.join(output)) # 一次性输出

不推荐:频繁print()

1
2
3
# 每次print都会刷新缓冲区,效率低
for i in range(n):
print(solve(data[i])) # 大数据时很慢

内存优化:静态空间vs动态空间

推荐:静态空间分配

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 全局预分配,避免频繁内存分配
MAXN = 201
MAXM = 201
mat = [[0] * MAXM for _ in range(MAXN)] # 静态矩阵
arr = [0] * MAXM # 静态辅助数组

def solve():
global n, m
# 复用已分配的空间
for i in range(n):
for j in range(m):
mat[i][j] = read_next_int()

# 使用前先清空
for i in range(m):
arr[i] = 0

不推荐:频繁动态分配

1
2
3
4
5
6
7
8
def solve():
# 每次都重新分配内存
mat = [] # 动态创建
for i in range(n):
row = [] # 每行都新建
for j in range(m):
row.append(read_next_int())
mat.append(row)

子矩阵最大累加和问题详解

问题描述与算法思路

问题:给定包含正数、负数、零的矩阵,求累加和最大的子矩阵。

核心思想:将二维问题转化为一维最大子数组和问题

  1. 枚举子矩阵的上下边界(第i行到第j行)
  2. 将每列在这个范围内的元素累加,得到一维数组
  3. 对一维数组使用Kadane算法求最大子数组和

完整实现(填函数风格)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Code01_FillFunction:
def sumOfSubMatrix(self, mat, n):
"""主方法,求n×n矩阵的最大子矩阵和"""
return self.maxSumSubmatrix(mat, n, n)

@staticmethod
def maxSumSubmatrix(mat, n, m):
"""求子矩阵的最大累加和"""
max_sum = float('-inf')

# 枚举上边界
for i in range(n):
arr = [0] * m # 辅助数组,每次重置

# 枚举下边界(从i到n-1)
for j in range(i, n):
# 将第j行累加到辅助数组
for k in range(m):
arr[k] += mat[j][k]

# 求当前辅助数组的最大子数组和
max_sum = max(max_sum, Code01_FillFunction.maxSumSubarray(arr, m))

return max_sum

@staticmethod
def maxSumSubarray(arr, m):
"""Kadane算法求最大子数组和"""
max_sum = float('-inf')
cur = 0

for i in range(m):
cur += arr[i]
max_sum = max(max_sum, cur)
cur = 0 if cur < 0 else cur # 负数时重置

return max_sum

算法复杂度

  • 时间复杂度:O(n² × m)
  • 空间复杂度:O(m)

ACM风格实现(静态空间优化)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import sys

# 静态空间分配,避免频繁内存分配
MAXN = 201
MAXM = 201
mat = [[0] * MAXM for _ in range(MAXN)]
arr = [0] * MAXM
n = m = 0

def main():
global n, m
tokens = sys.stdin.read().split()
idx = 0
output = []

while idx < len(tokens):
n = int(tokens[idx])
idx += 1
m = int(tokens[idx])
idx += 1

# 读取矩阵数据到静态空间
for i in range(n):
for j in range(m):
mat[i][j] = int(tokens[idx])
idx += 1

# 计算结果并收集输出
output.append(str(maxSumSubmatrix()))

# 批量输出所有结果
print('\n'.join(output))

def maxSumSubmatrix():
"""使用静态空间的子矩阵最大和算法"""
max_sum = float('-inf')

for i in range(n):
# 清空辅助数组(复用静态空间)
for x in range(m):
arr[x] = 0

for j in range(i, n):
# 累加第j行到辅助数组
for k in range(m):
arr[k] += mat[j][k]
max_sum = max(max_sum, maxSumSubarray())

return max_sum

def maxSumSubarray():
"""一维最大子数组和"""
max_sum = float('-inf')
cur = 0

for i in range(m):
cur += arr[i]
max_sum = max(max_sum, cur)
cur = 0 if cur < 0 else cur

return max_sum

if __name__ == '__main__':
main()

执行过程示例

以矩阵为例:

1
2
3
[[-90, 48, 78],
[ 64,-40, 64],
[-81, -7, 66]]

枚举过程

  1. i=0, j=0(第0行):arr=[-90, 48, 78] → 最大子数组和=126
  2. i=0, j=1(第0-1行):arr=[-26, 8, 142] → 最大子数组和=150
  3. i=0, j=2(第0-2行):arr=[-107, 1, 208] → 最大子数组和=209
  4. i=1, j=1(第1行):arr=[64, -40, 64] → 最大子数组和=88
  5. i=1, j=2(第1-2行):arr=[-17, -47, 130] → 最大子数组和=130
  6. i=2, j=2(第2行):arr=[-81, -7, 66] → 最大子数组和=66

最终结果:209(来自第0-2行,第2-2列的子矩阵)

Python高级IO优化

FastReader快读类

FastReader 是一个极致追求输入速度的工具,适用于数据量极大的算法竞赛场景。

主要作用

  • 极速读取输入,尤其是大量数字(如百万级数据)。
  • 通过 一次性读取大块数据(8KB) 到内存,减少系统I/O调用次数。
  • 按字节处理并手动解析数字,比标准 input() 或 sys.stdin.readline() 更快。

工作原理

  • 缓冲区:用 self.buffer 存储从输入流一次性读取的大块数据。
  • 按字节解析:用 readByte 方法逐字节读取,跳过非数字字符,自己实现整数解析(包括负数)。
  • 高效:只处理数字和符号,省略 split、strip 等高层方法,极致优化输入。

典型用法
适合极端大数据输入、对时间卡得很紧的OJ平台。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import sys

class FastReader:
def __init__(self, file=sys.stdin):
self.file = file
self.buffer = ""
self.idx = 0

def _read(self):
"""一次性读取8KB数据到缓冲区"""
self.buffer = self.file.read(8192)
self.idx = 0

def readByte(self):
"""读取下一个字节"""
if self.idx >= len(self.buffer):
self._read()
if self.buffer == "":
return -1
byte = self.buffer[self.idx]
self.idx += 1
return ord(byte)

def readInt(self):
"""快速读取整数"""
num = 0
minus = False
b = self.readByte()

# 跳过非数字字符
while b != -1 and (b < ord('0') or b > ord('9')) and b != ord('-'):
b = self.readByte()

if b == ord('-'):
minus = True
b = self.readByte()

# 读取数字
while b != -1 and (ord('0') <= b <= ord('9')):
num = num * 10 + (b - ord('0'))
b = self.readByte()

return -num if minus else num

def readLong(self):
"""读取长整数(Python中与int相同)"""
return self.readInt()

FastWriter快写类

FastWriter 是一个高效输出工具,适用于需要频繁输出、输出量大的场合。

主要作用

  • 减少输出次数:把所有输出内容先缓存在内存里,最后统一输出,减少系统调用。
  • 链式调用和兼容Java风格,用起来很方便。

工作原理

  • 缓冲区:所有待输出内容先存入 self.buffer 列表。
  • 批量输出:调用 flush() 时,一次性将所有内容写入输出流。
  • 兼容性:支持 write(写字符串)、writeln/println(写一行),用法灵活。

典型用法

适合数据量大、频繁输出的算法题/竞赛场景,防止 print() 太慢导致超时。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import sys

class FastWriter:
def __init__(self, file=sys.stdout):
self.file = file
self.buffer = []

def write(self, s):
"""写入字符串到缓冲区"""
self.buffer.append(str(s))
return self # 支持链式调用

def writeln(self, s=""):
"""写入一行"""
self.buffer.append(str(s) + "\n")
return self

def println(self, s=""):
"""兼容Java习惯的方法名"""
return self.writeln(s)

def flush(self):
"""刷新缓冲区,实际写入文件"""
if self.buffer:
self.file.write("".join(self.buffer))
self.file.flush()
self.buffer = []

def close(self):
"""关闭写入器"""
self.flush()
if self.file != sys.stdout:
self.file.close()

Kattio类(Python版)

Kattio 类是一个高效的输入输出(I/O)工具类,最初流行于 Java 的竞赛编程圈。它的 Python 版本(如你上面给出的代码)主要是用来简化和加快处理标准输入输出,尤其适合数据量较大、输入格式“特殊”或需要频繁读取单个数据的场景,比如各类算法竞赛、OJ(Online Judge)平台等。

主要作用

  • 高效读取输入:普通的 input() 在数据量大时会变慢,Kattio 通过缓冲和一次性读取一行数据,提升了读取效率。
  • 简化输入格式处理:常见的输入格式如多行多列、混合类型(int、float、str)都能方便读取,不用每次都写 split、map 一大堆。
  • 输出简便:带有封装的 println 方法,输出不再需要手动 print(…, file=…)。

工作原理

  • 维护一个缓冲区(self.buffer),每次读取一整行并分割成单词。
  • 提供 next() 方法按顺序读取下一个字符串,nextInt() 读取下一个整数,nextDouble() 读取下一个浮点数等。
  • 适配输入和输出流(默认为标准输入输出,但也可以重定向到文件),并提供 close 方法在需要时关闭流。

参考Java版Kattio的Python实现,处理特殊输入格式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import sys
from typing import Optional

class Kattio:
"""
高效IO类,适用于特殊格式输入
效率略低于FastReader,但兼容性更好
"""
def __init__(self, input_stream=sys.stdin, output_stream=sys.stdout):
self.input = input_stream
self.output = output_stream
self.buffer = []
self.idx = 0

def _fill_buffer(self):
"""填充缓冲区"""
line = self.input.readline()
if line == '':
return
self.buffer = line.strip().split()
self.idx = 0

def next(self) -> Optional[str]:
"""读取下一个字符串"""
while self.idx >= len(self.buffer):
self._fill_buffer()
if not self.buffer:
return None
result = self.buffer[self.idx]
self.idx += 1
return result

def nextInt(self) -> int:
"""读取下一个整数"""
return int(self.next())

def nextLong(self) -> int:
"""读取下一个长整数"""
return int(self.next())

def nextDouble(self) -> float:
"""读取下一个浮点数"""
return float(self.next())

def println(self, s):
"""输出一行"""
print(s, file=self.output)

def close(self):
"""关闭IO流"""
if self.input != sys.stdin:
self.input.close()
if self.output != sys.stdout:
self.output.close()

Python常用数据结构快速参考

基础容器操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from collections import deque, defaultdict
import bisect

# 列表(动态数组)
arr = [1, 2, 3]
arr.append(4) # O(1) 尾部添加
arr.insert(0, 0) # O(n) 头部插入
arr.pop() # O(1) 尾部删除
arr.pop(0) # O(n) 头部删除
print(arr[1]) # O(1) 随机访问

# 双端队列(可当栈或队列)
dq = deque()
dq.append(1) # 队尾入队
dq.appendleft(2) # 队首入队
dq.pop() # 队尾出队
dq.popleft() # 队首出队

# 集合操作
s1 = {1, 2, 3}
s2 = {2, 3, 4}
print(s1 | s2) # 并集
print(s1 & s2) # 交集
print(s1 - s2) # 差集

# 字典操作
d = defaultdict(int) # 默认值为0
d["key"] += 1 # 自动初始化并自增

排序与查找

1
2
3
4
5
6
7
8
9
10
11
# 排序
arr = [3, 1, 4, 1, 5]
arr.sort() # 原地升序排序
arr.sort(reverse=True) # 原地降序排序
sorted_arr = sorted(arr) # 返回新的排序数组
custom_sorted = sorted(arr, key=lambda x: -x) # 自定义排序

# 二分查找
arr = [1, 2, 4, 7, 9] # 必须有序
idx = bisect.bisect_left(arr, 4) # 查找插入位置
idx = bisect.bisect_right(arr, 4) # 查找插入位置(右侧)

大整数与高精度

1
2
3
4
5
6
7
8
# Python原生支持任意精度整数
big_num = 10**100 # 10的100次方
result = big_num * big_num # 自动处理大整数运算

# 高精度除法
import decimal
decimal.getcontext().prec = 50 # 设置精度
a = decimal.Decimal('1') / decimal.Decimal('3')

实战技巧与注意事项

性能优化技巧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# 1. 避免在循环中重复计算
# 错误写法
for i in range(n):
for j in range(len(arr)): # 每次都计算len(arr)
pass

# 正确写法
arr_len = len(arr)
for i in range(n):
for j in range(arr_len):
pass

# 2. 使用局部变量访问全局数据
# 错误写法
def process():
for i in range(n):
result += global_data[i] # 每次都查找全局变量

# 正确写法
def process():
local_data = global_data # 本地化全局变量
for i in range(n):
result += local_data[i]

# 3. 字符串拼接优化
# 错误写法
s = ""
for i in range(n):
s += str(i) # 每次都创建新字符串

# 正确写法
parts = []
for i in range(n):
parts.append(str(i))
s = "".join(parts)

常见陷阱与解决方案

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 1. 递归深度限制
import sys
sys.setrecursionlimit(10000) # 设置递归深度限制

# 2. 浮点数精度问题
import math
def is_equal(a, b, eps=1e-9):
return abs(a - b) < eps

# 3. 列表初始化陷阱
# 错误:所有行共享同一个列表
matrix = [[0] * m] * n

# 正确:每行都是独立的列表
matrix = [[0] * m for _ in range(n)]

# 4. 字典默认值
from collections import defaultdict
# 使用defaultdict避免KeyError
count = defaultdict(int)
count[key] += 1 # 自动初始化为0

调试技巧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 条件编译式调试
DEBUG = False

def debug_print(*args):
if DEBUG:
print("DEBUG:", *args)

# 输入输出重定向(本地测试)
import sys
if DEBUG:
sys.stdin = open('input.txt', 'r')
sys.stdout = open('output.txt', 'w')

# 计时器
import time
start_time = time.time()
# ... 算法代码 ...
print(f"执行时间: {time.time() - start_time:.3f}秒")