taichi.ad
#
- class taichi.ad.FwdMode(loss, param, seed=None, clear_gradients=True)#
- clear_seed(self)#
- insert(self, func)#
- recover_kernels(self)#
- class taichi.ad.Tape(loss=None, clear_gradients=True, validation=False, grad_check=None)#
- grad(self)#
- insert(self, func, args)#
- taichi.ad.clear_all_gradients(gradient_type=SNodeGradType.ADJOINT)#
Sets the gradients of all fields to zero.
- taichi.ad.grad_for(primal)#
Generates a decorator to decorate primal’s customized gradient function.
See
grad_replaced()
for examples.- Parameters:
primal (Callable) – The primal function, must be decorated by
grad_replaced()
.- Returns:
The decorator used to decorate customized gradient function.
- Return type:
Callable
- taichi.ad.grad_replaced(func)#
A decorator for python function to customize gradient with Taichi’s autodiff system, e.g. ti.ad.Tape() and kernel.grad().
This decorator forces Taichi’s autodiff system to use a user-defined gradient function for the decorated function. Its customized gradient must be decorated by
grad_for()
.- Parameters:
fn (Callable) – The python function to be decorated.
- Returns:
The decorated function.
- Return type:
Callable
Example:
>>> @ti.kernel >>> def multiply(a: ti.float32): >>> for I in ti.grouped(x): >>> y[I] = x[I] * a >>> >>> @ti.kernel >>> def multiply_grad(a: ti.float32): >>> for I in ti.grouped(x): >>> x.grad[I] = y.grad[I] / a >>> >>> @ti.ad.grad_replaced >>> def foo(a): >>> multiply(a) >>> >>> @ti.ad.grad_for(foo) >>> def foo_grad(a): >>> multiply_grad(a)
- taichi.ad.no_grad(func)#
A decorator for python function to skip gradient calculation within Taichi’s autodiff system, e.g. ti.ad.Tape() and kernel.grad(). This decorator forces Taichi’s autodiff system to use an empty gradient function for the decorated function.
- Parameters:
fn (Callable) – The python function to be decorated.
- Returns:
The decorated function.
- Return type:
Callable
Example:
>>> @ti.kernel >>> def multiply(a: ti.float32): >>> for I in ti.grouped(x): >>> y[I] = x[I] * a >>> >>> @ti.no_grad >>> def foo(a): >>> multiply(a)