---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 y = x.reshape(4, 4)
2 y
RuntimeError: shape '[4, 4]' is invalid for input of size 12
但重组时,col与row的乘积必须等于元素总数
1 2 3
z = torch.zeros((2, 3, 4)) o = torch.ones((3, 2, 4)) print(z, '\n' ,o)
x = torch.arange(12).reshape((3,4)) y = torch.tensor([[2,1,4],[1,2,3],[4,3,2]]) torch.cat((x,y),dim=0), torch.cat((x,y),dim=1)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[13], line 3
1 x = torch.arange(12).reshape((3,4))
2 y = torch.tensor([[2,1,4],[1,2,3],[4,3,2]])
----> 3 torch.cat((x,y),dim=0), torch.cat((x,y),dim=1)
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 4 but got size 3 for tensor number 1 in the list.
若尺寸不对应则会报错
1 2 3
x = torch.arange(12).reshape((3,4)) y = torch.tensor([[2,1,4,3],[1,2,3,4],[4,3,2,1]]) x == y
y = torch.arange(4) before = id(y) y = y + x id(y) == before
False
在处理变量中,一般操作会导致新建内存,如果矩阵数据特别大,容易消耗内存,可以使用原地存储
1 2 3 4
y = torch.arange(4) before = id(y) y += x id(y) == before
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[22], line 3
1 y = torch.arange(4)
2 before = id(y)
----> 3 y += x
4 id(y) == before
RuntimeError: output with shape [4] doesn't match the broadcast shape [3, 4]
类型互换
1 2 3 4 5 6
y = torch.arange(4) x = torch.arange(4,8,1) z = torch.zeros_like(y) print('id(z):',id(z)) z[:] = x + y print('id(z):',id(z))
id(z): 140157950377472
id(z): 140157950377472
Pytorch与Numpy的互换
1 2 3
A = x.numpy() B = torch.tensor(A) type(A),type(B)
(numpy.ndarray, torch.Tensor)
将大小为1的张量转化为python标量
1 2
a = torch.tensor([3.5]) a,a.item(),float(a),int(a)
(tensor([3.5000]), 3.5, 3.5, 3)
2.2数据预处理
2.2.1 生成数据集文件
1 2 3 4 5 6 7 8 9
import os os.makedirs(os.path.join('..','data'), exist_ok=True) data_file = os.path.join('..', 'data', 'house_tiny.csv') withopen(data_file, 'w') as f: f.write('NumRooms,Alley,Price\n') f.write('NA,Pave,127500\n') f.write('2,NA,106000\n') f.write('4,NA,178100\n') f.write('NA,NA,140000\n')
创建CSV文件:房价数据-房间数量/巷子类型/价格
1 2 3 4
import pandas as pd
data = pd.read_csv(data_file) print(data)
NumRooms Alley Price
0 NaN Pave 127500
1 2.0 NaN 106000
2 4.0 NaN 178100
3 NaN NaN 140000