Commit Graph

47 Commits

Author SHA1 Message Date
Zheng Li 4a52aeb437
bert attention mask (#1934)
* bert attention mask

* Allow for using None as a mask.

* Revert part of the changes so that the proper default mask applies.

* Cosmetic change.

* Another cosmetic tweak.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-08-01 08:26:19 +02:00
Juarez Bochi 9bd94c1ffa
Speed up bert with approx gelu (#1410) 2023-12-06 17:46:37 +01:00
Laurent Mazare 8a82d623e5
Handle LongStorage in pytorch checkpoints. (#1152) 2023-10-22 18:34:36 +01:00
Laurent Mazare bb3471ea31
Adapt more examples to the updated safetensor api. (#947)
* Simplify the safetensor usage.

* Convert more examples.

* Move more examples.

* Adapt stable-diffusion.
2023-09-23 21:26:03 +01:00
Laurent Mazare d3f05eae8c
Move some models to candle-transformers so that it's easier to re-use. (#794)
* Move some models to candle-transformers so that they can be shared.

* Also move falcon.

* Move Llama.

* Move whisper (partial).
2023-09-10 09:40:27 +01:00
Nicolas Patry 1aca6fa291 Upgrading hf-hub. 2023-08-29 14:18:54 +02:00
Laurent Mazare a1812f934f
Add a yolo-v3 example. (#528)
* Add a couple functions required for yolo.

* Add the yolo-v3 example.

* Add minimum and maximum.

* Use the newly introduced maximum.

* Cuda support for min/max + add some testing.

* Allow for more tests to work with accelerate.

* Fix a typo.
2023-08-20 18:19:37 +01:00
Laurent Mazare c84883ecf2
Add a cuda kernel for upsampling. (#441)
* Add a cuda kernel for upsampling.

* Update for the latest tokenizers version.
2023-08-14 13:12:17 +01:00
Laurent Mazare 385f0d261c
Normalize embeddings in the bert example. (#390) 2023-08-10 13:05:55 +01:00
Nicolas Patry ca479a873e Upgrading hf-hub to `0.2.0` (Modified API to not pass the Repo around
all the time)
2023-07-27 20:05:02 +02:00
Laurent Mazare 43c7223292
Rename the .r functions to .dims so as to be a bit more explicit. (#220) 2023-07-22 10:39:27 +01:00
Nicolas Patry 439321745a Removing `candle-hub` internal to extract into `hf-hub` standalone. 2023-07-19 15:04:38 +02:00
Laurent Mazare ff61a42ad7
Use mkl to accelerate binary ops. (#190)
* Vectorized binary ops with mkl.

* Improve the binary op mkl support.

* Push the support for mkl binary ops.

* Proper vectorization of binary ops.

* Proper mkl'isation when broadcasting binary ops.
2023-07-18 12:04:39 +01:00
Laurent Mazare f0cccd08f0
Bert tracing (#184)
* Add some tracing to bert.

* More tracing.

* Add a flag for tracing.
2023-07-17 19:40:42 +01:00
Laurent Mazare 66750f9827
Add some 'cuda-if-available' helper function. (#172) 2023-07-15 08:25:15 +01:00
Nicolas Patry 4ed56d7861 Removing cuda default.
Seems very important for a lot of exploring users usually on laptop
without GPUs.

Adding more README instructions in a follow up.
2023-07-14 16:52:15 +02:00
Laurent Mazare a2f72edc0d
Simplify the parameters used by sum and sum_keepdim. (#165) 2023-07-14 08:22:08 +01:00
Laurent Mazare 2bfa791336
Use the same default as pytorch for sum. (#164) 2023-07-13 21:32:32 +01:00
Laurent Mazare 50b0946a2d
Tensor mutability (#154)
* Working towards tensor mutability.

* Use a ref-cell to provide tensor mutability.
2023-07-13 11:04:40 +01:00
Laurent Mazare 674eb35e10
Remove some dead-code pragmas. (#137) 2023-07-11 09:33:59 +01:00
Laurent Mazare b46c28a2ac
VarBuilder path creation (#131)
* Use a struct for the safetensor+routing.

* Group the path and the var-builder together.

* Fix for the empty path case.
2023-07-10 22:37:34 +01:00
Laurent Mazare 1aa7fbbc33
Move the var-builder in a central place. (#130) 2023-07-10 20:49:50 +01:00
Laurent Mazare b06e1a7e54
[nn] Move the Embedding and Activation parts. (#116)
* Share the Embedding and Activation parts.

* Tweak some activations.
2023-07-10 10:24:52 +01:00
Laurent Mazare 9ce0f1c010
Sketch the candle-nn crate. (#115)
* Sketch the candle-nn crate.

* Tweak the cuda dependencies.

* More cuda tweaks.
2023-07-10 08:50:09 +01:00
Nicolas Patry 0a2c82e301
Merge pull request #92 from LaurentMazare/sync_hub
Creating new sync Api for `candle-hub`.
2023-07-07 00:10:47 +02:00
Nicolas Patry 115629fe08 Creating new sync Api for `candle-hub`.
- `api::Api` -> `api::tokio::api` (And created new `api::sync::Api`).
- Remove `tokio` from all our examples.
- Using similar codebase for now instead of ureq (for simplicity).
2023-07-06 15:15:25 +02:00
Nicolas Patry 3f291bdf9d Enabling `roberta` for the example (it's the same model as Bert, with
just different naming.)
2023-07-06 13:25:21 +02:00
Laurent Mazare c297a50960
Add mkl support for matrix multiply. (#86)
* Fix some rebase issues.

* Use mkl instead.

* Use mkl in bert.

* Add the optional mkl feature.

* Conditional compilation based on the mkl feature.

* Add more mkl support.
2023-07-06 11:05:05 +01:00
laurent 2c3d871b2e Add a simpler way to specify the dim index for some ops. 2023-07-05 20:22:43 +01:00
laurent 174e57d216 Use avg pooling before the cosine similarity. 2023-07-05 17:05:50 +01:00
laurent 914e84deec Add some sentence similarity comparision to the bert example. 2023-07-05 16:49:57 +01:00
Nicolas Patry d8f75ceeaa Some polish. 2023-07-05 07:41:14 +00:00
Nicolas Patry 963c75cb89 Adding offline mode. 2023-07-05 07:19:57 +00:00
Nicolas Patry 43a007cba4 Upgrading bert example to work with `bert-base-uncased`.
- Always take weights from the hub
- Optional `model_id` + `revision` to use safetensors version
  potentially
- Optional loading for `bert-base-uncased` (`weight` vs `gamma`).
- Take the config from the hub.
2023-07-04 14:12:14 +00:00
laurent a57b314780 Add a batch dimension on the bert example. 2023-07-04 06:10:52 +01:00
laurent b6d179cc1c Allow for batch dimensions in the embedding layer. 2023-07-03 18:37:40 +01:00
laurent 9784d1ed9f Minor tweaks. 2023-07-03 18:31:55 +01:00
laurent 5524ca29cc Remove the fixed length hack. 2023-07-03 17:13:23 +01:00
laurent 1ea6690557 Bugfix for transpose. 2023-07-03 17:06:23 +01:00
laurent a7f03a7bb6 Fix the layer norm to properly handle bias. 2023-07-03 16:45:03 +01:00
laurent f379b8feae Get some embeddings out. 2023-07-03 16:11:16 +01:00
laurent 54850e7525 Get the tensors to be loaded properly. 2023-07-03 15:53:31 +01:00
laurent ad52b0377c Add the varbuilder + check shapes. 2023-07-03 15:32:20 +01:00
laurent f74bddca31 Model creation. 2023-07-03 14:09:46 +01:00
laurent 12ac9e1460 Complete (?) the forward pass. 2023-07-03 13:33:32 +01:00
laurent d796945ad8 Add more to the forward pass. 2023-07-03 13:04:41 +01:00
laurent 2309c5fac5 Boilerplate code for Bert. 2023-07-03 12:17:06 +01:00