Estoy tratando de construir un modelo para la Quora preguntas par de conjunto de datos en el que la salida es binaria, ya sea 1 o 0, pero me sale este error. Sé que la salida de la forma de mi modelo es diferente de la de entrada de la forma, pero no sé cómo solucionarlo. El tamaño del lote es de 16
class Bert_model (nn.Module):
def __init__(self) :
super(Bert_model,self).__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased', return_dict=False)
self.drop_layer = nn.Dropout(.25)
self.output = nn.Linear(self.bert.config.hidden_size,1)
def forward(self,input_ids,attention_mask):
_,o2 = self.bert (input_ids =input_ids , attention_mask = attention_mask )
o2 = self.drop_layer(o2)
return self.output(o2)
model = Bert_model()
loss_fn = nn.BCELoss().to(device)
def train_epoch(
model,
data_loader,
loss_fn,
optimizer,
device,
n_examples
):
model = model.train()
losses = []
correct_predictions = 0
for d in data_loader:
input_ids = d["input_ids"].to(device)
attention_mask = d["attention_mask"].to(device)
targets = d["target"].to(device)
input_ids = input_ids.view(BATCH_SIZE,-1)
attention_mask = attention_mask.view(BATCH_SIZE,-1)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
_, preds = torch.max(outputs, dim=1)
targets = targets.unsqueeze(-1)
loss = loss_fn(F.softmax(outputs,dim=1), targets)
correct_predictions += torch.sum(preds == targets)
losses.append(loss.item())
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
return correct_predictions.double() / n_examples, np.mean(losses)
El error:
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in
binary_cross_entropy(input, target, weight, size_average, reduce,
reduction) 2913 weight = weight.expand(new_size) 2914
-> 2915 return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) 2916 2917 ValueError: Using a target
size (torch.Size([2, 1])) that is different to the input size
(torch.Size([16, 1])) is deprecated
BCELoss
conCrossEntropyLoss
que se supone que es para problemas de clasificación binaria.