修正が必要なコード

以下に,メインとなるファイルと,エラーに関係しそうなモジュールを貼り付けます.
もしほかにコードで引用しているモジュールなど追加情報が必要でしたらお申し付けください.

If you need additional information such as modules this code import, please do not hesitate to tell me.


main_file

#!/usr/bin/env python
# coding: utf-8

# In[1]:


from make_embedding import preprocessing, divide_comments_by_video, get_comment_embedding, \
    get_title_desc_embedding, initialize_vgg_19, get_image_embedding_vgg_19, \
        cal_cos_sim_video_embedding, cal_attn_weight_embedding, ThumbFrameDataset
from get_common_thumb_frame import get_common_thumb_frame
from make_comment_bilstm import BiLSTM, create_batches
from video_dataset import VideoDataset, collate_fn


# In[2]:


import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, roc_curve, roc_auc_score
import matplotlib.pyplot as plt
import pickle
import warnings

from transformers import BertModel, BertTokenizer

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import MultiheadAttention
from torch.optim import AdamW
from torch.utils.data import DataLoader, Subset
from torch.multiprocessing import set_start_method

from torchvision import transforms
from torchvision import models
from torchvision.models.vgg import VGG19_Weights

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

import pytorch_lightning as pl


# In[3]:


try:
    set_start_method('spawn')
except RuntimeError:
    pass


# In[4]:


random_state=42
d = 768
max_length = 200
max_epochs = 200
patience = 10
num_workers = 16
n_splits = 5
video_batch_size = 1
lstm_hidden_size = 768//2
j = 0.1

batch_size = 4
comment_batch_size = 64
frame_batch_size = 256
dropout_rate = 0.5
lstm_batch_size = 128
lstm_dropout = 0.1 # 0.1-0.5 大規模データでは過学習のリスクが少ないので小さい値から始められる
lr = 1e-4
input_size = 768
hidden_dim = 128
num_layers = 2 # 1-3がいい
bidirectional = True
num_heads = 2
weight_decay = 1e-5
fig_save_name = 'reproduct_choi'
name='reproduct_choi'

torch.set_float32_matmul_precision('medium')
warnings.filterwarnings("ignore", category=FutureWarning)


# In[5]:


model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertModel.from_pretrained(model_name)


# In[6]:


df = preprocessing(random_state)


# In[7]:


common_ids_list = get_common_thumb_frame(df)
df = df[df['video_id'].isin(common_ids_list)]
# thumbnail: 288,video_frame: 213


# In[8]:


print(len(common_ids_list))


# In[9]:


num_real = len(df[df['label']==1])
num_fake = len(df[df['label']==0])

num_videos = len(df['video_id'].drop_duplicates())
print(f'動画数:{num_videos},データ数:{len(df["label"])}, リアル数:{num_real},フェイク数:{num_fake}')


# In[10]:


##### データ削減 #####
df = df.sample(frac=0.05, random_state=42)
#####################


# In[11]:


df_list = divide_comments_by_video(df)
df_drop = df.drop_duplicates(subset='video_id')


# In[12]:


# df = df_list[0]
# print(len(df))
# comment_embeddings = get_comment_embedding(df, tokenizer, bert_model, max_length, batch_size=128)
# print(comment_embeddings.shape)

# 1229
# torch.Size([1229, 200, 768])


# In[13]:


# title_desc_embeddings = get_title_desc_embedding(df_drop, tokenizer, bert_model, max_length=max_length)

# #### 書き込み #####
# with open(f'pickle/title_desc_embeddings_maxlength={max_length}.pkl', 'wb') as f:
#     pickle.dump(title_desc_embeddings, f)


# In[14]:


# ##### 読み込み #####
# with open(f'pickle/title_desc_embeddings_maxlength={max_length}.pkl', 'rb') as f:
#     title_desc_embeddings = pickle.load(f)


# In[15]:


# print(title_desc_embeddings.shape)
# torch.Size([117, 200, 768])


# In[16]:


def plot_roc_curve(fpr, tpr, random_state, batch_size, max_length, fig_save_name):
    plt.figure()
    plt.plot(fpr, tpr, label='ROC curve')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend(loc="lower right")
    filename = f'fig/{fig_save_name}_{random_state}_batch_size={batch_size}_max_length={max_length}.png'
    plt.savefig(filename)
    plt.close()

def plot_fpr_threshold(thresholds, fpr, random_state, batch_size, max_length, fig_save_name):
    plt.figure()
    plt.plot(thresholds, fpr)
    plt.xlabel('Thresholds')
    plt.ylabel('False Positive Rate')
    plt.title('Threshold vs. FPR')
    plt.gca().invert_xaxis()
    plt.grid(True)
    filename = f'fig/{fig_save_name}_{random_state}_batch_size={batch_size}_max_length={max_length}.png'
    plt.savefig(filename)
    plt.close()

def plot_tpr_threshold(thresholds, tpr, random_state, batch_size, max_length, fig_save_name):
    plt.figure()
    plt.plot(thresholds, tpr)
    plt.xlabel('Thresholds')
    plt.ylabel('True Positive Rate')
    plt.title('Threshold vs. FPR')
    plt.gca().invert_xaxis()
    plt.grid(True)
    filename = f'fig/{fig_save_name}_{random_state}_batch_size={batch_size}_max_length={max_length}.png'
    plt.savefig(filename)
    plt.close()


# In[17]:


from video_dataset import CommentProcessor


# In[18]:


# サブタスク2 タイトル処理
class TitleDescProcessor(nn.Module):
    def __init__(self, d):
        super(TitleDescProcessor, self).__init__()
        self.fc = nn.Linear(d, 2*d)

    def forward(self, x):
        # x shape: (batch_size, num_titles, max_length, embedding_dim)
        # 平均値プーリング
        x = torch.mean(x, dim=2)
        # x shape: (batch_size, num_titles, embedding_dim)
        x = self.fc(x)
        # x shape: (batch_size, num_titles, embedding_dim*2)
        # 平均値プーリング
        x = torch.mean(x, dim=1)

        return x


# In[19]:


# ##### 使用例 #####
# df = df_list[0]
# title_desc_embeddings = get_title_desc_embedding(df, tokenizer, bert_model, max_length, batch_size=32)
# title_desc_embeddings = title_desc_embeddings.unsqueeze(0).to('cuda')
# processor = TitleDescProcessor(d=768)
# processor = processor.to('cuda')
# embedding = processor(title_desc_embeddings)
# print(embedding.shape)
# # torch.Size([1, 1536])


# In[20]:


# サブクラス3: 上位j個の類似フレームを取得
class GetJFrames(nn.Module):
    def __init__(self, d=768, j=0.1, batch_size=1, frame_batch_size=32):
        super(GetJFrames, self).__init__()
        self.j = j
        self.batch_size = batch_size
        self.frame_batch_size = frame_batch_size
        self.vgg_19 = initialize_vgg_19(d=d)  # vgg_19の初期化

    def forward(self, common_ids_list):
        self.vgg_19 = self.vgg_19.to('cuda')
        dataset = ThumbFrameDataset(common_ids_list)
        data_loader = DataLoader(dataset, self.batch_size)
        top_j_sim_video_embeddings_list = cal_cos_sim_video_embedding(data_loader, self.vgg_19, self.j, self.frame_batch_size)
        
        return top_j_sim_video_embeddings_list


# In[21]:


# with open('pickle/top_j_sim_video_embeddings_list.pkl', 'wb') as f:
#     pickle.dump(top_j_sim_video_embeddings_list, f)


# In[22]:


with open('pickle/top_j_sim_video_embeddings_list.pkl', 'rb') as f:
    top_j_sim_video_embeddings_list = pickle.load(f)


# In[23]:


# サブクラス4: ビデオの処理
class VideoProcessor(nn.Module):
    def __init__(self, video_batch_size=64, d=768, num_heads=8):
        super(VideoProcessor, self).__init__()
        self.attention = MultiheadAttention(embed_dim=d*2, num_heads=num_heads, batch_first=True)
        self.video_batch_size = video_batch_size
        self.video_fc = nn.Linear(2*d, 2*d)

    def forward(self, top_j_sim_video_embeddings_list):
        self.attention = self.attention.to('cuda')
        weighted_avg_video_embeddings = cal_attn_weight_embedding(self.attention, top_j_sim_video_embeddings_list)
        video_output = self.video_fc(weighted_avg_video_embeddings)
        video_output_avg = torch.mean(video_output, dim=1)

        return video_output_avg


# In[24]:


##### 使用例 #####
with open('pickle/top_j_sim_video_embeddings_list.pkl', 'rb') as f:
    top_j_sim_video_embeddings_list = pickle.load(f)
processor = VideoProcessor(d=768, num_heads=8)
processor = processor.to('cuda')
weighted_avg_video_embeddings = processor(top_j_sim_video_embeddings_list)
print(weighted_avg_video_embeddings.shape)


# In[25]:


from torch import float16


class FakeNewsDetector(pl.LightningModule):
    def __init__(self, tokenizer, bert_model, random_state, max_length, batch_size, num_workers, lr, n_split, dropout_rate, lstm_dropout, input_size, lstm_hidden_size, hidden_dim, num_layers, bidirectional, num_heads, max_epochs, patience, fig_save_name, name, weight_decay, d=768):
        super().__init__()
        self.save_hyperparameters(ignore=['tokenizer', 'bert_model'])

        self.validation_step_outputs = []
        self.d = d

        self.video_fc = nn.Linear(2*d, 2*d)

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(max_length * 2*d, 1024)
        self.bn1 = nn.BatchNorm1d(1024)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(1024, 512)
        self.bn2 = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.fc4 = nn.Linear(128, 1)

        self.comment_weight = nn.Parameter(torch.randn(1))
        self.title_desc_weight = nn.Parameter(torch.randn(1))
        self.video_weight = nn.Parameter(torch.randn(1))
        print('=====ハイパーパラメータなど完了=====')

        self.bilstm_model = BiLSTM(input_size=int(input_size), hidden_size=int(lstm_hidden_size),
                                    num_layers=int(num_layers), dropout=float(lstm_dropout))
        self.bilstm_model = self.bilstm_model.to('cuda')
        print('=====Bi-LSTM完了=====')

        self.comment_processor = CommentProcessor(d, num_layers, lstm_dropout, lstm_batch_size, lstm_hidden_size=768//2)
        self.title_desc_processor = TitleDescProcessor(d)
        self.get_j_frames = GetJFrames()
        self.video_processor = VideoProcessor()
        print('=====サブクラスなど完了=====')

    def forward(self, comment_embeddings, masks_stack, hit_likes, title_desc_embedding, video_output_stack):
       

        comment_output_avg = self.comment_processor(comment_embeddings)
        # shape: (batch_size, 2*d)
        

        title_desc_output_avg = self.title_desc_processor(title_desc_embedding)
        # shape: (batch_size, 2*d)
        

        top_j_sim_video_embeddings_list = self.get_j_frames(common_ids_list)
        video_output_avg = self.video_processor(top_j_sim_video_embeddings_list)
        # shape: (1, 2*d)


        weights = F.softmax(torch.stack([self.comment_weight, self.title_desc_weight, self.video_weight]), dim=0)


        combined_output = weights[0] * comment_output_avg + weights[1] * title_desc_output_avg + weights[2] * video_output_avg


        x = self.flatten(combined_output)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.bn1(x)
        x = self.dropout(x)
        
        x = self.fc2(x)
        x = F.relu(x)
        x = self.bn2(x)
        x = self.dropout(x)
        
        x = self.fc3(x)
        x = F.relu(x)
        x = self.bn3(x)
        x = self.dropout(x)
        
        x = self.fc4(x)
        x = torch.sigmoid(x)
        x = x.squeeze()  # 不要な次元を削除して形状を(batch_size,)にする
        return x
    
    
    def training_step(self, batch, batch_idx):
        comment_embeddings, masks, hit_likes, title_desc_embedding, weighted_avg_video_embedding, label = batch
        output = self(comment_embeddings, masks, hit_likes, title_desc_embedding, weighted_avg_video_embedding)
        loss = F.binary_cross_entropy(output, label)
        self.log('train_loss', loss)

        return loss
    
    def validation_step(self, batch, batch_idx):
        comment_embeddings, masks, hit_likes, title_desc_embedding, weighted_avg_video_embedding, label = batch
        output = self(comment_embeddings, masks, hit_likes, title_desc_embedding, weighted_avg_video_embedding)
        
        loss = F.binary_cross_entropy(output, label)
        self.log('val_loss', loss)

        label_predicted = output

        label_predicted = label_predicted.cpu().numpy()
        label = label.cpu().numpy()
        logits = torch.logit(output).cpu().numpy()

        self.validation_step_outputs.append({'label': label, 'label_predicted': label_predicted, 'logits': logits})

        accuracy = accuracy_score(label, label_predicted)
        f1 = f1_score(label, label_predicted)
        precision = precision_score(label, label_predicted)
        recall = recall_score(label, label_predicted)

        self.log('val_acc', accuracy)
        self.log('val_f1', f1)
        self.log('val_precision', precision)
        self.log('val_recall', recall)

        return loss

    def on_validation_epoch_end(self):
        all_label = []
        all_preds = []
        all_pred_probs = []

        for output in self.validation_step_outputs:
            all_label.extend(output['label'])
            all_preds.extend(output['logits'])
            all_pred_probs.extend(output['label_predicted'])
        
        cm = confusion_matrix(all_label, all_preds)
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                self.log(f'val_cm_{i}_{j}', float(cm[i, j]))

        auc = roc_auc_score(all_label, all_pred_probs)
        self.log('val_AUC', auc)

        fpr, tpr, thresholds = roc_curve(all_label, all_pred_probs)

        plot_roc_curve(fpr, tpr, random_state, self.batch_size, self.max_length, self.fig_save_path, fig_save_name='test_roc')
        plot_fpr_threshold(thresholds, fpr, random_state, self.batch_size, self.max_length, self.fig_save_path, fig_save_name='test_fpr')
        plot_tpr_threshold(thresholds, tpr, random_state, self.batch_size, self.max_length, self.fig_save_path, fig_save_name='test_tpr')

        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        return optimizer

        


# In[26]:


def make_trainer(max_epochs, logger, name, patience):

    early_stop_callback = EarlyStopping(
    monitor='val_loss',    # 監視する値
    min_delta=0.00,        # 変化の最小量
    patience=patience,     # 改善が見られないエポック数
    verbose=False,
    mode='min'             # 'min' は値の減少を監視
    )

    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        mode='min',
        save_top_k=1,
        dirpath='model_checkpoints/',
        filename=name + '-{epoch:02d}-{val_loss:.2f}'
    )

    trainer = pl.Trainer(
        max_epochs=max_epochs,
        devices=1,
        accelerator='gpu',
        logger=logger,
        callbacks=[early_stop_callback, checkpoint_callback],
        enable_progress_bar=True,
        precision='16-mixed'
    )
    return trainer


# In[27]:


def make_kfold_dataloaders(df_drop, dataset, n_splits, batch_size, num_workers, pin_memory=True):
    kfold = StratifiedKFold(n_splits=n_splits, shuffle=True)
    labels = df_drop['label'].to_list()
    dataloaders_list = []

    for train_indices, val_indices in kfold.split(X=np.zeros(len(dataset)), y=labels):
        # Subsetを使って訓練セットと検証セットを作成
        train_subset = Subset(dataset, train_indices)
        val_subset = Subset(dataset, val_indices)

        # DataLoaderを作成
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, persistent_workers=True, num_workers=num_workers, pin_memory=pin_memory, collate_fn=collate_fn)
        val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, persistent_workers=True, num_workers=num_workers, pin_memory=pin_memory, collate_fn=collate_fn)

        # このイテレーションのデータローダーをリストに追加
        dataloaders_list.append((train_loader, val_loader))
    
    return dataloaders_list


# In[28]:


val_losses = []
accuracies, f1_scores, precisions, recalls = [], [], [], []
cms_0_0, cms_0_1, cms_1_0, cms_1_1 = [], [], [], []
aucs = []


# In[29]:


skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)


# In[30]:


dataset = VideoDataset(df_list, df, tokenizer, bert_model, max_length, comment_batch_size, num_layers, lstm_dropout, lstm_batch_size, lstm_hidden_size, j, frame_batch_size, num_heads, video_batch_size, d=768)


# In[31]:


dataloaders_list = make_kfold_dataloaders(df_drop, dataset, n_splits, batch_size, num_workers)


# In[32]:


for fold, (train_loader, val_loader) in enumerate(dataloaders_list):
    print('========')
    print(f"Fold {fold + 1}")
    print('========')

    model = FakeNewsDetector(tokenizer, bert_model, random_state, max_length, batch_size, num_workers, lr, n_splits, dropout_rate, lstm_dropout, input_size, hidden_dim, lstm_hidden_size, num_layers, bidirectional, num_heads, max_epochs, patience, fig_save_name, name, weight_decay, d=768)

    logger = TensorBoardLogger(
        save_dir="lightning_logs",
        name=name,
        version=f"Fold_{fold+1}"
        )
    
    trainer = make_trainer(max_epochs, logger, name, patience)

    trainer.fit(model, train_loader, val_loader)

    val_results = trainer.validate(model, val_loader)

    val_losses.append(val_results[0]['val_loss'])

    accuracies.append(val_results[0]['val_acc'])
    f1_scores.append(val_results[0]['val_f1'])
    precisions.append(val_results[0]['val_precision'])
    recalls.append(val_results[0]['val_recall'])

    cms_0_0.append(val_results[0]['val_cm_0_0'])
    cms_0_1.append(val_results[0]['val_cm_0_1'])
    cms_1_0.append(val_results[0]['val_cm_1_0'])
    cms_1_1.append(val_results[0]['val_cm_1_1'])

    aucs.append(val_results[0]['val_AUC'])


video_dataset.py

from make_comment_bilstm import BiLSTM, create_batches
from make_embedding import get_comment_embedding, get_title_desc_embedding, cal_cos_sim_video_embedding, initialize_vgg_19, cal_attn_weight_embedding, ThumbFrameDataset
from torch.utils.data import Dataset
import torch
import torch.nn as nn
from torch.nn import MultiheadAttention
from torch.utils.data import DataLoader
import pytorch_lightning as pl

class CommentProcessor(pl.LightningModule):
    def __init__(self, d, num_layers, lstm_dropout, lstm_batch_size, lstm_hidden_size=768//2):
        super(CommentProcessor, self).__init__()
        self.d = d
        # BiLSTMモデルの定義。BiLSTMは双方向のため、隠れ層サイズは2倍になる点に注意。
        self.bilstm = BiLSTM(input_size=768, hidden_size=lstm_hidden_size, num_layers=num_layers, dropout=lstm_dropout)
        self.lstm_batch_size = lstm_batch_size
        # LSTMの出力をさらに変換するための全結合層。2倍の隠れ層サイズから2*d次元へ変換。
        self.comment_fc = nn.Linear(2*lstm_hidden_size, 2*d)
        # このモデルで使用するデバイスをCUDAに設定。全てのサブモジュールもCUDAへ移動される。
    
    def forward(self, comment_embeddings, masks):
        # 入力データとマスクをCUDAデバイスへ明示的に移動。
        comment_embeddings = comment_embeddings.to('cuda')
        masks = masks.to('cuda')
        
        # バッチ処理用に入力データをバッチに分割。
        comment_batches, mask_batches = create_batches(comment_embeddings, masks, self.lstm_batch_size)
        # 出力を格納するリストを初期化。
        output = []
        # 分割したバッチごとに処理を実行。
        for comment_batch, mask_batch in zip(comment_batches, mask_batches):
            # マスクを使用して、各コメントの有効な長さを計算。
            lengths = mask_batch.sum(dim=1).long()
            # BiLSTMにバッチを入力し、出力を取得。入力は適切な形状に変形される。
            lstm_out = self.bilstm(comment_batch, lengths)
            # 出力リストに結果を追加。
            output.append(lstm_out)

        # 全バッチの出力を結合。
        comment_output = torch.cat(output, dim=0)
        # 全結合層を通じて出力の次元を調整。
        comment_output = self.comment_fc(comment_output)
        # コメントごとに平均値を計算し、特徴ベクトルを得る。
        comment_output_avg = torch.mean(comment_output, dim=0)

        return comment_output_avg
        # shape: (batch_size*num_comments*max_length, 2*d)

class TitleDescProcessor(nn.Module):
    def __init__(self, d, df, tokenizer, bert_model, max_length, batch_size):
        super(TitleDescProcessor, self).__init__()
        self.fc = nn.Linear(d, 2*d)
        self.df = df
        self.tokenizer = tokenizer
        self.bert_model = bert_model
        self.max_length = max_length
        self.batch_size = batch_size

    def forward(self, x):
        x = get_title_desc_embedding(self.df, self.tokenizer, self.bert_model, self.max_length, self.batch_size)
        # x shape: (batch_size, num_titles, max_length, embedding_dim)
        # 平均値プーリング
        x = torch.mean(x, dim=2)
        # x shape: (batch_size, num_titles, embedding_dim)
        x = self.fc(x)
        # x shape: (batch_size, num_titles, embedding_dim*2)
        # 平均値プーリング
        x = torch.mean(x, dim=1)

        return x
    
class GetJFrames(nn.Module):
    def __init__(self, frame_batch_size, j, video_batch_size, d):
        super(GetJFrames, self).__init__()
        self.j = j
        self.video_batch_size = video_batch_size
        self.frame_batch_size = frame_batch_size
        self.vgg_19 = initialize_vgg_19(d=d)  # vgg_19の初期化

    def forward(self, common_ids_list):
        self.vgg_19 = self.vgg_19.to('cuda')
        dataset = ThumbFrameDataset(common_ids_list)
        data_loader = DataLoader(dataset, self.video_batch_size)
        top_j_sim_video_embeddings_list = cal_cos_sim_video_embedding(data_loader, self.vgg_19, self.j, self.frame_batch_size)
        
        return top_j_sim_video_embeddings_list
    
class VideoProcessor(nn.Module):
    def __init__(self, video_batch_size, num_heads, d):
        super(VideoProcessor, self).__init__()
        self.attention = MultiheadAttention(embed_dim=d*2, num_heads=num_heads, batch_first=True)
        self.video_batch_size = video_batch_size
        self.video_fc = nn.Linear(2*d, 2*d)

    def forward(self, top_j_sim_video_embeddings_list):
        self.attention = self.attention.to('cuda')
        weighted_avg_video_embeddings = cal_attn_weight_embedding(self.attention, top_j_sim_video_embeddings_list)
        video_output = self.video_fc(weighted_avg_video_embeddings)
        video_output_avg = torch.mean(video_output, dim=1)

        return video_output_avg




class VideoDataset(Dataset):
    def __init__(self, df_list, df, tokenizer, bert_model, max_length, comment_batch_size, num_layers, lstm_dropout, lstm_batch_size, lstm_hidden_size, j, frame_batch_size, num_heads, video_batch_size, d=768):
        self.comment_processor = CommentProcessor(d, num_layers, lstm_dropout,  lstm_batch_size, lstm_hidden_size)  
        self.title_desc_processor = TitleDescProcessor(d, df, tokenizer, bert_model, max_length, batch_size=32)  
        self.get_j_frames = GetJFrames(j, frame_batch_size, video_batch_size=1, d=768)  
        self.video_processor = VideoProcessor(d, num_heads, video_batch_size)

        self.df_list = df_list
        self.tokenizer = tokenizer
        self.bert_model = bert_model
        self.max_length = max_length
        self.comment_batch_size = comment_batch_size

    def __len__(self):
        return len(self.df_list)
    
    def __getitem__(self, idx):
        # 各データをstackして返す
        df = self.df_list[idx]
        # 動画によってコメント数が違う→バッチサイズが異なる→スタックできない→padding
        comment_embeddings = get_comment_embedding(df, self.tokenizer, self.bert_model, self.max_length, self.comment_batch_size)
        comment_output_avg = self.comment_processor(comment_embeddings)
        hit_likes = torch.tensor(df['like_count'].values, dtype=torch.float16)
        title_desc_output_avg = self.title_desc_processor()
        top_j_sim_video_embeddings_list = self.get_j_frames()
        video_output_avg = self.video_processor(top_j_sim_video_embeddings_list)
        
        label = df['label'].values
        label = torch.tensor(label, dtype=torch.float16)
        # 自動的にlabelは(batch_size,)の形状にして渡される

        return comment_output_avg, hit_likes, title_desc_output_avg, video_output_avg, label

    
def collate_fn(batch):
    # バッチ内の全ての要素からコメントテンソルを取得し、最大のコメント数を計算
    max_comments = max([comments.size(0) for comments, _, _, _, _ in batch])
    padded_comments = []
    masks = []

    # バッチ内の各要素に対してパディングとマスク処理を実施
    for comments, hit_likes, title_desc_embeddings, video_output, label in batch:
        pad_size = max_comments - comments.size(0)
        mask = torch.ones(comments.size(0), dtype=torch.bool)
        if pad_size > 0:
            pad_tensor = torch.zeros(pad_size, comments.size(1), comments.size(2), dtype=comments.dtype)
            comments = torch.cat([comments, pad_tensor], dim=0)
            pad_mask = torch.zeros(pad_size, dtype=torch.bool)
            mask = torch.cat([mask, pad_mask], dim=0)
        padded_comments.append(comments)
        masks.append(mask)

    # リストをTensorに変換
    padded_comments_stack = torch.stack(padded_comments, dim=0)
    masks_stack = torch.stack(masks, dim=0)
    hit_likes = torch.stack([hit_likes for _, hit_likes, _, _, _ in batch], dim=0)
    title_desc_embeddings = torch.stack([title_desc_embeddings for _, _, title_desc_embeddings, _, _ in batch], dim=0)
    video_output_stack = torch.stack([video_output for _, _, _, video_output, _ in batch], dim=0)
    labels = torch.stack([label for _, _, _, _, label in batch], dim=0)
    
    return padded_comments_stack, masks_stack, hit_likes, title_desc_embeddings, video_output_stack, labels

この記事が気に入ったらサポートをしてみませんか?