# softmax
import torch
# X = torch.tensor([-0.3, 0.2, 0.5, 0.7, 0.1, 0.8])
# X_exp_sum = X.exp().sum()
# X_softmax_hand = torch.exp(X) / X_exp_sum
# print(X_softmax_hand)
# # safe softmax
# X_max = X.max()
# X_exp_sum_sub_max = torch.exp(X-X_max).sum()
# X_safe_softmax_hand = torch.exp(X - X_max) / X_exp_sum_sub_max
# print(X_safe_softmax_hand)
# # online softmax
# X_pre = X[:-1]
# print('input x')
# print(X)
# print(X_pre)
# print(X[-1])
# # we calculative t-1 time Online Softmax
# X_max_pre = X_pre.max()
# X_sum_pre = torch.exp(X_pre - X_max_pre).sum()
# # we calculative t time Online Softmax
# X_max_cur = torch.max(X_max_pre, X[-1]) # X[-1] is new data
# X_sum_cur = X_sum_pre * torch.exp(X_max_pre - X_max_cur) + torch.exp(X[-1] - X_max_cur)
# # final we calculative online softmax
# X_online_softmax = torch.exp(X - X_max_cur) / X_sum_cur
# print('online softmax result: ', X_online_softmax)
# # block online softmax
# X_block = torch.split(X, split_size_or_sections = 3 , dim = 0)
# print(X)
# print(X_block)
# # we parallel calculate different block max & sum
# X_block_0_max = X_block[0].max()
# X_block_0_sum = torch.exp(X_block[0] - X_block_0_max).sum()
# X_block_1_max = X_block[1].max()
# X_block_1_sum = torch.exp(X_block[1] - X_block_1_max).sum()
# # online block update max & sum
# X_block_1_max_update = torch.max(X_block_0_max, X_block_1_max) # X[-1] is new data
# # X_block_1_sum_update = X_block_0_sum * torch.exp(X_block_0_max - X_block_1_max_update) + torch.exp(X_block[1] - X_block_1_max_update).sum() # block sum
# X_block_1_sum_update = X_block_0_sum * torch.exp(X_block_0_max - X_block_1_max_update) + X_block_1_sum*torch.exp(X_block_1_max - X_block_1_max_update).sum() # block sum
# X_block_online_softmax = torch.exp(X - X_block_1_max_update) / X_block_1_sum_update
# print(X_block_online_softmax)
#batch online softmax
X_batch = torch.randn(4, 6)
_, d = X_batch.shape
X_batch_block_0 = X_batch[:, :d//2]
X_batch_block_1 = X_batch[:, d//2:]
# we parallel calculate different block max & sum
X_batch_0_max, _ = X_batch_block_0.max(dim = 1, keepdim = True)
X_batch_0_sum = torch.exp(X_batch_block_0 - X_batch_0_max).sum(dim = 1, keepdim = True)
X_batch_1_max, _ = X_batch_block_1.max(dim = 1, keepdim = True)
X_batch_1_sum = torch.exp(X_batch_block_1 - X_batch_1_max).sum(dim = 1, keepdim = True)
# online batch block update max & sum
X_batch_1_max_update = torch.maximum(X_batch_0_max, X_batch_1_max) # 逐个元素找最大值
X_batch_1_sum_update = X_batch_0_sum * torch.exp(X_batch_0_max - X_batch_1_max_update) \
+ torch.exp(X_batch_block_1 - X_batch_1_max_update).sum(dim = 1, keepdim = True) # block sum
X_batch_online_softmax = torch.exp(X_batch - X_batch_1_max_update) / X_batch_1_sum_update
print(X_batch_online_softmax)
手撕Softmax
12 views