forked from jittor/jittor
Merge branch 'develop' of https://github.com/Jittor/jittor
This commit is contained in:
commit
6e761f89da
|
@ -39,11 +39,17 @@ VarPtr CublasBatchedMatmulOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
// a [b,n,m] b [b,m,k], c[b,n,k]
|
||||
// c = a*b
|
||||
if (v_index == 0) {
|
||||
// da = dc*b^T
|
||||
return make_cublas_batched_matmul(dout, b, trans_a^0, trans_b^1);
|
||||
if (trans_a)
|
||||
return make_cublas_batched_matmul(b, dout, trans_b, 1);
|
||||
else
|
||||
// da = dc*b^T
|
||||
return make_cublas_batched_matmul(dout, b, 0, trans_b^1);
|
||||
} else {
|
||||
// db = a^T*dc
|
||||
return make_cublas_batched_matmul(a, dout, trans_a^1, trans_b^0);
|
||||
if (trans_b)
|
||||
return make_cublas_batched_matmul(dout, a, 1, trans_a);
|
||||
else
|
||||
// db = a^T*dc
|
||||
return make_cublas_batched_matmul(a, dout, trans_a^1, 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@
|
|||
|
||||
// CUDA and CUBLAS functions
|
||||
#include <helper_functions.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
|
||||
#ifndef min
|
||||
#define min(a,b) ((a < b) ? a : b)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cublas.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
|
||||
#ifdef CUBLAS_API_H_
|
||||
// cuBLAS API errors
|
||||
|
|
|
@ -65,7 +65,7 @@
|
|||
#include <assert.h>
|
||||
|
||||
#include <cudnn.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "fp16_dev.h"
|
||||
#include "fp16_emu.h"
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#include <cudnn.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
|
||||
const char *_cudaGetErrorEnum(cudnnStatus_t error) {
|
||||
return cudnnGetErrorString(error);
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include "init.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include <curand.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "curand_random_op.h"
|
||||
#include "curand_warper.h"
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cublas.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include <curand.h>
|
||||
|
||||
// cuRAND API errors
|
||||
|
|
|
@ -6,16 +6,12 @@
|
|||
#include "var.h"
|
||||
#include "cutt_transpose_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include <iostream>
|
||||
|
||||
#ifdef JIT
|
||||
#include "cutt.h"
|
||||
#endif
|
||||
#include "cutt_warper.h"
|
||||
#include "misc/stack_vector.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
static auto make_transpose = get_op_info("cutt_transpose")
|
||||
.get_constructor<VarPtr, Var*, NanoVector>();
|
||||
|
||||
|
@ -58,52 +54,49 @@ VarPtr CuttTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return make_transpose(dout, reverse);
|
||||
}
|
||||
|
||||
void CuttTransposeOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][DIM=") << JK::hex1(axes.size());
|
||||
for (uint i=0; i<axes.size(); i++)
|
||||
jk << _CS("][AXES") << JK::hex1(axes[i]) << '=' << JK::hex1(i);
|
||||
jk << ']';
|
||||
}
|
||||
unordered_map<string, unsigned int> cutt_plan_cache;
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
extern unordered_map<string, unsigned int> cutt_plan_cache;
|
||||
|
||||
void CuttTransposeOp::jit_run() {
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
vector<int> permutation, permutation2;
|
||||
vector<int> y_shape;
|
||||
vector<int> x_shape;
|
||||
@for(i, 0, DIM, permutation.push_back(DIM-1-AXES@i);)
|
||||
@for(i, 0, DIM, permutation2.push_back(permutation[DIM-1-@i@@]);)
|
||||
std::vector<int> reverse;
|
||||
reverse.reserve(permutation2.size());
|
||||
for (uint i=0; i<permutation2.size(); i++)
|
||||
reverse[permutation2[i]] = i;
|
||||
|
||||
@for(i, 0, DIM, x_shape.push_back(x->shape[DIM-1-@i@@]);)
|
||||
|
||||
void CuttTransposeOp::run() {
|
||||
auto* __restrict__ xp = x->mem_ptr;
|
||||
auto* __restrict__ yp = y->mem_ptr;
|
||||
StackVector<int> x_shape;
|
||||
StackVector<int> new_shape, new_axes, trans, reverse;
|
||||
int dim = x->shape.size();
|
||||
for (int i=0; i<dim; i++) {
|
||||
trans[i] = new_shape.size();
|
||||
if (x->shape[i] != 1)
|
||||
new_shape.push_back(x->shape[i]);
|
||||
}
|
||||
for (int i = 0; i < dim; ++i) {
|
||||
if (x->shape[axes[i]] != 1) {
|
||||
new_axes.push_back(trans[axes[i]]);
|
||||
}
|
||||
}
|
||||
dim = new_shape.size();
|
||||
for (int i=0; i<dim; i++)
|
||||
reverse[i] = dim-1-new_axes[dim-1-i];
|
||||
for (int i=0; i<dim; i++)
|
||||
x_shape[i] = new_shape[dim-1-i];
|
||||
if (dim == 1) {
|
||||
checkCudaErrors(cudaMemcpyAsync(yp, xp, x->size, cudaMemcpyDefault, 0));
|
||||
return;
|
||||
}
|
||||
jk.clear();
|
||||
jk << @DIM << ",";
|
||||
for (uint i=0; i<@DIM; i++) jk << x_shape[i] << ",";
|
||||
for (uint i=0; i<@DIM; i++) jk << reverse[i] << ",";
|
||||
jk << sizeof(Tx) << ".";
|
||||
jk << dim << ',';
|
||||
for (int i=0; i<dim; i++) jk << x_shape[i] << ',';
|
||||
for (int i=0; i<dim; i++) jk << reverse[i] << ',';
|
||||
jk << x->dtype().dsize() << '.';
|
||||
auto iter = cutt_plan_cache.find(jk.to_string());
|
||||
LOGvvv << "Run cutt_transpose with key:" << jk.to_string();
|
||||
|
||||
if (iter!=cutt_plan_cache.end()){
|
||||
cuttExecute(iter->second, xp, yp);
|
||||
} else {
|
||||
cuttHandle plan;
|
||||
cuttPlan(&plan, @DIM, x_shape.data(), reverse.data(), sizeof(Tx), 0);
|
||||
cuttPlan(&plan, dim, x_shape.data(), reverse.data(), x->dtype().dsize(), 0);
|
||||
cutt_plan_cache[jk.to_string()] = plan;
|
||||
cuttExecute(plan, xp, yp);
|
||||
}
|
||||
}
|
||||
#endif // JIT_cuda
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -19,7 +19,7 @@ struct CuttTransposeOp : Op {
|
|||
const char* name() const override { return "cutt_transpose"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
void run() override;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -101,11 +101,17 @@ const char *_cudaGetErrorEnum(NppStatus error);
|
|||
#endif
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
extern bool peek_logged;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void peek(T result, char const *const func, const char *const file,
|
||||
int const line) {
|
||||
if (result) {
|
||||
// DEVICE_RESET
|
||||
if (jittor::peek_logged) return;
|
||||
jittor::peek_logged = 1;
|
||||
LOGe << "Peek CUDA error at" << file >> ":" >> line << " code="
|
||||
>> static_cast<unsigned int>(result) >> "(" << _cudaGetErrorEnum(result) << ")"
|
||||
<< func;
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
#include <cuda_runtime.h>
|
||||
#include <nccl.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
#include <nccl.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "nccl_warper.h"
|
||||
#include "ops/op_register.h"
|
||||
namespace jittor {
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
#include <nccl.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "nccl_warper.h"
|
||||
#include "ops/op_register.h"
|
||||
namespace jittor {
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
#include <nccl.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "nccl_warper.h"
|
||||
#include "ops/op_register.h"
|
||||
namespace jittor {
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
// These are CUDA Helper functions for initialization and error checking
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
|
||||
|
||||
#ifdef _CUFFT_H_
|
||||
|
|
|
@ -884,6 +884,11 @@ Var.__int__ = to_int
|
|||
Var.__float__ = to_float
|
||||
Var.__bool__ = to_bool
|
||||
|
||||
def format(v, spec):
|
||||
return v.item().__format__(spec)
|
||||
Var.__format__ = format
|
||||
|
||||
|
||||
int = int32
|
||||
Var.int = Var.int32
|
||||
float = float32
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor import pool
|
||||
from collections.abc import Sequence
|
||||
|
||||
def argmax_pool(x, size, stride, padding=0):
|
||||
return pool.pool(x, size, 'maximum', padding, stride)
|
||||
|
@ -180,8 +181,14 @@ jt.Var.__setitem__ = setitem
|
|||
def getitem(x, slices):
|
||||
if isinstance(slices, jt.Var) and slices.dtype == "bool":
|
||||
return getitem(x, slices.where())
|
||||
if isinstance(slices, list):
|
||||
slices = tuple(slices)
|
||||
if isinstance(slices, Sequence):
|
||||
ss = []
|
||||
for s in slices:
|
||||
if isinstance(s, jt.Var) and s.dtype == "bool":
|
||||
ss.extend(s.where())
|
||||
else:
|
||||
ss.append(s)
|
||||
slices = tuple(ss)
|
||||
return x.getitem(slices)
|
||||
|
||||
def setitem(x, slices, value):
|
||||
|
|
|
@ -12,6 +12,35 @@ import numpy as np
|
|||
import math
|
||||
from collections.abc import Sequence,Iterable
|
||||
|
||||
def __copy__(x):
|
||||
return x.copy().detach()
|
||||
jt.Var.__copy__ = __copy__
|
||||
|
||||
def __deepcopy__(x,memo):
|
||||
result = x.copy().detach()
|
||||
memo[id(x)]=result
|
||||
return result
|
||||
jt.Var.__deepcopy__ = __deepcopy__
|
||||
|
||||
def __len__(x):
|
||||
return x.shape[0]
|
||||
jt.Var.__len__ = __len__
|
||||
|
||||
def __iter__(x):
|
||||
result = []
|
||||
for i in range(x.shape[0]):
|
||||
result.append(x[i])
|
||||
return result.__iter__()
|
||||
jt.Var.__iter__ = __iter__
|
||||
|
||||
def all(x,dim):
|
||||
return x.all_(dim).bool()
|
||||
jt.Var.all = all
|
||||
|
||||
def any(x,dim):
|
||||
return x.any_(dim).bool()
|
||||
jt.Var.any = any
|
||||
|
||||
|
||||
def repeat(x, *shape):
|
||||
r'''
|
||||
|
@ -47,10 +76,24 @@ def repeat(x, *shape):
|
|||
x = x.broadcast(x_shape)
|
||||
elif len_x_shape > len_shape:
|
||||
rep_shape = (len_x_shape - len_shape) * [1] + shape
|
||||
|
||||
reshape_shape = []
|
||||
broadcast_shape = []
|
||||
for x_s,r_s in zip(x_shape,rep_shape):
|
||||
reshape_shape.append(1)
|
||||
reshape_shape.append(x_s)
|
||||
|
||||
broadcast_shape.append(r_s)
|
||||
broadcast_shape.append(1)
|
||||
|
||||
x = x.reshape(reshape_shape)
|
||||
x = x.broadcast(broadcast_shape)
|
||||
|
||||
tar_shape = (np.array(x_shape) * np.array(rep_shape)).tolist()
|
||||
dims = []
|
||||
for i in range(len(tar_shape)): dims.append(f"i{i}%{x_shape[i]}")
|
||||
return x.reindex(tar_shape, dims)
|
||||
|
||||
x = x.reshape(tar_shape)
|
||||
return x
|
||||
|
||||
jt.Var.repeat = repeat
|
||||
|
||||
def chunk(x, chunks, dim=0):
|
||||
|
@ -326,9 +369,8 @@ def unique(x):
|
|||
'''
|
||||
x = x.reshape(-1)
|
||||
_,x = jt.argsort(x)
|
||||
index2 = [i for i in range(1,x.shape[0])]
|
||||
index1 = [i for i in range(x.shape[0]-1)]
|
||||
y = x[1:][x[index2] != x[index1]]
|
||||
index,= jt.index((x.shape[0],))
|
||||
y = x[1:][x[index[1:]] != x[index[:-1]]]
|
||||
x = jt.contrib.concat([x[:1],y],dim=0)
|
||||
return x
|
||||
|
||||
|
@ -401,12 +443,6 @@ def log2(x):
|
|||
|
||||
jt.Var.log2 = log2
|
||||
|
||||
def item(x):
|
||||
assert x.ndim==1 and x.shape[0]==1
|
||||
return x.numpy().item()
|
||||
|
||||
jt.Var.item = item
|
||||
|
||||
def meshgrid(*tensors):
|
||||
r'''
|
||||
Take N tensors, each of which can be 1-dimensional vector, and create N n-dimensional grids,
|
||||
|
|
|
@ -264,17 +264,29 @@ class L1Loss(Module):
|
|||
def execute(self, output, target):
|
||||
return l1_loss(output, target)
|
||||
|
||||
class BCEWithLogitsLoss(Module):
|
||||
def __init__(self, weight=None, size_average=True):
|
||||
self.sigmoid = Sigmoid()
|
||||
self.bce = BCELoss(weight, size_average)
|
||||
def execute(self, output, target):
|
||||
output = self.sigmoid(output)
|
||||
output = self.bce(output, target)
|
||||
return output
|
||||
def binary_cross_entropy_with_logits(output, target, weight=None, pos_weight=None, size_average=True):
|
||||
max_val = jt.clamp(-output,min_v=0)
|
||||
if pos_weight is not None:
|
||||
log_weight = (pos_weight-1)*target + 1
|
||||
loss = (1-target)*output+(log_weight*(((-max_val).exp()+(-output - max_val).exp()).log()+max_val))
|
||||
else:
|
||||
loss = (1-target)*output+max_val+((-max_val).exp()+(-output -max_val).exp()).log()
|
||||
if weight is not None:
|
||||
loss *=weight
|
||||
|
||||
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True):
|
||||
return BCEWithLogitsLoss(weight, size_average)(input, target)
|
||||
if size_average:
|
||||
return loss.mean()
|
||||
else:
|
||||
return loss.sum()
|
||||
|
||||
class BCEWithLogitsLoss(Module):
|
||||
def __init__(self, weight=None, pos_weight=None, size_average=True):
|
||||
self.pos_weight = pos_weight
|
||||
self.weight = weight
|
||||
self.size_average = size_average
|
||||
|
||||
def execute(self, output, target):
|
||||
return binary_cross_entropy_with_logits(output,target,self.weight,self.pos_weight,self.size_average)
|
||||
|
||||
def softmax(x, dim = None):
|
||||
if dim is None:
|
||||
|
@ -408,13 +420,14 @@ class LayerNorm(Module):
|
|||
|
||||
def execute(self, x):
|
||||
dims = [-i for i in range(len(self.normalized_shape), 0, -1)]
|
||||
xmean = jt.mean(x, dims=dims)
|
||||
x2mean = jt.mean(x*x, dims=dims)
|
||||
xmean = jt.mean(x, dims=dims, keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=dims, keepdims=1)
|
||||
|
||||
xvar = (x2mean-xmean*xmean).maximum(0.0)
|
||||
w = self.weight / jt.sqrt(xvar+self.eps)
|
||||
b = self.bias - xmean * w
|
||||
return x * w.broadcast(x, dims) + b.broadcast(x, dims)
|
||||
return x * w + b
|
||||
|
||||
|
||||
LayerNorm2d = LayerNorm1d = LayerNorm
|
||||
|
||||
|
|
|
@ -210,3 +210,64 @@ class Adam(Optimizer):
|
|||
v.update(b1 * v + (1-b1) * g * g)
|
||||
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
|
||||
p.update(p - m * step_size / (jt.sqrt(v) + eps))
|
||||
|
||||
|
||||
class LRScheduler:
|
||||
def __init__(self,optimizer, last_epoch=-1):
|
||||
assert isinstance(optimizer,Optimizer)
|
||||
self.optimizer = optimizer
|
||||
|
||||
if last_epoch==-1:
|
||||
for gp in optimizer.param_groups:
|
||||
gp.setdefault('initial_lr',gp.get('lr',optimizer.lr))
|
||||
else:
|
||||
for gp in optimizer.param_groups:
|
||||
assert 'initial_lr' in gp
|
||||
|
||||
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
|
||||
self.last_epoch = last_epoch
|
||||
self.optimizer._step_count = 0
|
||||
self._step_count = 0
|
||||
self.step()
|
||||
|
||||
def get_lr(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_last_lr(self):
|
||||
return self._last_lr
|
||||
|
||||
def step(self,epoch=None):
|
||||
self._step_count += 1
|
||||
|
||||
if epoch is None:
|
||||
self.last_epoch += 1
|
||||
values = self.get_lr()
|
||||
else:
|
||||
self.last_epoch = epoch
|
||||
values = self.get_lr()
|
||||
|
||||
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
||||
param_group, lr = data
|
||||
param_group['lr'] = lr
|
||||
|
||||
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
|
||||
|
||||
|
||||
class LambdaLR(LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, lr_lambda, last_epoch=-1):
|
||||
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
|
||||
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
|
||||
else:
|
||||
if len(lr_lambda) != len(optimizer.param_groups):
|
||||
raise ValueError("Expected {} lr_lambdas, but got {}".format(len(optimizer.param_groups), len(lr_lambda)))
|
||||
|
||||
self.lr_lambdas = list(lr_lambda)
|
||||
|
||||
super(LambdaLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
|
||||
|
||||
def get_lr(self):
|
||||
return [base_lr * lmbda(self.last_epoch)
|
||||
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
|
|
@ -30,7 +30,7 @@ class TestCuttTransposeOp(unittest.TestCase):
|
|||
for perm in perms:
|
||||
with jt.log_capture_scope(
|
||||
log_silent=1,
|
||||
log_v=0, log_vprefix="op.cc=100"
|
||||
log_v=0, log_vprefix="cutt=100"
|
||||
) as raw_log:
|
||||
if perm:
|
||||
x = np.transpose(a, perm)
|
||||
|
@ -39,7 +39,7 @@ class TestCuttTransposeOp(unittest.TestCase):
|
|||
x = np.transpose(a)
|
||||
y = jt.transpose(a).data
|
||||
self.assertEqual(x.shape, y.shape)
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "cutt_transpose" + ".*)")
|
||||
logs = find_log_with_re(raw_log, "(Run cutt_transpose with key.*)")
|
||||
if perm is None:
|
||||
continue
|
||||
last = -1
|
||||
|
@ -53,7 +53,7 @@ class TestCuttTransposeOp(unittest.TestCase):
|
|||
last = perm[i]
|
||||
if not in_order:
|
||||
assert len(logs)==1
|
||||
assert (x==y).all(), f"\n{x}\n{y}"
|
||||
assert (x==y).all(), f"\n{x}\n{y}\n{perm}\n{a.shape}"
|
||||
|
||||
ia = [gen_data([5, 7]), gen_data([2,2,2]), gen_data([2,3,4,5]), gen_data([5,3]), gen_data([3,1,5,3,1])]
|
||||
for a in ia: check(a)
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include <functional>
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "mem/allocator/cuda_dual_allocator.h"
|
||||
#include "event_queue.h"
|
||||
#endif
|
||||
|
@ -446,6 +446,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
// record trace data
|
||||
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var==2)) {
|
||||
trace_data.record_execution(op, is_fused_op, jkl);
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
|
||||
}
|
||||
LOGvvv << "Finished Op(" >> op->name() << rid >>
|
||||
"/" >> queue.size() >> ") output:" << op->outputs();
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include "mem/allocator.h"
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
|
|
@ -177,7 +177,8 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
Var* dout = grads[id];
|
||||
trace_grad_op = op;
|
||||
VarPtr dvar = make_grad(op, out, dout, var, index);
|
||||
if (dvar && dvar->num>=0 && var->num)
|
||||
if (dvar && dvar->num>=0 && var->num>0)
|
||||
// var->num == 0 represents a any match var
|
||||
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
|
||||
<< "dvar" << dvar << "var" << var;
|
||||
if (!grad)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
// ***************************************************************
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#endif
|
||||
#include <random>
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
// ***************************************************************
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "mem/allocator/cuda_device_allocator.h"
|
||||
|
||||
namespace jittor {
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include <mutex>
|
||||
#include <cstring>
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "var.h"
|
||||
#include "mem/allocator.h"
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
// ***************************************************************
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "mem/allocator/cuda_host_allocator.h"
|
||||
|
||||
namespace jittor {
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
// ***************************************************************
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "mem/allocator/cuda_managed_allocator.h"
|
||||
|
||||
namespace jittor {
|
||||
|
|
|
@ -17,6 +17,7 @@ struct StackVector {
|
|||
inline T& front() { return a[0]; }
|
||||
inline T& back() { return a[n-1]; }
|
||||
inline int size() { return n;}
|
||||
inline T* data() { return a;}
|
||||
inline StackVector(int n=0) : n(n) {}
|
||||
|
||||
struct Iter {
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
// ***************************************************************
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "mem/allocator/cuda_dual_allocator.h"
|
||||
#include "event_queue.h"
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
// ***************************************************************
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#endif
|
||||
#include <algorithm>
|
||||
#include "var.h"
|
||||
|
|
|
@ -10,7 +10,8 @@
|
|||
#include "ops/copy_op.h"
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
@ -36,14 +37,14 @@ void CopyOp::run() {
|
|||
auto size = x->size;
|
||||
auto x_ptr = x->mem_ptr;
|
||||
auto y_ptr = outputs().front()->mem_ptr;
|
||||
if (flags.get(NodeFlags::_cpu)) {
|
||||
#ifdef HAS_CUDA
|
||||
if (flags.get(NodeFlags::_cuda)) {
|
||||
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0));
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
std::memcpy(y_ptr, x_ptr, size);
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
else {
|
||||
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
// ***************************************************************
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include <mutex>
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator/sfrl_allocator.h"
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#include "ops/op_register.h"
|
||||
#ifdef JIT_cuda
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#endif
|
||||
#ifndef JIT
|
||||
#include "misc/stack_vector.h"
|
||||
|
|
|
@ -34,9 +34,9 @@ unordered_set<string> reduce_ops = {
|
|||
"add",
|
||||
// @pybind(prod, product, reduce_multiply)
|
||||
"multiply",
|
||||
// @pybind(reduce_logical_and, all)
|
||||
// @pybind(reduce_logical_and, all_)
|
||||
"logical_and",
|
||||
// @pybind(reduce_logical_or, any)
|
||||
// @pybind(reduce_logical_or, any_)
|
||||
"logical_or",
|
||||
"logical_xor",
|
||||
"bitwise_and",
|
||||
|
@ -65,7 +65,8 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
|
|||
reduce_mask |= 1<<dim;
|
||||
}
|
||||
}
|
||||
if (x->dtype() == ns_bool && ns == ns_add)
|
||||
// if (x->dtype() == ns_bool && ns == ns_add)
|
||||
if (x->dtype() == ns_bool)
|
||||
y = create_output(nullptr, ns_int32);
|
||||
else
|
||||
y = create_output(nullptr, binary_dtype_infer(ns, x, x));
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "ops/binary_op_defs.h"
|
||||
#ifdef JIT_cuda
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#endif
|
||||
#else
|
||||
#include "ops/op_register.h"
|
||||
|
@ -69,7 +69,7 @@ void SetitemOp::infer_shape() {
|
|||
for (int i=0; i<data_dim; i++) {
|
||||
int j = i - data_dim + out_shape.size();
|
||||
if (!(data_shape[i]==1 && out_shape[j]!=-1)) {
|
||||
CHECK(data_shape[i]<0 || data_shape[i]==out_shape[j])
|
||||
CHECK(data_shape[i]<0 || out_shape[j]<0 || data_shape[i]==out_shape[j])
|
||||
<< "Data shape not match" << data_shape << out_shape;
|
||||
bmask |= 1<<j;
|
||||
}
|
||||
|
|
|
@ -40,38 +40,8 @@ TransposeOp::TransposeOp(Var* x, NanoVector axes_) : x(x), axes(axes_) {
|
|||
.get_constructor<VarPtr, Var*, NanoVector>();
|
||||
}
|
||||
if (cutt_transpose) {
|
||||
bool need_reshape = false;
|
||||
int dims = x->shape.size();
|
||||
vector<int64> in_axes;
|
||||
vector<int64> in_shape;
|
||||
vector<int64> out_shape;
|
||||
vector<int64> trans;
|
||||
int cnt = 0;
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
if (x->shape[i] == 1) {
|
||||
need_reshape = true;
|
||||
trans.push_back(-1);
|
||||
} else {
|
||||
trans.push_back(cnt);
|
||||
cnt += 1;
|
||||
in_shape.push_back(x->shape[i]);
|
||||
}
|
||||
out_shape.push_back(x->shape[axes[i]]);
|
||||
}
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
if (x->shape[axes[i]] != 1) {
|
||||
in_axes.push_back(trans[axes[i]]);
|
||||
}
|
||||
}
|
||||
if (need_reshape) {
|
||||
auto x1 = make_reshape(x, NanoVector(in_shape));
|
||||
auto x2 = cutt_transpose(x1, in_axes);
|
||||
auto x3 = make_reshape(x2, NanoVector(out_shape));
|
||||
forward(x3);
|
||||
} else {
|
||||
auto var = cutt_transpose(x, axes);
|
||||
forward(var);
|
||||
}
|
||||
auto var = cutt_transpose(x, axes);
|
||||
forward(var);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#include <dlfcn.h>
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#endif
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "profiler/profiler.h"
|
||||
|
|
|
@ -164,6 +164,19 @@ static vector<Stack> get_stack_info() {
|
|||
(int)PyFrame_GetLineNumber(prev_f)});
|
||||
}
|
||||
}
|
||||
if (stacks.size() == 0) {
|
||||
auto m = std::min(3,n);
|
||||
for (int i=0; i<m; i++) {
|
||||
auto f = frames[n-m+i];
|
||||
auto s = to_string(f->f_code->co_filename);
|
||||
auto num = (int)PyFrame_GetLineNumber(f);
|
||||
stacks.emplace_back(Stack{
|
||||
s+":"+S(num),
|
||||
"",
|
||||
s,
|
||||
num});
|
||||
}
|
||||
}
|
||||
return stacks;
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
// ***************************************************************
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "mem/allocator/cuda_dual_allocator.h"
|
||||
#include "event_queue.h"
|
||||
|
|
|
@ -23,7 +23,7 @@ static void push_py_object_pickle(RingBuffer* rb, PyObject* obj, uint64& __restr
|
|||
ASSERT(0 == PyBytes_AsStringAndSize(ret.obj, &s, &size));
|
||||
rb->push_t<int64>(size, offset);
|
||||
rb->push(size, offset);
|
||||
LOGir << string(rb->get_ptr(size, offset), size);
|
||||
// LOGir << string(rb->get_ptr(size, offset), size);
|
||||
std::memcpy(rb->get_ptr(size, offset), s, size);
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -13,7 +13,8 @@
|
|||
#include "utils/mwsr_list.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
||||
bool peek_logged = 0;
|
||||
typedef uint32_t uint;
|
||||
using string = std::string;
|
||||
using stringstream = std::stringstream;
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
#include <sstream>
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "helper_cuda.h"
|
||||
#endif
|
||||
#include "var_holder.h"
|
||||
#include "var.h"
|
||||
|
|
Loading…
Reference in New Issue