Tengo dos variables, x
y theta
. Estoy tratando de minimizar mi pérdida con respecto a theta
sólo, sino como parte de mi función de pérdida necesito la derivada de una función diferente (f
) con respecto a x
. Este derivado de la misma no es relevante para la minimización, sólo su salida. Sin embargo, cuando la aplicación esta en PyTorch estoy recibiendo un error en tiempo de ejecución.
Un ejemplo mínimo es el siguiente:
# minimal example of two different autograds
import torch
from torch.autograd.functional import jacobian
def f(theta, x):
return torch.sum(theta * x ** 2)
def df(theta, x):
J = jacobian(lambda x: f(theta, x), x)
return J
# example evaluations of the autograd gradient
x = torch.tensor([1., 2.])
theta = torch.tensor([1., 1.], requires_grad = True)
# derivative should be 2*theta*x (same as an analytical)
with torch.no_grad():
print(df(theta, x))
print(2*theta*x)
tensor([2., 4.])
tensor([2., 4.])
# define some arbitrary loss as a fn of theta
loss = torch.sum(df(theta, x)**2)
loss.backward()
da el siguiente error
RuntimeError: el elemento 0 de los tensores no requiere de posgrado y no tiene un grad_fn
Si me ofrecen una analítica de derivados (2*theta*x
), funciona bien:
loss = torch.sum((2*theta*x)**2)
loss.backward()
Es allí una manera de hacer esto en PyTorch? O estoy limitado de alguna manera?
Déjeme saber si alguien necesita más detalles.
PS
Me estoy imaginando la solución es algo similar a la forma en que JAX no autograd, como que es lo que estoy más familiarizado. Lo que quiero decir aquí es que en JAX creo que usted acaba de hacer:
from jax import grad
df = grad(lambda x: f(theta, x))
y, a continuación, df
sólo sería una función que puede ser llamado en cualquier momento. Pero es PyTorch la misma? O es que hay algún conflicto dentro de .backward()
que la causa de este error?
create_graph
argumento, porque no quiero que sea incluido en mi.backward()
de la llamada. En este caso, ¿por qué me da un error? No entiendo el mensaje de error.