diff --git a/.buildkite/documentation.yml b/.buildkite/documentation.yml index 9895b6ef15..2a4398a12a 100644 --- a/.buildkite/documentation.yml +++ b/.buildkite/documentation.yml @@ -2,9 +2,9 @@ steps: - group: ":open_book: Build & Deploy Documentation" if: build.message !~ /\[skip docs\]/ && !build.pull_request.draft steps: - - label: "Tutorial Build [%N/%t]" - key: "tutorial-build" - parallelism: 6 + - label: "Tutorial Build [%N/%t] CUDA Runners" + key: "tutorial-build-cuda" + parallelism: 4 plugins: - JuliaCI/julia#v1: version: "1" @@ -13,7 +13,9 @@ steps: dirs: - src - ext - command: julia --code-coverage=user --color=yes --project=docs docs/tutorials.jl + command: julia --code-coverage=user --color=yes --project=docs --threads=auto docs/tutorials.jl + env: + TUTORIAL_BACKEND_GROUP: "CUDA" agents: queue: "juliagpu" cuda: "*" @@ -22,10 +24,40 @@ steps: - "docs/src/tutorials/intermediate/**/*" - "docs/src/tutorials/advanced/**/*" - "tutorial_deps/*" + - "**/*.cov" + timeout_in_minutes: 60 + + - label: "Tutorial Build [%N/%t] CPU Runners" + if: build.message !~ /\[skip docs\]/ && !build.pull_request.draft + key: "tutorial-build-cpu" + parallelism: 4 + plugins: + - JuliaCI/julia#v1: + version: "1" + # - JuliaCI/julia-coverage#v1: + # codecov: true + # dirs: + # - src + # - ext + command: julia --code-coverage=user --color=yes --project=docs --threads=auto docs/tutorials.jl + env: + TUTORIAL_BACKEND_GROUP: "CPU" + agents: + queue: "juliaecosystem" + os: "linux" + arch: "x86_64" + artifact_paths: + - "docs/src/tutorials/beginner/**/*" + - "docs/src/tutorials/intermediate/**/*" + - "docs/src/tutorials/advanced/**/*" + - "tutorial_deps/*" + - "**/*.cov" timeout_in_minutes: 60 - label: "Final Documentation Build" - depends_on: [tutorial-build] + depends_on: + - "tutorial-build-cuda" + - "tutorial-build-cpu" plugins: - JuliaCI/julia#v1: version: "1" @@ -65,4 +97,4 @@ env: JULIA_NUM_THREADS: 4 GKSwstype: "100" # https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988 SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" - SECRET_DOCUMENTER_KEY: "iRC4P/r5o9pARB670eK9jPlKQKgkTMDAyvp2GbLG8WwLuT8T1VcWx/o4+ofGlzbTh5Z+LuFgPXfgqkjGuoWLcocHNm78xQMNMywB4rcLB2shqp8xG2vhglgnTBBS4EiyPAtVqGyi5AKmfF95PfkJvnI0Lqg5P/RWQvNGywLAR0Ikgr/lqocm2CvkFGbpMzpGxGvj76JYOusVeKvGAp698TXqPabSZR2oZQLfYnEZnaO8ivkqvMGQSXfgzoIMjCOrN1rSa84SWeI9BDeBslzDHwaYGlvjpfCyviiLtKj4t5Acl1gVE0qxxZxWuALIU6z+C1W8TbW7ZDCBUFs6UTIT+Q==;U2FsdGVkX1+/HSgg1skLszz835vSO6mEtXMhG62ohQQUc5opdo7kEIAG2wCoJPQrqGyaF9kKDVvrN5G2MdjUyaLBYlv90RzXhjTiMNFdgI3M4K500xKq3itY/aEL7hUSMRKxTos8u4xhdbRboY4rPcqgtCJ2LHEjNxmml/NfEo/8lk291rGoEYQLTvKP9cuo4enmEVVRhqmabBzt1MDz0m4c8RufJWW2Ni4osaKRkYPjl/ijJ38wvRUZIiyCX7uofh+3iCKWn0111q5xFhn256Pm79Cx2ZP+yTp9sMsVNMJZ3UJ5r18F3H+zFHWWQSoiWpHn2WNB/2VUEyt0Lp1LnogKru96P2oYkXi6kqrA+qlLISUUU7R7ggJU0IRS6MjSGDyVzlaZG8m+RmY0bmQKrDwSeq1JMGkBpjwPY1o4yOnFRB7Rj1bzToLtd2IFSa8x0a2dUSyL5pBlyWklzZCxPp05R53RNSOi2KfhNfdZU2H7xEj5+z2aV5OidzowXIyYH8FlusMdk3NAOsvTbmBGiwvN4Zub9Exli06ZwARu/oJHLRh+hgOErIJ7DoX6nPrAtofSy6Etydpt+c4HkVZtGPWFSTMNWIGNx2NB1IfveOTU60H5emQ7zow5grXz4VTczqvCIh2hoQdSR4Oplr6+tDDLhtcGGHchHt473o2ygQ1m1tg7oSvMN7jmkUV1N6GniQofmlbr8d5LK4i/QtfC5GHCKIg3ohRlDvuvvKzvVWofgHX3NhXFTKK/CWAIp76iOaCWJcI562SpKyn+pFqYKpatJ42WfF3VbNpJYVMYMai5BwAE2RyZ6FhHbsaHq/NXO/dRJwHeDm4Pc/LFlGFdzpdbuf+w2DoePc56PlNmKsLNlZVlwbWcExKttI8nz3Th3aHNNtbIbD9awf1RdDspudQrTPWkyEopDVm7TkOj/J891U5p24PF5dasIJR19Tqpic3LVJuBXYRbL/Z79VRjeE3wBGLTDdhzJMA8TrS+yMSCF80bIw/F44o4WbA3Ya425mph9MIt/a137osRKATYqbustmVW/LfIyVhuHCOCRQsqTyFU+ff6Tp0EE2i1du90wosr+UutXiubYphCmuKkZONPbiXjpW1CAi40iAwxfgOVqAl13y4FlUp4EiGS7hPBUbvvEXMqT3ssfL+mlideH/v08PQCRcyG03zcCjCTmjXCggqHd+eEXhnsNZ4PFKCKiN+znR5SW+/p+kJTaBrX2e/kMU6kzjwb4NyNmZie0hHSneVtwJ1FuXJk/Zph4quv5KugCCx21xb5pePqxfKRW5jtW6r2Rc7OSNN4BHjwAcj8fOVV+12Ak7//o8mRh0aveYfoEvjCdaI8OPfjduDGfmzPUvXiqV9kGpovdlDUATyoVa3l1CowJ5r8KDOD6Ps89OG7TV2c7Wzxq2FQVjMFXxv/4wMZR1F/0zyH+ofPLVZjK3039z35GD4uoOW9Uc7WSr4FbxxuCDwOXWgstuk3rk6ASZFSe7RIwE/Y16d/aqzI+LG8pHqaEdhg6o6Y6JxBYNQo/JoglUOHwD+N5g5n9vfBNzf0xTlE/r0yjO3LCHyWzCnWr3QdKgzm6EDyL8GO+yQIbtXtw6lRQB/UEZ+ayt175r08Yhey95IsPwLVDFRRlG6pYwmzTlQOEwvqDI8SDMWboU+jp6a5jrbaAmqiIkaoiIzrV1QDp1x+Sqj0veqN+RtcpXLawJevz8dm76H+Mmp1br61nwvGcBaOKukICVj3iLeeu5tV5NoEJznWPwveHrcarZtKvOOeJbydmNAz286i0F1ocX337dt17jIkRv9sHbfqAVapob+eT7F3N/UY99GWGDVbXzaruQwsuPPR6MbLolG6buHQaKX3OZ/zJqGWfEAHw5yJKoKNe8aSgY2DsoITqPlbNRQQmOIMuF8ffD8L1stD/P5Ohth5Nql2W+l6y87/nqxkJ9y4FFS4QzrMrl9ztugfsRoYyeSWRydLUHlTCv155VsGAxjCMBQg1rP99Smfd02EbCFlWlypIw/zem0LZ1zVuz/Wjb03n+dzi2GIKRlTrt6YMrGGAcKI+3Pf1D0rsDhXNkdFUjOeofUkDbBr/splYCKLucDHFVdN88XyaQoj2fBymNJ4BqvK64TVOLwPGAQvh/rHZ5PkJR3lMI4fg+Kxdl9/5xDjkD9aV+yRvfqVGodNW/qofq34nrdb3co1tZ4BxtSANKdJg3Fv6U0I4DOMVsJTeOn/918M31rif0rKAwnHAkeyQVbZyEsFoqxvE8gUFs1zTRwZJWlmY0xnuVcM8pOh6hULeYGiF57ZlbvymygYqObe58YgrChRnF4NhKIIYzuz7mOSKRXqF3Cr0LNYHcktUH9wrqISxiHbaUQceYZ1D0q8UfiayeK9yppMkltcDUL9M93xjTGJK8pVzARXn6ETuEsNTtLvbU/KMDY7bnVc7n08suLCk1YeJB/sn0wuTbPt+27NeYIG1YXBEE0dsgJW4z64489h71v4xws856gFOHZx0L/nkW7l328HA3jltbgJFl52mQHAJwUZrt5sJef/k7gsTdX1zQtjKN8lFjo4qpvJUpenmO9nT+Wty5cjohlETBos8CdSqj4SjEu7/UhDt52evt33EayoWJ8TjKd4VRFYCXnM6eGnSMDqUU5f7DxVjrwHnT26jtq9ijKTiAxls7fYjN8TGT/S3CHZZAK1u5gSbWfkFOcE+mioboNwDvuvysjL6de+bsc7r35w4hLFnPmKemcde4pNQfEnuelBFJqwYZbcAkhN8AmtqIWPXBw9n3eUx/TJgMFEIoB/frNDRbB0WJKdBkjdE1NVvAUl3jDnZbWjG6rqE+6UvyGqKBpd0FRYAfg3ss3hVB70uluULKUBVazlNIQlqX+qYEMBXaDIkxcftre8KYebQyJnxiOB5V+eELvm6L28bK4Xh2tpXzJL7aDlQnL8dRNvQdZgDL62EXYhrc3mz0I/p7br3KMcnei/LaPRAgcsW7WKLwzE5id6JnpOJj4VXdkX7IUB4xQjDRsGKxhjbklMVFA8g/801khNlwzU/IoXsHBgTs7yZoFX/oo4Jyp514hwqPlvJEgci0OHiSA6Mx3le2nUh0SQH+AzFJ2vi7Bn1a4psiuqd+vJJ1iuNw5CBCZlV+GO8sG93BBGnLzZDoRvkIMbzwESFP3JYZ/lKs29CB2Adobl9YbwP3he0I9cD0A/RPC70gzTdVEfL6T4iPUhBr1Bn3YlUPeC2QvCTbpKkxDsfzchuq/y0xlmL4E7Rdb+4TSMlViXfnc6aoD9vvPMWLJFF2qrxRLKhUTse5V6RoE+EVmHSiX0Vd7sd/bYp7asOC0b1xL+zjfJ5DSrtMA/P8L1p+CoLNXgVfgzCB3sCa+GLSLS2INsL1Qtnfkl8IGaMDeV+VAyHjY0HCj0l1X99f/RzD6TYrZAkLS8h1EM/JjomglhVG9/HTKS20BBJeos5ifrVd38rhONJy0HCP28pn4rCIyIE4bNG+1tEsHAg4FDYgh/OYuBsaGYgha9TGV5lGIxmVCECq3IPpkPN1CsLqv3KuDvNeH6XOOAzVtFj4VoIV6QgRLP8+94ZiiEDaPQxQ7BZoqrqFYrxWHDtEuon46VtQ3Nfq/1Rq/HvszJv6JE77w7qvKlxG9sXgxzCDRqNrG83cwY2hpDBr8U0hPMrEx977Weja1aG/rG6uirNBcY5qAAOLDo+9RvV1xqvWFF8SkT97tzNUHbzw8tuUlCT9m4rshCG+jBw59rpUZwW+eR1ih9qU7Nyr3oNgi/zmkORF1duym8VSfW5dxtRBIqxxM0oSWoHti+HSd0VLdHw8jRpbQddMBr1sjD1jIgp3w2dU4oEthzStKCPY2/lAWBm+1Es1okGhEM3I939DRcYOjfJnTCtJLJ9DTKycVDMerXvHnCgImZ0Oh4mtLF+63hn+9wUc56owFeNqs+NJHqmBBFX2uNr/Rj9mzYkRRPsYYSyCB7jIS+Z8Zall6W3dwLcsE3uw/oPKx5bJDAhnp7kZgzLC0zlS2D0ZcNZuW2uUtwhZJM6OOyV+FUFgizmpIQAQ8Nm6n/1yk0asB4jZFf221a9ZmzvUfWKmmIR7OxX3qBH9x2uMMhemv9LZdEHMcjTeIXRYciMLWUNeWagYhDgV1cRBGCDTh2EhHvYX7ZXfpsHjLOR+sAEr7uR3siitf/mRkiLfT2YBgTACKKoj05UuC8aknEV4T5bWiye+gKGioml5G/fWYHyHow37g6D84n0cBTWmI0oPlg+rqpeRLOeYaTeCXOtM/7M1FHuGvzmBnag2vhKY2tpjVrg2nI3p4SRlzTyoQkyMfRXN87v5nAheVcLgrYtkv9aX7R6VMZ1UIsxn62ZHFa2IR6skB/xw7RRuJY5r5FIWs1LqIQDaon5L4C4v9rnBxMYoUM" \ No newline at end of file + SECRET_DOCUMENTER_KEY: "iRC4P/r5o9pARB670eK9jPlKQKgkTMDAyvp2GbLG8WwLuT8T1VcWx/o4+ofGlzbTh5Z+LuFgPXfgqkjGuoWLcocHNm78xQMNMywB4rcLB2shqp8xG2vhglgnTBBS4EiyPAtVqGyi5AKmfF95PfkJvnI0Lqg5P/RWQvNGywLAR0Ikgr/lqocm2CvkFGbpMzpGxGvj76JYOusVeKvGAp698TXqPabSZR2oZQLfYnEZnaO8ivkqvMGQSXfgzoIMjCOrN1rSa84SWeI9BDeBslzDHwaYGlvjpfCyviiLtKj4t5Acl1gVE0qxxZxWuALIU6z+C1W8TbW7ZDCBUFs6UTIT+Q==;U2FsdGVkX1+/HSgg1skLszz835vSO6mEtXMhG62ohQQUc5opdo7kEIAG2wCoJPQrqGyaF9kKDVvrN5G2MdjUyaLBYlv90RzXhjTiMNFdgI3M4K500xKq3itY/aEL7hUSMRKxTos8u4xhdbRboY4rPcqgtCJ2LHEjNxmml/NfEo/8lk291rGoEYQLTvKP9cuo4enmEVVRhqmabBzt1MDz0m4c8RufJWW2Ni4osaKRkYPjl/ijJ38wvRUZIiyCX7uofh+3iCKWn0111q5xFhn256Pm79Cx2ZP+yTp9sMsVNMJZ3UJ5r18F3H+zFHWWQSoiWpHn2WNB/2VUEyt0Lp1LnogKru96P2oYkXi6kqrA+qlLISUUU7R7ggJU0IRS6MjSGDyVzlaZG8m+RmY0bmQKrDwSeq1JMGkBpjwPY1o4yOnFRB7Rj1bzToLtd2IFSa8x0a2dUSyL5pBlyWklzZCxPp05R53RNSOi2KfhNfdZU2H7xEj5+z2aV5OidzowXIyYH8FlusMdk3NAOsvTbmBGiwvN4Zub9Exli06ZwARu/oJHLRh+hgOErIJ7DoX6nPrAtofSy6Etydpt+c4HkVZtGPWFSTMNWIGNx2NB1IfveOTU60H5emQ7zow5grXz4VTczqvCIh2hoQdSR4Oplr6+tDDLhtcGGHchHt473o2ygQ1m1tg7oSvMN7jmkUV1N6GniQofmlbr8d5LK4i/QtfC5GHCKIg3ohRlDvuvvKzvVWofgHX3NhXFTKK/CWAIp76iOaCWJcI562SpKyn+pFqYKpatJ42WfF3VbNpJYVMYMai5BwAE2RyZ6FhHbsaHq/NXO/dRJwHeDm4Pc/LFlGFdzpdbuf+w2DoePc56PlNmKsLNlZVlwbWcExKttI8nz3Th3aHNNtbIbD9awf1RdDspudQrTPWkyEopDVm7TkOj/J891U5p24PF5dasIJR19Tqpic3LVJuBXYRbL/Z79VRjeE3wBGLTDdhzJMA8TrS+yMSCF80bIw/F44o4WbA3Ya425mph9MIt/a137osRKATYqbustmVW/LfIyVhuHCOCRQsqTyFU+ff6Tp0EE2i1du90wosr+UutXiubYphCmuKkZONPbiXjpW1CAi40iAwxfgOVqAl13y4FlUp4EiGS7hPBUbvvEXMqT3ssfL+mlideH/v08PQCRcyG03zcCjCTmjXCggqHd+eEXhnsNZ4PFKCKiN+znR5SW+/p+kJTaBrX2e/kMU6kzjwb4NyNmZie0hHSneVtwJ1FuXJk/Zph4quv5KugCCx21xb5pePqxfKRW5jtW6r2Rc7OSNN4BHjwAcj8fOVV+12Ak7//o8mRh0aveYfoEvjCdaI8OPfjduDGfmzPUvXiqV9kGpovdlDUATyoVa3l1CowJ5r8KDOD6Ps89OG7TV2c7Wzxq2FQVjMFXxv/4wMZR1F/0zyH+ofPLVZjK3039z35GD4uoOW9Uc7WSr4FbxxuCDwOXWgstuk3rk6ASZFSe7RIwE/Y16d/aqzI+LG8pHqaEdhg6o6Y6JxBYNQo/JoglUOHwD+N5g5n9vfBNzf0xTlE/r0yjO3LCHyWzCnWr3QdKgzm6EDyL8GO+yQIbtXtw6lRQB/UEZ+ayt175r08Yhey95IsPwLVDFRRlG6pYwmzTlQOEwvqDI8SDMWboU+jp6a5jrbaAmqiIkaoiIzrV1QDp1x+Sqj0veqN+RtcpXLawJevz8dm76H+Mmp1br61nwvGcBaOKukICVj3iLeeu5tV5NoEJznWPwveHrcarZtKvOOeJbydmNAz286i0F1ocX337dt17jIkRv9sHbfqAVapob+eT7F3N/UY99GWGDVbXzaruQwsuPPR6MbLolG6buHQaKX3OZ/zJqGWfEAHw5yJKoKNe8aSgY2DsoITqPlbNRQQmOIMuF8ffD8L1stD/P5Ohth5Nql2W+l6y87/nqxkJ9y4FFS4QzrMrl9ztugfsRoYyeSWRydLUHlTCv155VsGAxjCMBQg1rP99Smfd02EbCFlWlypIw/zem0LZ1zVuz/Wjb03n+dzi2GIKRlTrt6YMrGGAcKI+3Pf1D0rsDhXNkdFUjOeofUkDbBr/splYCKLucDHFVdN88XyaQoj2fBymNJ4BqvK64TVOLwPGAQvh/rHZ5PkJR3lMI4fg+Kxdl9/5xDjkD9aV+yRvfqVGodNW/qofq34nrdb3co1tZ4BxtSANKdJg3Fv6U0I4DOMVsJTeOn/918M31rif0rKAwnHAkeyQVbZyEsFoqxvE8gUFs1zTRwZJWlmY0xnuVcM8pOh6hULeYGiF57ZlbvymygYqObe58YgrChRnF4NhKIIYzuz7mOSKRXqF3Cr0LNYHcktUH9wrqISxiHbaUQceYZ1D0q8UfiayeK9yppMkltcDUL9M93xjTGJK8pVzARXn6ETuEsNTtLvbU/KMDY7bnVc7n08suLCk1YeJB/sn0wuTbPt+27NeYIG1YXBEE0dsgJW4z64489h71v4xws856gFOHZx0L/nkW7l328HA3jltbgJFl52mQHAJwUZrt5sJef/k7gsTdX1zQtjKN8lFjo4qpvJUpenmO9nT+Wty5cjohlETBos8CdSqj4SjEu7/UhDt52evt33EayoWJ8TjKd4VRFYCXnM6eGnSMDqUU5f7DxVjrwHnT26jtq9ijKTiAxls7fYjN8TGT/S3CHZZAK1u5gSbWfkFOcE+mioboNwDvuvysjL6de+bsc7r35w4hLFnPmKemcde4pNQfEnuelBFJqwYZbcAkhN8AmtqIWPXBw9n3eUx/TJgMFEIoB/frNDRbB0WJKdBkjdE1NVvAUl3jDnZbWjG6rqE+6UvyGqKBpd0FRYAfg3ss3hVB70uluULKUBVazlNIQlqX+qYEMBXaDIkxcftre8KYebQyJnxiOB5V+eELvm6L28bK4Xh2tpXzJL7aDlQnL8dRNvQdZgDL62EXYhrc3mz0I/p7br3KMcnei/LaPRAgcsW7WKLwzE5id6JnpOJj4VXdkX7IUB4xQjDRsGKxhjbklMVFA8g/801khNlwzU/IoXsHBgTs7yZoFX/oo4Jyp514hwqPlvJEgci0OHiSA6Mx3le2nUh0SQH+AzFJ2vi7Bn1a4psiuqd+vJJ1iuNw5CBCZlV+GO8sG93BBGnLzZDoRvkIMbzwESFP3JYZ/lKs29CB2Adobl9YbwP3he0I9cD0A/RPC70gzTdVEfL6T4iPUhBr1Bn3YlUPeC2QvCTbpKkxDsfzchuq/y0xlmL4E7Rdb+4TSMlViXfnc6aoD9vvPMWLJFF2qrxRLKhUTse5V6RoE+EVmHSiX0Vd7sd/bYp7asOC0b1xL+zjfJ5DSrtMA/P8L1p+CoLNXgVfgzCB3sCa+GLSLS2INsL1Qtnfkl8IGaMDeV+VAyHjY0HCj0l1X99f/RzD6TYrZAkLS8h1EM/JjomglhVG9/HTKS20BBJeos5ifrVd38rhONJy0HCP28pn4rCIyIE4bNG+1tEsHAg4FDYgh/OYuBsaGYgha9TGV5lGIxmVCECq3IPpkPN1CsLqv3KuDvNeH6XOOAzVtFj4VoIV6QgRLP8+94ZiiEDaPQxQ7BZoqrqFYrxWHDtEuon46VtQ3Nfq/1Rq/HvszJv6JE77w7qvKlxG9sXgxzCDRqNrG83cwY2hpDBr8U0hPMrEx977Weja1aG/rG6uirNBcY5qAAOLDo+9RvV1xqvWFF8SkT97tzNUHbzw8tuUlCT9m4rshCG+jBw59rpUZwW+eR1ih9qU7Nyr3oNgi/zmkORF1duym8VSfW5dxtRBIqxxM0oSWoHti+HSd0VLdHw8jRpbQddMBr1sjD1jIgp3w2dU4oEthzStKCPY2/lAWBm+1Es1okGhEM3I939DRcYOjfJnTCtJLJ9DTKycVDMerXvHnCgImZ0Oh4mtLF+63hn+9wUc56owFeNqs+NJHqmBBFX2uNr/Rj9mzYkRRPsYYSyCB7jIS+Z8Zall6W3dwLcsE3uw/oPKx5bJDAhnp7kZgzLC0zlS2D0ZcNZuW2uUtwhZJM6OOyV+FUFgizmpIQAQ8Nm6n/1yk0asB4jZFf221a9ZmzvUfWKmmIR7OxX3qBH9x2uMMhemv9LZdEHMcjTeIXRYciMLWUNeWagYhDgV1cRBGCDTh2EhHvYX7ZXfpsHjLOR+sAEr7uR3siitf/mRkiLfT2YBgTACKKoj05UuC8aknEV4T5bWiye+gKGioml5G/fWYHyHow37g6D84n0cBTWmI0oPlg+rqpeRLOeYaTeCXOtM/7M1FHuGvzmBnag2vhKY2tpjVrg2nI3p4SRlzTyoQkyMfRXN87v5nAheVcLgrYtkv9aX7R6VMZ1UIsxn62ZHFa2IR6skB/xw7RRuJY5r5FIWs1LqIQDaon5L4C4v9rnBxMYoUM" diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index a2acf191fa..237293a0f3 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,6 +1,6 @@ steps: - label: "Triggering Pipelines (Pull Request)" - if: "build.pull_request.base_branch == 'main'" + if: build.branch != "main" || build.tag == null agents: queue: "juliagpu" plugins: diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f008abdef8..9790c8d09d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -154,8 +154,6 @@ jobs: with: version: ${{ matrix.version }} - uses: julia-actions/julia-downgrade-compat@v1 - with: - skip: 'AMDGPU' - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: diff --git a/Project.toml b/Project.toml index 55fd497790..537c8574ea 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.68" +version = "1.0.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -20,7 +20,6 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -42,7 +41,6 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" @@ -56,7 +54,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] LuxComponentArraysExt = "ComponentArrays" -LuxDynamicExpressionsExt = "DynamicExpressions" LuxEnzymeExt = "Enzyme" LuxFluxExt = "Flux" LuxMLUtilsExt = "MLUtils" @@ -70,7 +67,7 @@ LuxZygoteExt = "Zygote" [compat] ADTypes = "1.5" Adapt = "4" -ArgCheck = "2.1" +ArgCheck = "2.3" ArrayInterface = "7.9" CUDA = "5.3.2" ChainRulesCore = "1.24" @@ -78,7 +75,6 @@ Compat = "4.15" ComponentArrays = "0.15.16" ConcreteStructs = "0.2.3" DispatchDoctor = "0.4.12" -DynamicExpressions = "0.16, 0.17, 0.18, 0.19" Enzyme = "0.12.26" EnzymeCore = "0.7.7" FastClosures = "0.3.2" @@ -89,9 +85,8 @@ Functors = "0.4.12" GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" LossFunctions = "0.11.1" -LuxCore = "0.1.24" -LuxDeviceUtils = "0.1.26" -LuxLib = "0.3.42" +LuxCore = "1" +LuxLib = "1.2" MLDataDevices = "1.1" MLUtils = "0.4.4" MPI = "0.20.19" @@ -104,7 +99,7 @@ Preferences = "1.4.3" Random = "1.10" Reexport = "1.2.2" ReverseDiff = "1.15" -SIMDTypes = "0.1.0" +SIMDTypes = "0.1" Setfield = "1.1.1" SimpleChains = "0.4.7" Static = "1.1.1" @@ -113,6 +108,6 @@ Statistics = "1.10" Tracker = "0.2.34" UnrolledUtilities = "0.1.2" VectorizationBase = "0.21.70" -WeightInitializers = "0.1.5, 1" +WeightInitializers = "1" Zygote = "0.6.70" julia = "1.10" diff --git a/README.md b/README.md index bd747cb396..033be3ebdd 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,9 @@ import Pkg Pkg.add("Lux") ``` +> [!TIP] +> If you are using a pre-v1 version of Lux.jl, please see the [Updating to v1 section](https://lux.csail.mit.edu/dev/introduction/updating_to_v1/) for instructions on how to update. + ## 🤸 Quickstart ```julia @@ -74,7 +77,43 @@ st_opt, ps = Optimisers.update(st_opt, ps, gs) Look in the [examples](/examples/) directory for self-contained usage examples. The [documentation](https://lux.csail.mit.edu) has examples sorted into proper categories. -## 🧪 Testing +## 🆘 Getting Help + +For usage related questions, please use [Github Discussions](https://github.com/orgs/LuxDL/discussions) which allows questions and answers to be indexed. To report bugs use [github issues](https://github.com/LuxDL/Lux.jl/issues) or even better send in a [pull request](https://github.com/LuxDL/Lux.jl/pulls). + +## 🧑‍🔬 Citation + +If you found this library to be useful in academic work, then please cite: + +```bibtex +@software{pal2023lux, + author = {Pal, Avik}, + title = {{Lux: Explicit Parameterization of Deep Neural Networks in Julia}}, + month = apr, + year = 2023, + note = {If you use this software, please cite it as below.}, + publisher = {Zenodo}, + version = {v0.5.0}, + doi = {10.5281/zenodo.7808904}, + url = {https://doi.org/10.5281/zenodo.7808904} +} + +@thesis{pal2023efficient, + title = {{On Efficient Training \& Inference of Neural Differential Equations}}, + author = {Pal, Avik}, + year = {2023}, + school = {Massachusetts Institute of Technology} +} +``` + +Also consider starring [our github repo](https://github.com/LuxDL/Lux.jl/). + +## 🧑‍💻 Contributing + +This section is somewhat incomplete. You can contribute by contributing to finishing this +section 😜. + +### 🧪 Testing The full test of `Lux.jl` takes a long time, here's how to test a portion of the code. @@ -122,36 +161,5 @@ ReTestItems.runtests("tests/"; name = "NAME OF THE TEST") For the `SkipConnection` tests that would be: ```julia -ReTestItems.runtests("tests/"; name = SkipConnection) -``` - -## 🆘 Getting Help - -For usage related questions, please use [Github Discussions](https://github.com/orgs/LuxDL/discussions) which allows questions and answers to be indexed. To report bugs use [github issues](https://github.com/LuxDL/Lux.jl/issues) or even better send in a [pull request](https://github.com/LuxDL/Lux.jl/pulls). - -## 🧑‍🔬 Citation - -If you found this library to be useful in academic work, then please cite: - -```bibtex -@software{pal2023lux, - author = {Pal, Avik}, - title = {{Lux: Explicit Parameterization of Deep Neural Networks in Julia}}, - month = apr, - year = 2023, - note = {If you use this software, please cite it as below.}, - publisher = {Zenodo}, - version = {v0.5.0}, - doi = {10.5281/zenodo.7808904}, - url = {https://doi.org/10.5281/zenodo.7808904} -} - -@thesis{pal2023efficient, - title = {{On Efficient Training \& Inference of Neural Differential Equations}}, - author = {Pal, Avik}, - year = {2023}, - school = {Massachusetts Institute of Technology} -} +ReTestItems.runtests("tests/"; name = "SkipConnection") ``` - -Also consider starring [our github repo](https://github.com/LuxDL/Lux.jl/). diff --git a/benchmarks/setup.jl b/benchmarks/setup.jl index 2ed92f2e0f..e2d05bc889 100644 --- a/benchmarks/setup.jl +++ b/benchmarks/setup.jl @@ -1,6 +1,6 @@ using ADTypes: ADTypes, AutoEnzyme, AutoZygote using Adapt: adapt -using Lux: Lux, BatchNorm, Chain, Conv, CrossCor, Dense, Dropout, FlattenLayer, MaxPool +using Lux: Lux, BatchNorm, Chain, Conv, Dense, Dropout, FlattenLayer, MaxPool using MLDataDevices: AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice using NNlib: relu, gelu using Random: Random diff --git a/benchmarks/setups/models.jl b/benchmarks/setups/models.jl index c5f8146f0c..b4f3763039 100644 --- a/benchmarks/setups/models.jl +++ b/benchmarks/setups/models.jl @@ -25,10 +25,10 @@ function setup_vgg16_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, conv_bn((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), conv_bn((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), conv_bn((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), - MaxPool((2, 2)); disable_optimizations=true), + MaxPool((2, 2))), FlattenLayer(), Chain(Dense(512, 4096, relu), Dropout(0.5f0), Dense(4096, 4096, relu), - Dropout(0.5f0), Dense(4096, 10); name="Classifier"); disable_optimizations=true) + Dropout(0.5f0), Dense(4096, 10); name="Classifier")) for bsize in (32, 64, 128) setup_forward_pass_benchmark!(suite, "vgg16(32, 32, 3, $bsize)", diff --git a/docs/Project.toml b/docs/Project.toml index 2dab4173db..85ac205ae7 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -14,7 +14,6 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" @@ -22,6 +21,7 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -39,11 +39,10 @@ GPUArraysCore = "0.1" KernelAbstractions = "0.9" LinearAlgebra = "1.10" Literate = "2.18.0" -Lux = "0.5.62" +Lux = "1" LuxCUDA = "0.3.2" -LuxCore = "0.1.15" -LuxDeviceUtils = "0.1.21" -LuxLib = "0.3.42" +LuxCore = "1" +LuxLib = "1" LuxTestUtils = "1.1" MLDataDevices = "1.1" Optimisers = "0.3.3" @@ -51,6 +50,6 @@ Pkg = "1.10" Printf = "1.10" Random = "1.10" StaticArrays = "1" -WeightInitializers = "0.1.7, 1" +WeightInitializers = "1" Zygote = "0.6.70" julia = "1.10" diff --git a/docs/make.jl b/docs/make.jl index 56c51ae44e..a491ada64d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,7 +1,6 @@ using Documenter, DocumenterVitepress, Pkg using Lux, LuxCore, LuxLib, WeightInitializers -using LuxTestUtils, LuxDeviceUtils -using MLDataDevices +using LuxTestUtils, MLDataDevices using LuxCUDA using Optimisers # for some docstrings @@ -15,6 +14,7 @@ pages = [ "Introduction" => "introduction/index.md", "Overview" => "introduction/overview.md", "Resources" => "introduction/resources.md", + "Updating to v1" => "introduction/updating_to_v1.md", "Citation" => "introduction/citation.md" ], "Tutorials" => [ @@ -31,8 +31,7 @@ pages = [ "tutorials/intermediate/3_HyperNet.md" ], "Advanced" => [ - "tutorials/advanced/1_GravitationalWaveForm.md", - "tutorials/advanced/2_SymbolicOptimalControl.md" + "tutorials/advanced/1_GravitationalWaveForm.md" ] ], "Manual" => [ @@ -56,7 +55,6 @@ pages = [ "api/Lux/distributed_utils.md", ], "Accelerator Support" => [ - "api/Accelerator_Support/LuxDeviceUtils.md", "api/Accelerator_Support/MLDataDevices.md" ], "Building Blocks" => [ @@ -80,8 +78,7 @@ makedocs(; sitename="Lux.jl Docs", authors="Avik Pal et al.", clean=true, doctest=false, # We test it in the CI, no need to run it here - modules=[Lux, LuxCore, LuxLib, WeightInitializers, - LuxTestUtils, LuxDeviceUtils, MLDataDevices], + modules=[Lux, LuxCore, LuxLib, WeightInitializers, LuxTestUtils, MLDataDevices], linkcheck=true, repo="https://github.com/LuxDL/Lux.jl/blob/{commit}{path}#{line}", format=DocumenterVitepress.MarkdownVitepress(; diff --git a/docs/run_single_tutorial.jl b/docs/run_single_tutorial.jl index 965f99b942..b163ee244d 100644 --- a/docs/run_single_tutorial.jl +++ b/docs/run_single_tutorial.jl @@ -24,13 +24,13 @@ function preprocess(path, str) using InteractiveUtils InteractiveUtils.versioninfo() - if @isdefined(LuxDeviceUtils) - if @isdefined(CUDA) && LuxDeviceUtils.functional(LuxCUDADevice) + if @isdefined(MLDataDevices) + if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice) println() CUDA.versioninfo() end - if @isdefined(AMDGPU) && LuxDeviceUtils.functional(LuxAMDGPUDevice) + if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice) println() AMDGPU.versioninfo() end diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index b111a3bf37..f6e58e14b3 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -6,7 +6,7 @@ import { transformerMetaWordHighlight } from '@shikijs/transformers'; // https://vitepress.dev/reference/site-config export default defineConfig({ - base: 'REPLACE_ME_DOCUMENTER_VITEPRESS',// TODO: replace this in makedocs! + base: 'REPLACE_ME_DOCUMENTER_VITEPRESS', title: 'REPLACE_ME_DOCUMENTER_VITEPRESS', description: 'Documentation for LuxDL Repositories', cleanUrls: true, @@ -79,7 +79,6 @@ export default defineConfig({ }, { text: 'Accelerator Support', items: [ - { text: 'LuxDeviceUtils', link: '/api/Accelerator_Support/LuxDeviceUtils' }, { text: 'MLDataDevices', link: '/api/Accelerator_Support/MLDataDevices' } ] }, @@ -112,6 +111,7 @@ export default defineConfig({ { text: 'Introduction', link: '/introduction' }, { text: 'Overview', link: '/introduction/overview' }, { text: 'Resources', link: '/introduction/resources' }, + { text: 'Updating to v1', link: '/introduction/updating_to_v1' }, { text: 'Citation', link: '/introduction/citation' }] }, "/tutorials/": { @@ -132,8 +132,7 @@ export default defineConfig({ }, { text: 'Advanced', collapsed: false, items: [ - { text: 'Training a Neural ODE to Model Gravitational Waveforms', link: '/tutorials/advanced/1_GravitationalWaveForm' }, - { text: 'Solving Optimal Control Problems with Symbolic UDEs', link: '/tutorials/advanced/2_SymbolicOptimalControl' },] + { text: 'Training a Neural ODE to Model Gravitational Waveforms', link: '/tutorials/advanced/1_GravitationalWaveForm' },] }, { text: 'Large Models', collapsed: true, items: [ @@ -216,7 +215,6 @@ export default defineConfig({ }, { text: 'Accelerator Support', collapsed: false, items: [ - { text: 'LuxDeviceUtils', link: '/api/Accelerator_Support/LuxDeviceUtils' }, { text: 'MLDataDevices', link: '/api/Accelerator_Support/MLDataDevices' }] }, { diff --git a/docs/src/api/Accelerator_Support/LuxDeviceUtils.md b/docs/src/api/Accelerator_Support/LuxDeviceUtils.md deleted file mode 100644 index e4fac9ba01..0000000000 --- a/docs/src/api/Accelerator_Support/LuxDeviceUtils.md +++ /dev/null @@ -1,50 +0,0 @@ -# [LuxDeviceUtils](@id LuxDeviceUtils-API) - -`LuxDeviceUtils.jl` is a lightweight package defining rules for transferring data across -devices. Most users should directly use Lux.jl instead. - -!!! note "Transition to `MLDataDevices.jl`" - - Currently this package is in maintenance mode and won't receive any new features, - however, we will backport bug fixes till Lux `v1.0` is released. Post that this package - should be considered deprecated and users should switch to `MLDataDevices.jl`. - - For more information on `MLDataDevices.jl` checkout the - [MLDataDevices.jl Documentation](@ref MLDataDevices-API). - -## Index - -```@index -Pages = ["LuxDeviceUtils.md"] -``` - -## Preferences - -```@docs -LuxDeviceUtils.gpu_backend! -``` - -## Data Transfer - -```@docs -LuxDeviceUtils.cpu_device -LuxDeviceUtils.gpu_device -``` - -## Miscellaneous - -```@docs -LuxDeviceUtils.reset_gpu_device! -LuxDeviceUtils.supported_gpu_backends -LuxDeviceUtils.default_device_rng -LuxDeviceUtils.get_device -LuxDeviceUtils.get_device_type -LuxDeviceUtils.loaded -LuxDeviceUtils.functional -``` - -## Multi-GPU Support - -```@docs -LuxDeviceUtils.set_device! -``` diff --git a/docs/src/api/Accelerator_Support/MLDataDevices.md b/docs/src/api/Accelerator_Support/MLDataDevices.md index df15d913f1..c1c031e82c 100644 --- a/docs/src/api/Accelerator_Support/MLDataDevices.md +++ b/docs/src/api/Accelerator_Support/MLDataDevices.md @@ -3,7 +3,7 @@ `MLDataDevices.jl` is a lightweight package defining rules for transferring data across devices. Most users should directly use Lux.jl instead. -!!! note "Comparison to LuxDeviceUtils.jl" +!!! note "Transitioning from `LuxDeviceUtils.jl`" `LuxDeviceUtils.jl` was renamed to `MLDataDevices.jl` in v1.0 as a part of allowing these packages to have broader adoption outsize the Lux community. However, Lux diff --git a/docs/src/api/Building_Blocks/LuxCore.md b/docs/src/api/Building_Blocks/LuxCore.md index 894266ccde..3016597b62 100644 --- a/docs/src/api/Building_Blocks/LuxCore.md +++ b/docs/src/api/Building_Blocks/LuxCore.md @@ -14,8 +14,9 @@ Pages = ["LuxCore.md"] ## Abstract Types ```@docs -LuxCore.AbstractExplicitLayer -LuxCore.AbstractExplicitContainerLayer +LuxCore.AbstractLuxLayer +LuxCore.AbstractLuxWrapperLayer +LuxCore.AbstractLuxContainerLayer ``` ## General @@ -49,12 +50,6 @@ LuxCore.update_state ## Layer size -!!! warning - - These specifications have been added very recently and most layers currently do not - implement them. - ```@docs -LuxCore.inputsize LuxCore.outputsize ``` diff --git a/docs/src/api/Building_Blocks/LuxLib.md b/docs/src/api/Building_Blocks/LuxLib.md index 8075d83ce0..21bbe1510a 100644 --- a/docs/src/api/Building_Blocks/LuxLib.md +++ b/docs/src/api/Building_Blocks/LuxLib.md @@ -1,4 +1,4 @@ -# LuxLib +# [LuxLib](@id LuxLib-API) Backend for Lux.jl diff --git a/docs/src/api/Lux/contrib.md b/docs/src/api/Lux/contrib.md index a5143ae01f..93f412d3e7 100644 --- a/docs/src/api/Lux/contrib.md +++ b/docs/src/api/Lux/contrib.md @@ -8,12 +8,6 @@ All features listed on this page are **experimental** which means: experimental sooner. 3. None of the features are exported. -!!! warning - - Starting v"0.5.2" all Experimental features need to be accessed via - `Lux.Experimental.`. Direct access via `Lux.` will be removed in - v"0.6". - ## Index ```@index @@ -22,11 +16,6 @@ Pages = ["contrib.md"] ## Parameter Freezing -!!! info - - In the long term, this will be supported via - [Optimisers.jl](https://github.com/FluxML/Optimisers.jl/pull/49). - ```@docs Lux.Experimental.FrozenLayer Lux.Experimental.freeze @@ -39,7 +28,6 @@ For detailed usage example look at the [manual page](@ref freezing-model-paramet ```@docs Lux.Experimental.layer_map -Lux.Experimental.@layer_map ``` ## Debugging Functionality @@ -56,15 +44,3 @@ Lux.Experimental.DebugLayer ```@docs Lux.Experimental.share_parameters ``` - -## StatefulLuxLayer - -[`Lux.StatefulLuxLayer`](@ref) used to be part of experimental features, but has been -promoted to stable API. It is now available via `Lux.StatefulLuxLayer`. Change all uses of -`Lux.Experimental.StatefulLuxLayer` to `Lux.StatefulLuxLayer`. - -## Compact Layer API - -[`Lux.@compact`](@ref) used to be part of experimental features, but has been promoted to -stable API. It is now available via `Lux.@compact`. Change all uses of -`Lux.Experimental.@compact` to `Lux.@compact`. diff --git a/docs/src/api/Lux/interop.md b/docs/src/api/Lux/interop.md index 8dce085a3a..f5c81fc71a 100644 --- a/docs/src/api/Lux/interop.md +++ b/docs/src/api/Lux/interop.md @@ -37,20 +37,7 @@ preserving the [layer interface](@ref lux-interface). `using SimpleChains` must be present somewhere in the code for these to be used. ```@docs -Adapt.adapt(from::ToSimpleChainsAdaptor, L::AbstractExplicitLayer) +Adapt.adapt(from::ToSimpleChainsAdaptor, L::AbstractLuxLayer) ToSimpleChainsAdaptor SimpleChainsLayer ``` - -## Symbolic Expressions - -### Embedding DynamicExpressions.jl Node in Lux Layers - -!!! tip - - Accessing these functions require manually loading `DynamicExpressions`, i.e., - `using DynamicExpressions` must be present somewhere in the code for these to be used. - -```@docs -DynamicExpressionsLayer -``` diff --git a/docs/src/api/Lux/layers.md b/docs/src/api/Lux/layers.md index c6672b25dc..b591844aab 100644 --- a/docs/src/api/Lux/layers.md +++ b/docs/src/api/Lux/layers.md @@ -22,7 +22,6 @@ RepeatedLayer ```@docs Conv ConvTranspose -CrossCor ``` ## Dropout Layers @@ -36,10 +35,13 @@ VariationalHiddenDropout ## Pooling Layers ```@docs +AdaptiveLPPool AdaptiveMaxPool AdaptiveMeanPool +GlobalLPPool GlobalMaxPool GlobalMeanPool +LPPool MaxPool MeanPool ``` @@ -92,9 +94,3 @@ WeightNorm PixelShuffle Upsample ``` - -## SciML Layers - -```@docs -PeriodicEmbedding -``` diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index 7d8a94f5b0..744624fa1e 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -120,26 +120,8 @@ StatefulLuxLayer @non_trainable ``` -## Preferences +## Miscellaneous ```@docs Lux.set_dispatch_doctor_preferences! ``` - -## Truncated Stacktraces (Deprecated) - -```@docs -Lux.disable_stacktrace_truncation! -``` - -## Device Management / Data Transfer (Deprecated) - -```@docs -Lux.cpu -Lux.gpu -``` - -!!! warning - - For detailed API documentation on Data Transfer check out the - [LuxDeviceUtils.jl](@ref LuxDeviceUtils-API) diff --git a/docs/src/introduction/index.md b/docs/src/introduction/index.md index 46a35f2839..43a1c8717d 100644 --- a/docs/src/introduction/index.md +++ b/docs/src/introduction/index.md @@ -11,6 +11,11 @@ import Pkg Pkg.add("Lux") ``` +!!! tip "Update to v1" + + If you are using a pre-v1 version of Lux.jl, please see the + [Updating to v1 section](@ref updating-to-v1) for instructions on how to update. + ## Quickstart !!! tip "Pre-Requisites" diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md new file mode 100644 index 0000000000..665b4c0747 --- /dev/null +++ b/docs/src/introduction/updating_to_v1.md @@ -0,0 +1,146 @@ +# [Updating to Lux v1](@id updating-to-v1) + +Lux v1 is a Major Release, mostly to signify the stability of the API. In this page, we list +out a concrete set of changes that need to be made to your code to update to Lux v1. We also +list out some new exciting features that were added as part of this release. + +## `LuxLib.jl` + +### Breaking Changes + +- Old deprecated API with keyword arguments has been removed. See the new docs in [LuxLib + API](@ref LuxLib-API) for more details. +- Default for [`layernorm`](@ref) dims has been changed to exclude the batch dimension. + +### New Major Features + +- Dense layers now support CUDA backend for Enzyme (starting `v1.1`). Wider support for + other operations with Enzyme + CUDA is being actively worked on. + +## `LuxCore.jl` + +### Breaking Changes + +- `AbstractExplicitLayer` has been renamed to `AbstractLuxLayer`. +- `AbstractExplicitContainerLayer` behaviour + - This has been renamed to `AbstractLuxContainerLayer`. + - Previously, `AbstractExplicitContainerLayer{(:a,)}` (i.e. singleton containers) would + produce default initial parameters and states without wrapping them in a + `NamedTuple{(:a,)}`. This was inconsistent with non-singleton containers, and was a + source of confusion. With `v` we return `(; a = )` and `(; a = )` + by default. See [`AbstractLuxWrapperLayer`](@ref) for a replacement of this + functionality. +- `inputsize` has been removed since it was ambiguous and not used anywhere. +- Changes to `outputsize`: + - Single argument version has been removed. See [LuxCore.jl Pull Request + 43](https://github.com/LuxDL/LuxCore.jl/pull/43#issuecomment-2254232817) for more + details on the rationale behind this change. + - Fallback implementation has been moved to `Lux.jl`. (i.e. users using Lux shouldn't + see a difference, but if `Lux.jl` isn't loaded, this function has error.) + - Internally this uses a `NilArray` that is able to compute sizes without actually + running the computation. +- `Functors` and `Setfield` have been made into optional dependencies. Certain `LuxCore` + functionality that rely on these functions, will throw an error if these packages are not + loaded. + +### New Major Features + +- Introduction of [`AbstractLuxWrapperLayer`](@ref). This behaves exactly like the old + singleton container. For example, the old `AbstractExplicitContainerLayer{(:a,)}` is + equivalent to `AbstractLuxWrapperLayer{:a}`. + +## `WeightInitializers.jl` + +This was a major release to signify the stability of the API. There were no breaking +changes. We do support a wider range of RNG types, see +[Supported RNG Types](@ref Supported-RNG-Types-WeightInit) for more details. + +## `MLDataDevices.jl` + +This is the most aggressive change that was made. We renamed the `LuxDeviceUtils.jl` package +to `MLDataDevices.jl`, to allow for non-Lux packages to use this shared device management +abstraction. + +!!! warning "Deprecation of `LuxDeviceUtils.jl`" + + This also marks the deprecation of the `LuxDeviceUtils.jl` package. We won't be making + any updates to that package, including fixing any bugs. All users should switch to + `MLDataDevices.jl` instead. + +### Breaking Changes + +- `Lux(___)Device` objects have been renamed to `(___)Device`. For example, `LuxCUDADevice` + has been renamed to `CUDADevice`. +- `Lux(___)Adaptor` objects have been removed. The corresponding `Device` objects should be + used directly instead. + +### New Major Features + +- [`DeviceIterator`](@ref) provides a generalization of `CUDA.CuIterator` and works for all + backends and more data types (using `Functors.jl`). `MLUtils.DataLoader |> gdev` now + returns a `DeviceIterator` instead of being a no-op. + +## `Lux.jl` + +### Breaking Changes (Removed Functionality) + +- Direct reexport of `NNlib` has been removed. We reexport selected functionality from + `NNlib`. Direactly load `NNlib` if you need to use the other functions. +- Flattening of [`Chain`](@ref) layers has been removed, and the corresponding + `disable_optimizations` kwarg has been removed. +- Some layers overloaded `Base.keys`, these have been removed. These were mostly + un-documented and weren't supposed to be used outside of the `Lux.jl` package. +- [`Training.TrainState`](@ref) construction with `rng` has been removed. +- Older versions of Preferences have been removed. +- `disable_stacktrace_truncation!` has been removed. From Julia 1.9 onwards, stacktrace + truncation is enabled by default. +- Certain Experimental features were present outside the `Lux.Experimental` module. These + have been removed, use them via `Lux.Experimental` instead. Run Julia with with `depwarn` + as `error` and Lux `v0.5` to see the deprecations. +- `Lux.Experimental.@layer_map` is not longer needed and has been removed. The name of the + variable prevents writing generic functions and is no longer pre-pended to the `KeyPath`. + See the docstring of [`Lux.Experimental.layer_map`](@ref) for more details. +- `allow_fast_activation` kwarg has been removed completely. Pass an anonymous function + as the activation to prevent internal modivations to the activation function. + +### Breaking Changes (Moved Functionality) + +- `Lux.Experimental.Training` has been moved to `Lux.Training`. We guarantee SemVar + on this new module. +- `Lux.cpu` and `Lux.gpu` have been removed. Use [`cpu_device`](@ref) and + [`gpu_device`](@ref) instead. +- `Experimental.@compact` can be directly used via [`@compact`](@ref) now. +- `Experimental.StatefulLuxLayer` has been moved to [`Lux.StatefulLuxLayer`](@ref). +- `st_fixed_path` kwarg has been removed from [`Lux.StatefulLuxLayer`](@ref), instead use it + as `StatefulLuxLayer{st_fixed_path}(...)`. +- Strings as inputs to [`Lux.Experimental.layer_map`](@ref) and + [`Lux.Experimental.@debug_mode`](@ref) are removed, use `Functors.KeyPath` instead. +- `CrossCor` has been removed. Use `Conv(args...; kwargs..., cross_correlation=true)` + instead. + +### Breaking Changes (Changes in Defaults) + +- [`Conv`](@ref) and [`ConvTranspose`](@ref) use an initialization based on the activation + function, taken from Pytorch. Pytorch assumes the activation function is `leakyrelu` to + compute the gain, however, we compute the gain based on the activation function passed in + to the layer. +- [`Upsample`](@ref) now has an `align_corners` keyword argument, which defaults to `false`. + Previously this was always `true`. +- [`Dense`](@ref) and [`Bilinear`](@ref) have updated default initializations to align with + the defaults from Pytorch. See the documentation for more details. +- [`InstanceNorm`](@ref) now defaults to `affine=false` instead of `affine=true`. +- [`Embedding`](@ref) now defaults to `init_weight=rand32` instead of `init_weight=randn32`. +- Recurrent Cells - [`RNNCell`](@ref), [`LSTMCell`](@ref), and [`GRUCell`](@ref) now have + different default initializations. See the documentation for more details. + +### New Features + +- [`InstanceNorm`](@ref) now supports tracking statistics. +- [`RNNCell`](@ref) and [`LSTMCell`](@ref) add `bias_ih` and `bias_hh` to the parameters to + align with Pytorch. Both are controlled using `init_bias` and `use_bias`. +- [`ConvTranspose`](@ref) allows `flipkernel=true` via `cross_correlation=true`. This makes + it efficient for MIOpen. +- [`ConvTranspose`](@ref) now has an `outpad` keyword argument, which is used to increase + the size of the output in the desired dimensions. +- Pooling Layers based on lpnorm have been added -- [`LPPool`](@ref), + [`GlobalLPPool`](@ref), and [`AdaptiveLPPool`](@ref). diff --git a/docs/src/manual/debugging.md b/docs/src/manual/debugging.md index 642c1fb925..3c3395a7d5 100644 --- a/docs/src/manual/debugging.md +++ b/docs/src/manual/debugging.md @@ -21,8 +21,7 @@ will see how easy it is to pin-point the problematic layer. ```@example manual_debugging using Lux, Random -model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) +model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), BatchNorm(1)) model_debug = Lux.Experimental.@debug_mode model ``` @@ -63,12 +62,12 @@ model = Chain(Dense(1 => 16, relu), Dense(16 => 3), # [!code --] Dense(16 => 1), # [!code ++] Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) + BatchNorm(1)) ``` ```@example manual_debugging model_fixed = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) + BatchNorm(1)) ps, st = Lux.setup(rng, model_fixed) @@ -88,7 +87,7 @@ debug model. (or even disable it by setting it to `:none`) ```@example manual_debugging model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) + BatchNorm(1)) ps, st = Lux.setup(rng, model) @@ -131,8 +130,7 @@ offending_layer(x) = 2 .* x ``` ```@example manual_debugging -model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), offending_layer), - BatchNorm(1); disable_optimizations=true) +model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), offending_layer), BatchNorm(1)) ps, st = Lux.setup(rng, model) diff --git a/docs/src/manual/dispatch_custom_input.md b/docs/src/manual/dispatch_custom_input.md index 3495703d97..67d1a8b6af 100644 --- a/docs/src/manual/dispatch_custom_input.md +++ b/docs/src/manual/dispatch_custom_input.md @@ -5,10 +5,10 @@ * Defining a dispatch on `(::Layer)(x::MyInputType, ps, st::NamedTuple)` is inconvenient, since it requires the user to define a new method for every layer type. -* `(::AbstractExplicitLayer)(x::MyInputType, ps, st::NamedTuple)` doesn't work. +* `(::AbstractLuxLayer)(x::MyInputType, ps, st::NamedTuple)` doesn't work. * Instead, we need to define the dispatch on - `Lux.apply(::AbstractExplicitLayer, x::MyInputType, ps, st::NamedTuple)`. + `Lux.apply(::AbstractLuxLayer, x::MyInputType, ps, st::NamedTuple)`. ## Concrete Example @@ -22,7 +22,7 @@ define a time dependent version of [`Chain`](@ref). ```@example dispatch using Lux, Random -struct TDChain{L <: NamedTuple} <: Lux.AbstractExplicitContainerLayer{(:layers,)} +struct TDChain{L <: NamedTuple} <: Lux.AbstractLuxWrapperLayer{:layers} layers::L end @@ -66,10 +66,10 @@ struct ArrayAndTime{A <: AbstractArray, T <: Real} end ``` -* Define the dispatch on `Lux.apply(::AbstractExplicitLayer, x::ArrayAndTime, ps, st::NamedTuple)`. +* Define the dispatch on `Lux.apply(::AbstractLuxLayer, x::ArrayAndTime, ps, st::NamedTuple)`. ```@example dispatch -function Lux.apply(layer::Lux.AbstractExplicitLayer, x::ArrayAndTime, ps, st::NamedTuple) +function Lux.apply(layer::Lux.AbstractLuxLayer, x::ArrayAndTime, ps, st::NamedTuple) y, st = layer(x.array, ps, st) return ArrayAndTime(y, x.time), st end diff --git a/docs/src/manual/distributed_utils.md b/docs/src/manual/distributed_utils.md index 2f9c4be36f..677e473777 100644 --- a/docs/src/manual/distributed_utils.md +++ b/docs/src/manual/distributed_utils.md @@ -87,11 +87,11 @@ And that's pretty much it! as input. 3. We don't automatically determine if the MPI Implementation is CUDA or ROCM aware. See [GPU-aware MPI](@ref gpu-aware-mpi-preferences) for more information. -4. Older [`Lux.gpu`](@ref) implementations used to "just work" with `FluxMPI.jl`. We expect - [`LuxDeviceUtils.gpu_device`](@ref) to continue working as expected, however, we - recommend using [`LuxDeviceUtils.gpu_device`](@ref) after calling - [`DistributedUtils.initialize`](@ref) to avoid any mismatch between the device set - via `DistributedUtils` and the device stores in `LuxCUDADevice` or `LuxAMDGPUDevice` +4. Older (now non-existent) `Lux.gpu` implementations used to "just work" with `FluxMPI.jl`. + We expect [`gpu_device`](@ref) to continue working as expected, however, we recommend + using [`gpu_device`](@ref) after calling [`DistributedUtils.initialize`](@ref) to avoid + any mismatch between the device set via `DistributedUtils` and the device stores in + `CUDADevice` or `AMDGPUDevice`. ## Known Shortcomings diff --git a/docs/src/manual/freezing_model_parameters.md b/docs/src/manual/freezing_model_parameters.md index 9b5c8ffb99..5f2f4055e6 100644 --- a/docs/src/manual/freezing_model_parameters.md +++ b/docs/src/manual/freezing_model_parameters.md @@ -9,17 +9,15 @@ In this manual entry, we will go over how to freeze certain parameters in a mode ## Freezing Layers of a Particular Kind To freeze a particular kind of layer, let's say [`Dense`](@ref) in the following example. -We can use [`Lux.Experimental.@layer_map`](@ref) and freeze layers if they are of type +We can use [`Lux.Experimental.layer_map`](@ref) and freeze layers if they are of type `Dense`. -```@example +```@example freezing_model_parameters using Lux, Random -rng = Random.default_rng() -Random.seed!(rng, 0) +rng = Xoshiro(0) -model = Chain(Dense(3, 4), Chain(Dense(4, 4), Dropout(0.5f0), BatchNorm(4)), - Dense(4, 1); disable_optimizations=true) +model = Chain(Dense(3, 4), Chain(Dense(4, 4), Dropout(0.5f0), BatchNorm(4)), Dense(4, 1)) ps, st = Lux.setup(rng, model) @@ -27,12 +25,12 @@ x = randn(rng, Float32, 3, 2) model(x, ps, st) -function freeze_dense(d::Lux.Dense, ps, st, ::String) - return Lux.freeze(d, ps, st, (:weight, :bias)) +function freeze_dense(d::Lux.Dense, ps, st, path) + return Lux.Experimental.freeze(d, ps, st, (:weight, :bias)) end -freeze_dense(l, ps, st, name) = (l, ps, st) +freeze_dense(l, ps, st, path) = (l, ps, st) -model_frozen, ps_frozen, st_frozen = Lux.Experimental.@layer_map freeze_dense model ps st +model_frozen, ps_frozen, st_frozen = Lux.Experimental.layer_map(freeze_dense, model, ps, st) model_frozen(x, ps_frozen, st_frozen) ``` @@ -41,25 +39,23 @@ model_frozen(x, ps_frozen, st_frozen) When the function in `layer_map` is called, the 4th argument is the name of the layer. For example, if you want to freeze the 1st layer inside the inner Chain. The name for this -would be `.layer_2.layer_1`. +would be `layer_2.layer_1`. :::code-group ```julia [Freezing by Layer Name] -function freeze_by_name(d, ps, st, name::String) - if name == "model.layer_2.layer_1" +function freeze_by_name(d, ps, st, name::KeyPath) + name == KeyPath(:layer_2, :layer_1) && return Lux.Experimental.freeze(d, ps, st, (:weight, :bias)) - else - return d, ps, st - end + return d, ps, st end ``` ```julia [Freezing by Layer Type] -function freeze_dense(d::Dense, ps, st, ::String) +function freeze_dense(d::Dense, ps, st, _) return Lux.Experimental.freeze(d, ps, st, (:weight, :bias)) end freeze_dense(l, ps, st, _) = (l, ps, st) @@ -77,24 +73,20 @@ the `weight` parameter while training the `bias` parameter. ```julia [Freezing Some Parameters of a Layer] -function freeze_by_name(d, ps, st, name::String) - if name == "model.layer_2.layer_1" - return Lux.freeze(d, ps, st, (:weight,)) - else - return d, ps, st - end +function freeze_by_name(d, ps, st, name::KeyPath) + name == KeyPath(:layer_2, :layer_1) && + return Lux.Experimental.freeze(d, ps, st, (:weight,)) + return d, ps, st end ``` ```julia [Freezing All Parameters of a Layer] -function freeze_by_name(d, ps, st, name::String) - if name == "model.layer_2.layer_1" - return Lux.freeze(d, ps, st, (:weight, :bias)) - else - return d, ps, st - end +function freeze_by_name(d, ps, st, name::KeyPath) + name == KeyPath(:layer_2, :layer_1) && + return Lux.Experimental.freeze(d, ps, st, (:weight, :bias)) + return d, ps, st end ``` @@ -103,10 +95,7 @@ end ## Freezing Part of a Chain -Starting `v0.4.22`, we can directly index into a `Chain`. So freezing a part of a `Chain`, -is extremely easy. - -```@example +```@example freezing_model_parameters using Lux, Random rng = Random.default_rng() @@ -114,7 +103,7 @@ Random.seed!(rng, 0) model = Chain(Dense(3, 4), Dense(4, 4), Dropout(0.5f0), BatchNorm(4), Dense(4, 1)) -model_frozen = Chain(model[1:2], Lux.freeze(model[3:4]), model[5]) +model_frozen = Chain(model[1:2], Lux.Experimental.freeze(model[3:4]), model[5]) ps, st = Lux.setup(rng, model_frozen) x = randn(rng, Float32, 3, 2) diff --git a/docs/src/manual/gpu_management.md b/docs/src/manual/gpu_management.md index b879b12387..b6f578d259 100644 --- a/docs/src/manual/gpu_management.md +++ b/docs/src/manual/gpu_management.md @@ -24,8 +24,7 @@ supported_gpu_backends() Automatic Backend Management is done by two simple functions: `cpu_device` and `gpu_device`. -* [`LuxDeviceUtils.cpu_device`](@ref): This is a simple function and just returns a - `LuxCPUDevice` object. +* [`cpu_device`](@ref): This is a simple function and just returns a `CPUDevice` object. ```@example gpu_management cdev = cpu_device() @@ -35,9 +34,9 @@ cdev = cpu_device() x_cpu = randn(Float32, 3, 2) ``` -* [`LuxDeviceUtils.gpu_device`](@ref): This function performs automatic GPU device selection - and returns an object. - 1. If no GPU is available, it returns a `LuxCPUDevice` object. +* [`gpu_device`](@ref): This function performs automatic GPU device selection and returns + an object. + 1. If no GPU is available, it returns a `CPUDevice` object. 2. If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use `Lux.gpu_backend!()`. (a) If the trigger package corresponding to the device is not loaded, then a warning is displayed. (b) If no @@ -57,19 +56,19 @@ x_gpu = x_cpu |> gdev ## Manual Backend Management -Automatic Device Selection can be circumvented by directly using `LuxCPUDevice` and -`AbstractLuxGPUDevice` objects. +Automatic Device Selection can be circumvented by directly using `CPUDevice` and +`AbstractGPUDevice` objects. ```@example gpu_management cdev = cpu_device() x_cpu = randn(Float32, 3, 2) -if LuxDeviceUtils.functional(LuxCUDADevice) - gdev = LuxCUDADevice() +if MLDataDevices.functional(CUDADevice) + gdev = CUDADevice() x_gpu = x_cpu |> gdev -elseif LuxDeviceUtils.functional(LuxAMDGPUDevice) - gdev = LuxAMDGPUDevice() +elseif MLDataDevices.functional(AMDGPUDevice) + gdev = AMDGPUDevice() x_gpu = x_cpu |> gdev else @info "No GPU is available. Using CPU." diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md index 79d1db0f09..37e1cfb056 100644 --- a/docs/src/manual/interface.md +++ b/docs/src/manual/interface.md @@ -22,7 +22,7 @@ framework. ### Singular Layer If the layer doesn't contain any other Lux layer, then it is a `Singular Layer`. This means -it should optionally subtype `Lux.AbstractExplicitLayer` but mandatorily define +it should optionally subtype `Lux.AbstractLuxLayer` but mandatorily define all the necessary functions mentioned in the docstrings. Consider a simplified version of [`Dense`](@ref) called `Linear`. @@ -38,7 +38,7 @@ architecture cannot change. ```@example layer_interface using LuxCore, Random, WeightInitializers # Importing `Lux` also gives you access to `LuxCore` -struct Linear{F1, F2} <: LuxCore.AbstractExplicitLayer +struct Linear{F1, F2} <: LuxCore.AbstractLuxLayer in_dims::Int out_dims::Int init_weight::F1 @@ -120,13 +120,21 @@ LuxCore.apply(l, x, ps, st) # or `l(x, ps, st)` If your layer comprises of other Lux layers, then it is a `Container Layer`. Note that you could treat it as a [`Singular Layer`](#singular-layer), and it is still fine. FWIW, if you -cannot subtype your layer with `LuxCore.AbstractExplicitContainerLayer` then you +cannot subtype your layer with `LuxCore.AbstractLuxContainerLayer` then you should go down the [`Singular Layer`](#singular-layer) route. But subtyping allows us to bypass some of these common definitions. Let us now define a layer, which is basically a composition of two linear layers. +!!! tip "Wrapper Layer" + + If you are defining a layer that is a wrapper around another layer, then you should + subtype `LuxCore.AbstractLuxWrapperLayer` instead of + `LuxCore.AbstractLuxContainerLayer`. The only difference from a container layer is that + it can wrap a single layer and the parameter/state structure is exactly the same as the + wrapped layer. + ```@example layer_interface -struct ComposedLinear{L1, L2} <: LuxCore.AbstractExplicitContainerLayer{(:linear_1, :linear_2)} +struct ComposedLinear{L1, L2} <: LuxCore.AbstractLuxContainerLayer{(:linear_1, :linear_2)} linear_1::L1 linear_2::L2 end diff --git a/docs/src/manual/migrate_from_flux.md b/docs/src/manual/migrate_from_flux.md index 3b4e892323..7eb4352f1f 100644 --- a/docs/src/manual/migrate_from_flux.md +++ b/docs/src/manual/migrate_from_flux.md @@ -49,9 +49,9 @@ should be implemented. A summary of the differences would be: * Lux relies on the user to define `Lux.initialparameters` and `Lux.initialstates` to distinguish between trainable parameters (called "parameters") and non-trainable parameters (called "states"). Additionally, Lux layers define the model architecture, - hence device transfer utilities like [`LuxDeviceUtils.gpu_device`](@ref), - [`LuxDeviceUtils.cpu_device`](@ref), etc. cannot be applied on Lux layers, instead they - need to be applied on the parameters and states. + hence device transfer utilities like [`gpu_device`](@ref), [`cpu_device`](@ref), etc. + cannot be applied on Lux layers, instead they need to be applied on the parameters and + states. Let's work through a concrete example to demonstrate this. We will implement a very simple layer that computes ``A \times B \times x`` where ``A`` is not trainable and ``B`` is @@ -62,7 +62,7 @@ trainable. ```julia [Lux] using Lux, Random, NNlib, Zygote -struct LuxLinear <: Lux.AbstractExplicitLayer +struct LuxLinear <: Lux.AbstractLuxLayer init_A init_B end diff --git a/docs/src/manual/nested_autodiff.md b/docs/src/manual/nested_autodiff.md index 0a5e074a47..497179c11d 100644 --- a/docs/src/manual/nested_autodiff.md +++ b/docs/src/manual/nested_autodiff.md @@ -22,7 +22,7 @@ Let's explore this using some questions that were posted on the [Julia Discourse forum](https://discourse.julialang.org/). ```@example nested_ad -using ADTypes, Lux, LinearAlgebra, Zygote, ForwardDiff, Random +using ADTypes, Lux, LinearAlgebra, Zygote, ForwardDiff, Random, StableRNGs using ComponentArrays, FiniteDiff ``` @@ -70,15 +70,15 @@ function loss_function1(model, x, ps, st, y) loss_emp = sum(abs2, ŷ .- y) # You can use `Zygote.jacobian` as well but ForwardDiff tends to be more efficient here J = ForwardDiff.jacobian(smodel, x) - loss_reg = abs2(norm(J)) + loss_reg = abs2(norm(J .* 0.01f0)) return loss_emp + loss_reg end # Using Batchnorm to show that it is possible model = Chain(Dense(2 => 4, tanh), BatchNorm(4), Dense(4 => 2)) -ps, st = Lux.setup(Xoshiro(0), model) -x = rand(Xoshiro(0), Float32, 2, 10) -y = rand(Xoshiro(11), Float32, 2, 10) +ps, st = Lux.setup(StableRNG(0), model) +x = randn(StableRNG(0), Float32, 2, 10) +y = randn(StableRNG(11), Float32, 2, 10) loss_function1(model, x, ps, st, y) ``` @@ -97,9 +97,9 @@ Now let's verify the gradients using finite differences: ComponentArray(ps)) println("∞-norm(∂x - ∂x_fd): ", norm(∂x .- ∂x_fd, Inf)) -@assert norm(∂x .- ∂x_fd, Inf) < 1e-1 # hide +@assert norm(∂x .- ∂x_fd, Inf) < 1e-2 # hide println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf)) -@assert norm(ComponentArray(∂ps) .- ∂ps_fd, Inf) < 1e-1 # hide +@assert norm(ComponentArray(∂ps) .- ∂ps_fd, Inf) < 1e-2 # hide nothing; # hide ``` @@ -123,8 +123,8 @@ end model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 1)) -ps, st = Lux.setup(Xoshiro(0), model) -t = rand(Xoshiro(0), Float32, 1, 16) +ps, st = Lux.setup(StableRNG(0), model) +t = rand(StableRNG(0), Float32, 1, 16) ``` Now the moment of truth: @@ -164,9 +164,9 @@ end model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 1)) -ps, st = Lux.setup(Xoshiro(0), model) +ps, st = Lux.setup(StableRNG(0), model) ps = ComponentArray(ps) # needs to be an AbstractArray for most jacobian functions -x = rand(Xoshiro(0), Float32, 1, 16) +x = rand(StableRNG(0), Float32, 1, 16) ``` We can as usual compute the gradient/jacobian of the loss function: @@ -260,9 +260,9 @@ Now let's compute the trace and compare the results: ```@example nested_ad model = Chain(Dense(4 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 4)) -ps, st = Lux.setup(Xoshiro(0), model) -x = rand(Xoshiro(0), Float32, 4, 12) -v = (rand(Xoshiro(12), Float32, 4, 12) .> 0.5f0) * 2.0f0 .- 1.0f0 # rademacher sample +ps, st = Lux.setup(StableRNG(0), model) +x = rand(StableRNG(0), Float32, 4, 12) +v = (rand(StableRNG(12), Float32, 4, 12) .> 0.5f0) * 2.0f0 .- 1.0f0 # rademacher sample nothing; # hide ``` diff --git a/docs/src/manual/preferences.md b/docs/src/manual/preferences.md index eaea213ee6..357c77acb5 100644 --- a/docs/src/manual/preferences.md +++ b/docs/src/manual/preferences.md @@ -38,8 +38,8 @@ By default, both of these preferences are set to `false`. 1. `gpu_backend` - Set this to bypass the automatic backend selection and use a specific gpu backend. Valid options are "cuda", "rocm", "metal", and "oneapi". This preference - needs to be set for `LuxDeviceUtils` package. It is recommended to use - [`LuxDeviceUtils.gpu_backend!`](@ref) to set this preference. + needs to be set for `MLDataDevices` package. It is recommended to use + [`MLDataDevices.gpu_backend!`](@ref) to set this preference. ## [Automatic Eltype Conversion](@id automatic-eltypes-preference) diff --git a/docs/tutorials.jl b/docs/tutorials.jl index 68597e6e26..ab49870308 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -1,37 +1,49 @@ #! format: off const BEGINNER_TUTORIALS = [ - "Basics/main.jl", - "PolynomialFitting/main.jl", - "SimpleRNN/main.jl", - "SimpleChains/main.jl" + "Basics/main.jl" => "CUDA", + "PolynomialFitting/main.jl" => "CUDA", + "SimpleRNN/main.jl" => "CUDA", + "SimpleChains/main.jl" => "CPU" ] const INTERMEDIATE_TUTORIALS = [ - "NeuralODE/main.jl", - "BayesianNN/main.jl", - "HyperNet/main.jl" + "NeuralODE/main.jl" => "CUDA", + "BayesianNN/main.jl" => "CPU", + "HyperNet/main.jl" => "CUDA", ] const ADVANCED_TUTORIALS = [ - "GravitationalWaveForm/main.jl", - "SymbolicOptimalControl/main.jl" + "GravitationalWaveForm/main.jl" => "CPU", ] const TUTORIALS = [ - collect(enumerate(Iterators.product(["beginner"], BEGINNER_TUTORIALS)))..., - collect(enumerate(Iterators.product(["intermediate"], INTERMEDIATE_TUTORIALS)))..., - collect(enumerate(Iterators.product(["advanced"], ADVANCED_TUTORIALS)))... + collect(enumerate(Iterators.product(["beginner"], first.(BEGINNER_TUTORIALS))))..., + collect(enumerate(Iterators.product(["intermediate"], first.(INTERMEDIATE_TUTORIALS))))..., + collect(enumerate(Iterators.product(["advanced"], first.(ADVANCED_TUTORIALS))))... ] +const BACKEND_LIST = lowercase.([ + last.(BEGINNER_TUTORIALS)..., + last.(INTERMEDIATE_TUTORIALS)..., + last.(ADVANCED_TUTORIALS)... +]) #! format: on +const BACKEND_GROUP = lowercase(get(ENV, "TUTORIAL_BACKEND_GROUP", "all")) + const BUILDKITE_PARALLEL_JOB_COUNT = parse( Int, get(ENV, "BUILDKITE_PARALLEL_JOB_COUNT", "-1")) +const TUTORIALS_WITH_BACKEND = if BACKEND_GROUP == "all" + TUTORIALS +else + TUTORIALS[BACKEND_LIST .== BACKEND_GROUP] +end + const TUTORIALS_BUILDING = if BUILDKITE_PARALLEL_JOB_COUNT > 0 id = parse(Int, ENV["BUILDKITE_PARALLEL_JOB"]) + 1 # Index starts from 0 - splits = collect(Iterators.partition( - TUTORIALS, cld(length(TUTORIALS), BUILDKITE_PARALLEL_JOB_COUNT))) + splits = collect(Iterators.partition(TUTORIALS_WITH_BACKEND, + cld(length(TUTORIALS_WITH_BACKEND), BUILDKITE_PARALLEL_JOB_COUNT))) id > length(splits) ? [] : splits[id] else - TUTORIALS + TUTORIALS_WITH_BACKEND end const NTASKS = min( diff --git a/examples/Basics/Project.toml b/examples/Basics/Project.toml index 9e0c4c2943..01b75c2e2a 100644 --- a/examples/Basics/Project.toml +++ b/examples/Basics/Project.toml @@ -14,7 +14,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ComponentArrays = "0.15" ForwardDiff = "0.10" Literate = "2" -Lux = "0.5.56" -LuxCUDA = "0.2, 0.3" -Optimisers = "0.2, 0.3" +Lux = "1" +LuxCUDA = "0.3" +Optimisers = "0.3" Zygote = "0.6" diff --git a/examples/BayesianNN/Project.toml b/examples/BayesianNN/Project.toml index d4b30f07c3..8d6c24c2ef 100644 --- a/examples/BayesianNN/Project.toml +++ b/examples/BayesianNN/Project.toml @@ -15,8 +15,8 @@ CairoMakie = "0.12" Functors = "0.4" LinearAlgebra = "1" Literate = "2" -Lux = "0.5" +Lux = "1" Random = "1" Tracker = "0.2" -Turing = "0.30, 0.31, 0.32, 0.33, 0.34" +Turing = "0.34" Zygote = "0.6.69" diff --git a/examples/BayesianNN/main.jl b/examples/BayesianNN/main.jl index 62525f3f00..aa850d2ed8 100644 --- a/examples/BayesianNN/main.jl +++ b/examples/BayesianNN/main.jl @@ -110,7 +110,7 @@ end # To interface with external libraries it is often desirable to use the # [`StatefulLuxLayer`](@ref) to automatically handle the neural network states. -const model = StatefulLuxLayer(nn, st) +const model = StatefulLuxLayer{true}(nn, nothing, st) ## Specify the probabilistic model. @model function bayes_nn(xs, ts) diff --git a/examples/ConvMixer/Project.toml b/examples/ConvMixer/Project.toml index 97da256d59..ca93123123 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/ConvMixer/Project.toml @@ -22,11 +22,11 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Comonicon = "1.0.8" ConcreteStructs = "0.2.3" -DataAugmentation = "0.2.12, 0.3" +DataAugmentation = "0.3" ImageCore = "0.10.2" ImageShow = "0.3.8" Interpolations = "0.15.1" -Lux = "0.5.53" +Lux = "1" LuxCUDA = "0.3.2" MLDatasets = "0.7.14" MLUtils = "0.4.4" diff --git a/examples/ConvMixer/README.md b/examples/ConvMixer/README.md index fbea290333..f16d8850db 100644 --- a/examples/ConvMixer/README.md +++ b/examples/ConvMixer/README.md @@ -78,6 +78,8 @@ Flags 1. Weight-Decay with Adam in Optimisers.jl works differently from `torch.optim.AdamW`, so you might need to adjust the value of `--weight-decay` to get the same results. + Pytorch multiplies the weight decay with the learning rate, whereas in Optimisers.jl + the learning rate is decoupled from the weight decay. 2. To match the results from the original repo, we need more augmentation strategies, that are currently not implemented in DataAugmentation.jl. 3. Don't compare the reported timings in that repo against the numbers here. They time the diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index 372259d817..56ca4115f1 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -55,34 +55,34 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) #! format: on end -function accuracy(model, ps, st, dataloader; dev=gpu_device()) +function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 st = Lux.testmode(st) - cpu_dev = cpu_device() for (x, y) in dataloader target_class = onecold(y) - predicted_class = onecold(cpu_dev(first(model(dev(x), ps, st)))) + predicted_class = onecold(first(model(x, ps, st))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end return total_correct / total end -@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, +Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01) rng = StableRNG(seed) gdev = gpu_device() - trainloader, testloader = get_dataloaders(batchsize) + trainloader, testloader = get_dataloaders(batchsize) .|> gdev model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) + ps, st = Lux.setup(rng, model) |> gdev opt = AdamW(; eta=lr_max, lambda=weight_decay) clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) train_state = Training.TrainState( - rng, model, AdamW(; eta=lr_max, lambda=weight_decay); transform_variables=gdev) + model, ps, st, AdamW(; eta=lr_max, lambda=weight_decay)) lr_schedule = linear_interpolation( [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]) @@ -95,8 +95,6 @@ end for (i, (x, y)) in enumerate(trainloader) lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader)) train_state = Optimisers.adjust!(train_state, lr) - x = x |> gdev - y = y |> gdev (_, _, _, train_state) = Training.single_train_step!( AutoZygote(), loss, (x, y), train_state) end diff --git a/examples/DDIM/Project.toml b/examples/DDIM/Project.toml index 21d9914cdf..461bf2222d 100644 --- a/examples/DDIM/Project.toml +++ b/examples/DDIM/Project.toml @@ -1,8 +1,6 @@ [deps] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" @@ -25,19 +23,17 @@ TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -AMDGPU = "0.9.6, 1" ArgCheck = "2.3.0" CairoMakie = "0.12" -ChainRulesCore = "1.23" Comonicon = "1" ConcreteStructs = "0.2.3" -DataAugmentation = "0.2.12, 0.3" +DataAugmentation = "0.3" DataDeps = "0.7.13" FileIO = "1.16" ImageCore = "0.9, 0.10" ImageIO = "0.6" JLD2 = "0.4.48" -Lux = "0.5.52" +Lux = "1" LuxCUDA = "0.3" MLUtils = "0.4" Optimisers = " 0.3" diff --git a/examples/DDIM/main.jl b/examples/DDIM/main.jl index ec31991b5f..1a0039541f 100644 --- a/examples/DDIM/main.jl +++ b/examples/DDIM/main.jl @@ -6,11 +6,10 @@ # ## Package Imports -using ArgCheck, CairoMakie, ChainRulesCore, ConcreteStructs, Comonicon, DataAugmentation, - DataDeps, FileIO, ImageCore, JLD2, Lux, LuxCUDA, MLUtils, Optimisers, - ParameterSchedulers, ProgressBars, Random, Setfield, StableRNGs, Statistics, Zygote +using ArgCheck, CairoMakie, ConcreteStructs, Comonicon, DataAugmentation, DataDeps, FileIO, + ImageCore, JLD2, Lux, LuxCUDA, MLUtils, Optimisers, ParameterSchedulers, ProgressBars, + Random, Setfield, StableRNGs, Statistics, Zygote using TensorBoardLogger: TBLogger, log_value, log_images -const CRC = ChainRulesCore CUDA.allowscalar(false) @@ -130,24 +129,22 @@ function ddim(rng::AbstractRNG, args...; min_signal_rate=0.02f0, max_signal_rate, dispatch=:DDIM) do x::AbstractArray{<:Real, 4} images = bn(x) rng = Lux.replicate(rng) - T = eltype(x) - noises = CRC.@ignore_derivatives randn!(rng, similar(images, T, size(images)...)) - diffusion_times = CRC.@ignore_derivatives rand!( - rng, similar(images, T, 1, 1, 1, size(images, 4))) + noises = rand_like(rng, images) + diffusion_times = rand_like(rng, images, (1, 1, 1, size(images, 4))) - noise_rates, signal_rates = __diffusion_schedules( + noise_rates, signal_rates = diffusion_schedules( diffusion_times, min_signal_rate, max_signal_rate) noisy_images = @. signal_rates * images + noise_rates * noises - pred_noises, pred_images = __denoise(unet, noisy_images, noise_rates, signal_rates) + pred_noises, pred_images = denoise(unet, noisy_images, noise_rates, signal_rates) @return noises, images, pred_noises, pred_images end end -function __diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_rate::T, +function diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_rate::T, max_signal_rate::T) where {T <: Real} start_angle = acos(max_signal_rate) end_angle = acos(min_signal_rate) @@ -160,8 +157,7 @@ function __diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_ return noise_rates, signal_rates end -function __denoise( - unet, noisy_images::AbstractArray{T, 4}, noise_rates::AbstractArray{T, 4}, +function denoise(unet, noisy_images::AbstractArray{T, 4}, noise_rates::AbstractArray{T, 4}, signal_rates::AbstractArray{T, 4}) where {T <: Real} pred_noises = unet((noisy_images, noise_rates .^ 2)) pred_images = @. (noisy_images - pred_noises * noise_rates) / signal_rates @@ -170,7 +166,7 @@ end # ## Helper Functions for Image Generation -function __reverse_diffusion( +function reverse_diffusion( model, initial_noise::AbstractArray{T, 4}, diffusion_steps::Int) where {T <: Real} num_images = size(initial_noise, 4) step_size = one(T) / diffusion_steps @@ -188,15 +184,15 @@ function __reverse_diffusion( # We start t = 1, and gradually decreases to t=0 diffusion_times = (ones(T, 1, 1, 1, num_images) .- step_size * step) |> dev - noise_rates, signal_rates = __diffusion_schedules( + noise_rates, signal_rates = diffusion_schedules( diffusion_times, min_signal_rate, max_signal_rate) - pred_noises, pred_images = __denoise( + pred_noises, pred_images = denoise( StatefulLuxLayer{true}(model.model.layers.unet, model.ps.unet, model.st.unet), noisy_images, noise_rates, signal_rates) next_diffusion_times = diffusion_times .- step_size - next_noisy_rates, next_signal_rates = __diffusion_schedules( + next_noisy_rates, next_signal_rates = diffusion_schedules( next_diffusion_times, min_signal_rate, max_signal_rate) next_noisy_images = next_signal_rates .* pred_images .+ @@ -206,14 +202,14 @@ function __reverse_diffusion( return pred_images end -function __denormalize(model::StatefulLuxLayer{true}, x::AbstractArray{<:Real, 4}) +function denormalize(model::StatefulLuxLayer, x::AbstractArray{<:Real, 4}) mean = reshape(model.st.bn.running_mean, 1, 1, 3, 1) var = reshape(model.st.bn.running_var, 1, 1, 3, 1) std = sqrt.(var .+ model.model.layers.bn.epsilon) return std .* x .+ mean end -function __save_images(output_dir, images::AbstractArray{<:Real, 4}) +function save_images(output_dir, images::AbstractArray{<:Real, 4}) imgs = Vector{Array{RGB, 2}}(undef, size(images, 4)) for i in axes(images, 4) img = @view images[:, :, :, i] @@ -224,7 +220,7 @@ function __save_images(output_dir, images::AbstractArray{<:Real, 4}) return imgs end -function __generate_and_save_image_grid(output_dir, imgs::Vector{<:AbstractArray{<:RGB, 2}}) +function generate_and_save_image_grid(output_dir, imgs::Vector{<:AbstractArray{<:RGB, 2}}) fig = Figure() nrows, ncols = 3, 4 for r in 1:nrows, c in 1:ncols @@ -238,11 +234,11 @@ function __generate_and_save_image_grid(output_dir, imgs::Vector{<:AbstractArray return end -function __generate( +function generate( model::StatefulLuxLayer, rng, image_size::NTuple{4, Int}, diffusion_steps::Int, dev) initial_noise = randn(rng, Float32, image_size...) |> dev - generated_images = __reverse_diffusion(model, initial_noise, diffusion_steps) - generated_images = __denormalize(model, generated_images) + generated_images = reverse_diffusion(model, initial_noise, diffusion_steps) + generated_images = denormalize(model, generated_images) return clamp01.(generated_images) end @@ -287,21 +283,23 @@ function Base.getindex(ds::FlowersDataset, i::Int) end function preprocess_image(image::Matrix{<:RGB}, image_size::Int) - return apply(CenterResizeCrop((image_size, image_size)), Image(image)) |> itemdata + return apply( + CenterResizeCrop((image_size, image_size)), DataAugmentation.Image(image)) |> + itemdata end const maeloss = MAELoss() function loss_function(model, ps, st, data) (noises, images, pred_noises, pred_images), st = Lux.apply(model, data, ps, st) - noise_loss = maeloss(noises, pred_noises) - image_loss = maeloss(images, pred_images) + noise_loss = maeloss(pred_noises, noises) + image_loss = maeloss(pred_images, images) return noise_loss, st, (; image_loss, noise_loss) end # ## Entry Point for our code -@main function main(; epochs::Int=100, image_size::Int=128, +Comonicon.@main function main(; epochs::Int=100, image_size::Int=128, batchsize::Int=128, learning_rate_start::Float32=1.0f-3, learning_rate_end::Float32=1.0f-5, weight_decay::Float32=1.0f-6, checkpoint_interval::Int=25, expt_dir=tempname(@__DIR__), @@ -316,7 +314,8 @@ end @info "Experiment directory: $(expt_dir)" - rng = StableRNG(1234) + rng = Random.default_rng() + Random.seed!(rng, 1234) image_dir = joinpath(expt_dir, "images") isdir(image_dir) || mkpath(image_dir) @@ -330,6 +329,7 @@ end @info "Building model" model = ddim(rng, (image_size, image_size); channels, block_depth, min_freq, max_freq, embedding_dims, min_signal_rate, max_signal_rate) + ps, st = Lux.setup(rng, model) |> gdev if inference_mode @argcheck saved_model_path!==nothing "`saved_model_path` must be specified for inference" @@ -338,28 +338,28 @@ end states = states |> gdev model = StatefulLuxLayer{true}(model, parameters, Lux.testmode(states)) - generated_images = __generate(model, StableRNG(generate_image_seed), + generated_images = generate(model, StableRNG(generate_image_seed), (image_size, image_size, 3, generate_n_images), diffusion_steps, gdev) |> cpu_device() path = joinpath(image_dir, "inference") @info "Saving generated images to $(path)" - imgs = __save_images(path, generated_images) - __generate_and_save_image_grid(path, imgs) + imgs = save_images(path, generated_images) + generate_and_save_image_grid(path, imgs) return end tb_dir = joinpath(expt_dir, "tb_logs") - @info "Logging Tensorboard logs to $(tb_dir). Run tensorboard with `tensorboard --logdir $(dirname(tb_dir))`" + @info "Tensorboard logs being saved to $(tb_dir). Run tensorboard with \ + `tensorboard --logdir $(dirname(tb_dir))`" tb_logger = TBLogger(tb_dir) tstate = Training.TrainState( - rng, model, AdamW(; eta=learning_rate_start, lambda=weight_decay); - transform_variables=gdev) + model, ps, st, AdamW(; eta=learning_rate_start, lambda=weight_decay)) @info "Preparing dataset" ds = FlowersDataset(x -> preprocess_image(x, image_size), true) - data_loader = DataLoader(ds; batchsize, collate=true, parallel=true) + data_loader = DataLoader(ds; batchsize, collate=true, parallel=true) |> gdev scheduler = CosAnneal(learning_rate_start, learning_rate_end, epochs) @@ -376,7 +376,6 @@ end for (i, data) in enumerate(data_loader) step += 1 - data = data |> gdev (_, _, stats, tstate) = Training.single_train_step!( AutoZygote(), loss_function, data, tstate) image_losses[i] = stats.image_loss @@ -394,13 +393,13 @@ end if epoch % generate_image_interval == 0 || epoch == epochs model_test = StatefulLuxLayer{true}( tstate.model, tstate.parameters, Lux.testmode(tstate.states)) - generated_images = __generate(model_test, StableRNG(generate_image_seed), + generated_images = generate(model_test, StableRNG(generate_image_seed), (image_size, image_size, 3, generate_n_images), diffusion_steps, gdev) |> cpu_device() path = joinpath(image_dir, "epoch_$(epoch)") @info "Saving generated images to $(path)" - imgs = __save_images(path, generated_images) + imgs = save_images(path, generated_images) log_images(tb_logger, "Generated Images", imgs; step) end diff --git a/examples/GravitationalWaveForm/Project.toml b/examples/GravitationalWaveForm/Project.toml index b60e84cd24..67e420f41c 100644 --- a/examples/GravitationalWaveForm/Project.toml +++ b/examples/GravitationalWaveForm/Project.toml @@ -5,8 +5,6 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" @@ -19,10 +17,8 @@ CairoMakie = "0.12" ComponentArrays = "0.15" LineSearches = "7" Literate = "2" -Lux = "0.5" -AMDGPU = "0.9.6, 1" -LuxCUDA = "0.3" +Lux = "1" Optimization = "3" -OptimizationOptimJL = "0.1, 0.2, 0.3" +OptimizationOptimJL = "0.3" OrdinaryDiffEq = "6" SciMLSensitivity = "7.57" diff --git a/examples/GravitationalWaveForm/main.jl b/examples/GravitationalWaveForm/main.jl index 718eefd462..56bbb23018 100644 --- a/examples/GravitationalWaveForm/main.jl +++ b/examples/GravitationalWaveForm/main.jl @@ -7,12 +7,10 @@ # ## Package Imports -using Lux, ComponentArrays, LineSearches, AMDGPU, LuxCUDA, OrdinaryDiffEq, Optimization, - OptimizationOptimJL, Printf, Random, SciMLSensitivity +using Lux, ComponentArrays, LineSearches, OrdinaryDiffEq, Optimization, OptimizationOptimJL, + Printf, Random, SciMLSensitivity using CairoMakie -CUDA.allowscalar(false) - # ## Define some Utility Functions # !!! tip @@ -234,7 +232,7 @@ ps, st = Lux.setup(Xoshiro(), nn) const params = ComponentArray{Float64}(ps) -const nn_model = StatefulLuxLayer(nn, st) +const nn_model = StatefulLuxLayer{true}(nn, nothing, st) # Now we define a system of odes which describes motion of point like particle with # Newtonian physics, uses diff --git a/examples/HyperNet/Project.toml b/examples/HyperNet/Project.toml index 161ecc3c89..354d535b55 100644 --- a/examples/HyperNet/Project.toml +++ b/examples/HyperNet/Project.toml @@ -4,7 +4,6 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" @@ -17,16 +16,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "0.2, 1" -ComponentArrays = "0.13, 0.14, 0.15" +ADTypes = "1" +ComponentArrays = "0.15" Literate = "2" -Lux = "0.5" -AMDGPU = "0.9.6, 1" -LuxCUDA = "0.2, 0.3" -MLDatasets = "0.5, 0.7" -MLUtils = "0.2, 0.3, 0.4" -OneHotArrays = "0.1, 0.2" -Optimisers = "0.2, 0.3" -Setfield = "0.8, 1" +Lux = "1" +LuxCUDA = "0.3" +MLDatasets = "0.7" +MLUtils = "0.4" +OneHotArrays = "0.2" +Optimisers = "0.3" +Setfield = "1" Statistics = "1" Zygote = "0.6" diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index 9045091c77..e0d96d4d72 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -2,8 +2,8 @@ # ## Package Imports -using Lux, ADTypes, ComponentArrays, AMDGPU, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, - Optimisers, Printf, Random, Setfield, Statistics, Zygote +using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers, + Printf, Random, Setfield, Statistics, Zygote CUDA.allowscalar(false) @@ -24,8 +24,8 @@ function load_datasets(n_train=1024, n_eval=32, batchsize=256) end # ## Implement a HyperNet Layer -function HyperNet(weight_generator::Lux.AbstractExplicitLayer, - core_network::Lux.AbstractExplicitLayer) +function HyperNet( + weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer) ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |> ComponentArray |> getaxes @@ -59,15 +59,12 @@ end # ## Define Utility Functions const loss = CrossEntropyLoss(; logits=Val(true)) -function accuracy(model, ps, st, dataloader, data_idx, gdev=gpu_device()) +function accuracy(model, ps, st, dataloader, data_idx) total_correct, total = 0, 0 st = Lux.testmode(st) - cpu_dev = cpu_device() for (x, y) in dataloader - x = x |> gdev - y = y |> gdev - target_class = onecold(cpu_dev(y)) - predicted_class = onecold(cpu_dev(model((data_idx, x), ps, st)[1])) + target_class = onecold(y) + predicted_class = onecold(first(model((data_idx, x), ps, st))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end @@ -80,32 +77,30 @@ function train() dataloaders = load_datasets() dev = gpu_device() - rng = Xoshiro(0) + ps, st = Lux.setup(rng, model) |> dev - train_state = Training.TrainState(rng, model, Adam(3.0f-4); transform_variables=dev) + train_state = Training.TrainState(model, ps, st, Adam(3.0f-4)) ### Lets train the model - nepochs = 10 + nepochs = 25 for epoch in 1:nepochs, data_idx in 1:2 - train_dataloader, test_dataloader = dataloaders[data_idx] + train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev stime = time() for (x, y) in train_dataloader - x = x |> dev - y = y |> dev (_, _, _, train_state) = Training.single_train_step!( AutoZygote(), loss, ((data_idx, x), y), train_state) end ttime = time() - stime train_acc = round( - accuracy(model, train_state.parameters, train_state.states, - train_dataloader, data_idx, dev) * 100; + accuracy(model, train_state.parameters, + train_state.states, train_dataloader, data_idx) * 100; digits=2) test_acc = round( - accuracy(model, train_state.parameters, train_state.states, - test_dataloader, data_idx, dev) * 100; + accuracy(model, train_state.parameters, + train_state.states, test_dataloader, data_idx) * 100; digits=2) data_name = data_idx == 1 ? "MNIST" : "FashionMNIST" @@ -116,22 +111,27 @@ function train() println() + test_acc_list = [0.0, 0.0] for data_idx in 1:2 - train_dataloader, test_dataloader = dataloaders[data_idx] + train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev train_acc = round( - accuracy(model, train_state.parameters, train_state.states, - train_dataloader, data_idx, dev) * 100; + accuracy(model, train_state.parameters, + train_state.states, train_dataloader, data_idx) * 100; digits=2) test_acc = round( - accuracy(model, train_state.parameters, train_state.states, - test_dataloader, data_idx, dev) * 100; + accuracy(model, train_state.parameters, + train_state.states, test_dataloader, data_idx) * 100; digits=2) data_name = data_idx == 1 ? "MNIST" : "FashionMNIST" @printf "[FINAL] \t %12s \t Training Accuracy: %.2f%% \t Test Accuracy: \ %.2f%%\n" data_name train_acc test_acc + test_acc_list[data_idx] = test_acc end + return test_acc_list end -train() +test_acc_list = train() +@assert test_acc_list[1] > 0.90 && test_acc_list[2] > 0.70 #hide +nothing #hide diff --git a/examples/ImageNet/Project.toml b/examples/ImageNet/Project.toml index c4ecbdc4ee..f5dd3601ff 100644 --- a/examples/ImageNet/Project.toml +++ b/examples/ImageNet/Project.toml @@ -27,25 +27,25 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +AMDGPU = "1" Augmentor = "0.6" Boltz = "0.1, 0.2, 0.3" Configurations = "0.17" FLoops = "0.2" FileIO = "1.16" Format = "1.3" -Functors = "0.2, 0.3, 0.4" +Functors = "0.4" Images = "0.26" JLD2 = "0.4.46" JpegTurbo = "0.1" -Lux = "0.5" -AMDGPU = "0.9.6, 1" -LuxCUDA = "0.2, 0.3" -MLUtils = "0.2.10, 0.3, 0.4" +Lux = "1" +LuxCUDA = "0.3" +MLUtils = "0.4" MPI = "0.20.19" Metalhead = "0.9" NCCL = "0.1.1" -OneHotArrays = "0.1, 0.2" -Optimisers = "0.2, 0.3" +OneHotArrays = "0.2" +Optimisers = "0.3" ParameterSchedulers = "0.4" Setfield = "1" SimpleConfig = "0.1" diff --git a/examples/ImageNet/main.jl b/examples/ImageNet/main.jl index 6332352950..5a8f225db6 100644 --- a/examples/ImageNet/main.jl +++ b/examples/ImageNet/main.jl @@ -5,7 +5,7 @@ using Augmentor, Configurations, Dates, FileIO, Functors, Images, MLUtils, OneHo import FLoops: ThreadedEx import Metalhead import MPI, NCCL -using AMDGPU, LuxCUDA +using LuxCUDA using Format # Distributed Training: NCCL for NVIDIA GPUs and MPI for anything else diff --git a/examples/ImageNet/utils.jl b/examples/ImageNet/utils.jl index ee5dcf74ff..44b1ef8ef3 100644 --- a/examples/ImageNet/utils.jl +++ b/examples/ImageNet/utils.jl @@ -2,13 +2,13 @@ CUDA.allowscalar(false) function unsafe_free! end -if LuxDeviceUtils.functional(LuxCUDADevice) +if MLDataDevices.functional(CUDADevice) function unsafe_free!(x) return hasmethod(CUDA.unsafe_free!, Tuple{typeof(x)}) ? CUDA.unsafe_free!(x) : nothing end unsafe_free!(x::OneHotArray) = CUDA.unsafe_free!(x.indices) -elseif LuxDeviceUtils.functional(LuxAMDGPUDevice) +elseif MLDataDevices.functional(AMDGPUDevice) function unsafe_free!(x) return hasmethod(AMDGPU.unsafe_free!, Tuple{typeof(x)}) ? AMDGPU.unsafe_free!(x) : nothing @@ -18,8 +18,8 @@ end function reclaim_all() GC.gc(true) - LuxDeviceUtils.functional(LuxCUDADevice) && CUDA.reclaim() - LuxDeviceUtils.functional(LuxAMDGPUDevice) && AMDGPU.reclaim() + MLDataDevices.functional(CUDADevice) && CUDA.reclaim() + MLDataDevices.functional(AMDGPUDevice) && AMDGPU.reclaim() return end @@ -147,7 +147,7 @@ end get_loggable_values(meter::ProgressMeter) = getproperty.(meter.meters, :average) # Optimisers State -function (dev::LuxDeviceUtils.AbstractLuxDevice)(l::Optimisers.Leaf) +function (dev::MLDataDevices.AbstractDevice)(l::Optimisers.Leaf) @set! l.state = dev(l.state) return l end diff --git a/examples/NeuralODE/Project.toml b/examples/NeuralODE/Project.toml index 3893288566..f586f60679 100644 --- a/examples/NeuralODE/Project.toml +++ b/examples/NeuralODE/Project.toml @@ -1,5 +1,4 @@ [deps] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" @@ -17,10 +16,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -AMDGPU = "0.9.6, 1" -ComponentArrays = "0.13, 0.14, 0.15" +ComponentArrays = "0.15" Literate = "2" -Lux = "0.5" +Lux = "1" LuxCUDA = "0.2, 0.3" MLDatasets = "0.5, 0.7" MLUtils = "0.2, 0.3, 0.4" diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index 55e0ae944d..085bcafedd 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -7,10 +7,10 @@ # ## Package Imports -using Lux, ComponentArrays, SciMLSensitivity, AMDGPU, LuxCUDA, Optimisers, OrdinaryDiffEq, - Random, Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf -import MLDatasets: MNIST -import MLUtils: DataLoader, splitobs +using Lux, ComponentArrays, SciMLSensitivity, LuxCUDA, Optimisers, OrdinaryDiffEq, Random, + Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf +using MLDatasets: MNIST +using MLUtils: DataLoader, splitobs CUDA.allowscalar(false) @@ -39,7 +39,7 @@ end # First we will use the [`@compact`](@ref) macro to define the Neural ODE Layer. function NeuralODECompact( - model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...) + model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...) return @compact(; model, solver, tspan, kwargs...) do x, p dudt(u, p, t) = vec(model(reshape(u, size(x)), p)) ## Note the `p.model` here @@ -54,8 +54,7 @@ end # The NeuralODE is a ContainerLayer, which stores a `model`. The parameters and states of # the NeuralODE are same as those of the underlying model. -struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <: - Lux.AbstractExplicitContainerLayer{(:model,)} +struct NeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <: Lux.AbstractLuxWrapperLayer{:model} model::M solver::So tspan::T @@ -63,7 +62,7 @@ struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <: end function NeuralODE( - model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...) + model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...) return NeuralODE(model, solver, tspan, kwargs) end @@ -107,13 +106,12 @@ end # ## Define Utility Functions const logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) -function accuracy(model, ps, st, dataloader; dev=gpu_device()) +function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 st = Lux.testmode(st) - cpu_dev = cpu_device() for (x, y) in dataloader target_class = onecold(y) - predicted_class = onecold(cpu_dev(first(model(dev(x), ps, st)))) + predicted_class = onecold(first(model(x, ps, st))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end @@ -126,7 +124,7 @@ function train(model_function; cpu::Bool=false, kwargs...) model, ps, st = create_model(model_function; dev, kwargs...) ## Training - train_dataloader, test_dataloader = loadmnist(128, 0.9) + train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) @@ -135,15 +133,13 @@ function train(model_function; cpu::Bool=false, kwargs...) for epoch in 1:nepochs stime = time() for (x, y) in train_dataloader - x = dev(x) - y = dev(y) _, _, _, tstate = Training.single_train_step!( AutoZygote(), logitcrossentropy, (x, y), tstate) end ttime = time() - stime - tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader; dev) - te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader; dev) + tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) + te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) @printf "[%d/%d] \t Time %.2fs \t Training Accuracy: %.5f%% \t Test \ Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc end @@ -177,8 +173,8 @@ train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true) # Starting `v0.5.5`, Lux provides a [`StatefulLuxLayer`](@ref) which can be used # to avoid the [`Box`ing of `st`](https://github.com/JuliaLang/julia/issues/15276). Using # the `@compact` API avoids this problem entirely. -struct StatefulNeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <: - Lux.AbstractExplicitContainerLayer{(:model,)} +struct StatefulNeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <: + Lux.AbstractLuxWrapperLayer{:model} model::M solver::So tspan::T @@ -186,12 +182,12 @@ struct StatefulNeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <: end function StatefulNeuralODE( - model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...) + model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...) return StatefulNeuralODE(model, solver, tspan, kwargs) end function (n::StatefulNeuralODE)(x, ps, st) - st_model = StatefulLuxLayer(n.model, ps, st) + st_model = StatefulLuxLayer{true}(n.model, ps, st) dudt(u, p, t) = st_model(u, p) prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps) return solve(prob, n.solver; n.kwargs...), st_model.st diff --git a/examples/PolynomialFitting/Project.toml b/examples/PolynomialFitting/Project.toml index 15eb039d58..a5c1183548 100644 --- a/examples/PolynomialFitting/Project.toml +++ b/examples/PolynomialFitting/Project.toml @@ -4,7 +4,6 @@ CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -13,11 +12,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "0.2, 1" +ADTypes = "1" CairoMakie = "0.12" Literate = "2" -Lux = "0.5" -AMDGPU = "0.9.6, 1" +Lux = "1" LuxCUDA = "0.3" Optimisers = "0.3" Statistics = "1" diff --git a/examples/PolynomialFitting/main.jl b/examples/PolynomialFitting/main.jl index 2db9b65916..50f32b447f 100644 --- a/examples/PolynomialFitting/main.jl +++ b/examples/PolynomialFitting/main.jl @@ -5,7 +5,7 @@ # ## Package Imports -using Lux, ADTypes, AMDGPU, LuxCUDA, Optimisers, Printf, Random, Statistics, Zygote +using Lux, ADTypes, LuxCUDA, Optimisers, Printf, Random, Statistics, Zygote using CairoMakie # ## Dataset @@ -55,12 +55,17 @@ opt = Adam(0.03f0) # functions provided by Lux. const loss_function = MSELoss() +const dev_cpu = cpu_device() +const dev_gpu = gpu_device() + +ps, st = Lux.setup(rng, model) |> dev_gpu + # ## Training # First we will create a [`Training.TrainState`](@ref) which is essentially a # convenience wrapper over parameters, states and optimizer states. -tstate = Training.TrainState(rng, model, opt) +tstate = Training.TrainState(model, ps, st, opt) # Now we will use Zygote for our AD requirements. @@ -79,9 +84,6 @@ function main(tstate::Training.TrainState, vjp, data, epochs) return tstate end -dev_cpu = cpu_device() -dev_gpu = gpu_device() - tstate = main(tstate, vjp_rule, (x, y), 250) y_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(x), tstate.parameters, tstate.states)[1]) nothing #hide diff --git a/examples/SimpleChains/Project.toml b/examples/SimpleChains/Project.toml index 1ff7ce3a2a..009fd8dcad 100644 --- a/examples/SimpleChains/Project.toml +++ b/examples/SimpleChains/Project.toml @@ -13,9 +13,9 @@ SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "0.2, 1" +ADTypes = "1" Literate = "2" -Lux = "0.5.20" +Lux = "1" MLDatasets = "0.7.14" MLUtils = "0.4" OneHotArrays = "0.2.5" diff --git a/examples/SimpleChains/main.jl b/examples/SimpleChains/main.jl index 480865cd2f..6726800594 100644 --- a/examples/SimpleChains/main.jl +++ b/examples/SimpleChains/main.jl @@ -8,8 +8,8 @@ # ## Package Imports using Lux, ADTypes, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf -import MLDatasets: MNIST -import SimpleChains: static +using MLDatasets: MNIST +using SimpleChains: SimpleChains # ## Loading MNIST function loadmnist(batchsize, train_split) @@ -19,7 +19,7 @@ function loadmnist(batchsize, train_split) imgs = dataset.features[:, :, 1:N] labels_raw = dataset.targets[1:N] - ## Process images into (H,W,C,BS) batches + ## Process images into (H, W, C, BS) batches x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) y_data = onehotbatch(labels_raw, 0:9) (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split) @@ -40,7 +40,7 @@ lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)), # We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining # the [`ToSimpleChainsAdaptor`](@ref) and providing the input dimensions. -adaptor = ToSimpleChainsAdaptor((static(28), static(28), static(1))) +adaptor = ToSimpleChainsAdaptor((28, 28, 1)) simple_chains_model = adaptor(lux_model) # ## Helper Functions @@ -61,9 +61,9 @@ end # ## Define the Training Loop function train(model; rng=Xoshiro(0), kwargs...) train_dataloader, test_dataloader = loadmnist(128, 0.9) + ps, st = Lux.setup(rng, model) - train_state = Training.TrainState( - rng, model, Adam(3.0f-4); transform_variables=identity) + train_state = Training.TrainState(model, ps, st, Adam(3.0f-4)) ### Warmup the model x_proto = randn(rng, Float32, 28, 28, 1, 1) @@ -72,10 +72,11 @@ function train(model; rng=Xoshiro(0), kwargs...) ### Lets train the model nepochs = 10 + tr_acc, te_acc = 0.0, 0.0 for epoch in 1:nepochs stime = time() for (x, y) in train_dataloader - (gs, _, _, train_state) = Training.single_train_step!( + gs, _, _, train_state = Training.single_train_step!( AutoZygote(), loss, (x, y), train_state) end ttime = time() - stime @@ -88,16 +89,20 @@ function train(model; rng=Xoshiro(0), kwargs...) @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \ %.2f%%\n" epoch nepochs ttime tr_acc te_acc end + + return tr_acc, te_acc end # ## Finally Training the Model # First we will train the Lux model -train(lux_model) +tr_acc, te_acc = train(lux_model) +@assert tr_acc > 0.75 && te_acc > 0.75 #hide nothing #hide # Now we will train the SimpleChains model train(simple_chains_model) +@assert tr_acc > 0.75 && te_acc > 0.75 #hide nothing #hide # On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of diff --git a/examples/SimpleRNN/Project.toml b/examples/SimpleRNN/Project.toml index 0932fbdcc8..9917b042a1 100644 --- a/examples/SimpleRNN/Project.toml +++ b/examples/SimpleRNN/Project.toml @@ -4,7 +4,6 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" @@ -14,13 +13,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "0.2.6, 1" +ADTypes = "1" JLD2 = "0.4" Literate = "2" -Lux = "0.5" -AMDGPU = "0.9.6, 1" -LuxCUDA = "0.2, 0.3" -MLUtils = "0.2, 0.3, 0.4" -Optimisers = "0.2, 0.3" +Lux = "1" +LuxCUDA = "0.3" +MLUtils = "0.4" +Optimisers = "0.3" Statistics = "1" Zygote = "0.6" diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index ef120c5a08..e0ba547138 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -9,8 +9,7 @@ # ## Package Imports -using ADTypes, Lux, AMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, - Statistics +using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics # ## Dataset @@ -42,7 +41,7 @@ end # ## Creating a Classifier -# We will be extending the `Lux.AbstractExplicitContainerLayer` type for our custom model +# We will be extending the `Lux.AbstractLuxContainerLayer` type for our custom model # since it will contain a lstm block and a classifier head. # We pass the fieldnames `lstm_cell` and `classifier` to the type to ensure that the @@ -52,8 +51,7 @@ end # To understand more about container layers, please look at # [Container Layer](@ref Container-Layer). -struct SpiralClassifier{L, C} <: - Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)} +struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)} lstm_cell::L classifier::C end @@ -130,22 +128,21 @@ accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred) # ## Training the Model function main(model_type) + dev = gpu_device() + ## Get the dataloaders - (train_loader, val_loader) = get_dataloaders() + train_loader, val_loader = get_dataloaders() .|> dev ## Create the model model = model_type(2, 8, 1) rng = Xoshiro(0) + ps, st = Lux.setup(rng, model) |> dev - dev = gpu_device() - train_state = Training.TrainState(rng, model, Adam(0.01f0); transform_variables=dev) + train_state = Training.TrainState(model, ps, st, Adam(0.01f0)) for epoch in 1:25 ## Train the model for (x, y) in train_loader - x = x |> dev - y = y |> dev - (_, loss, _, train_state) = Training.single_train_step!( AutoZygote(), lossfn, (x, y), train_state) @@ -155,8 +152,6 @@ function main(model_type) ## Validate the model st_ = Lux.testmode(train_state.states) for (x, y) in val_loader - x = x |> dev - y = y |> dev ŷ, st_ = model(x, train_state.parameters, st_) loss = lossfn(ŷ, y) acc = accuracy(ŷ, y) diff --git a/examples/SymbolicOptimalControl/Project.toml b/examples/SymbolicOptimalControl/Project.toml deleted file mode 100644 index fd41919312..0000000000 --- a/examples/SymbolicOptimalControl/Project.toml +++ /dev/null @@ -1,9 +0,0 @@ -[deps] -InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" - -[compat] -InteractiveUtils = "<0.0.1, 1" -Literate = "2" -Lux = "0.5" diff --git a/examples/SymbolicOptimalControl/main.jl b/examples/SymbolicOptimalControl/main.jl deleted file mode 100644 index b23498f844..0000000000 --- a/examples/SymbolicOptimalControl/main.jl +++ /dev/null @@ -1,5 +0,0 @@ -# # Solving Optimal Control Problems with Symbolic Universal Differential Equations - -# This tutorial has been been moved to Boltz.jl documentation. Refer to the the -# [Symbolic Optimal Control](https://luxdl.github.io/Boltz.jl/stable/tutorials/2_SymbolicOptimalControl) -# tutorial for more details. diff --git a/ext/LuxDynamicExpressionsExt.jl b/ext/LuxDynamicExpressionsExt.jl deleted file mode 100644 index 2b28885c4e..0000000000 --- a/ext/LuxDynamicExpressionsExt.jl +++ /dev/null @@ -1,154 +0,0 @@ -module LuxDynamicExpressionsExt - -using ChainRulesCore: NoTangent -using DynamicExpressions: DynamicExpressions, Node, OperatorEnum, eval_grad_tree_array, - eval_tree_array -using FastClosures: @closure -using ForwardDiff: ForwardDiff - -using Lux: Lux, NAME_TYPE, Chain, Parallel, WrappedFunction, DynamicExpressionsLayer -using MLDataDevices: CPUDevice - -@static if pkgversion(DynamicExpressions) ≥ v"0.19" - using DynamicExpressions: EvalOptions - - const EvalOptionsTypes = Union{Missing, EvalOptions, NamedTuple} -else - const EvalOptionsTypes = Union{Missing, NamedTuple} -end - -Lux.is_extension_loaded(::Val{:DynamicExpressions}) = true - -function Lux.DynamicExpressionsLayer(operator_enum::OperatorEnum, expressions::Node...; - name::NAME_TYPE=nothing, eval_options::EvalOptionsTypes=missing, - turbo::Union{Bool, Val, Missing}=missing, - bumper::Union{Bool, Val, Missing}=missing) - eval_options = construct_eval_options( - eval_options, construct_eval_options(turbo, bumper)) - - length(expressions) == 1 && return Lux.DynamicExpressionsLayer( - operator_enum, first(expressions), name, eval_options) - name_fn = name === nothing ? Returns(nothing) : @closure(i->"$(name)_$(i)") - #! format: off - return Chain( - Parallel(nothing, - ntuple(i -> DynamicExpressionsLayer(operator_enum, expressions[i], - name_fn(i), eval_options), length(expressions))...), - WrappedFunction{:direct_call}(Lux.Utils.stack1); - name="DynamicExpressionsLayer") - #! format: on -end - -function Lux.DynamicExpressionsLayer( - operator_enum::OperatorEnum, expressions::AbstractVector{<:Node}; kwargs...) - return Lux.DynamicExpressionsLayer(operator_enum, expressions...; kwargs...) -end - -construct_eval_options(::Missing, ::Missing) = (; turbo=Val(false), bumper=Val(false)) -function construct_eval_options(turbo::Union{Bool, Val}, ::Missing) - return construct_eval_options(turbo, Val(false)) -end -function construct_eval_options(::Missing, bumper::Union{Bool, Val}) - return construct_eval_options(Val(false), bumper) -end -function construct_eval_options(turbo::Union{Bool, Val}, bumper::Union{Bool, Val}) - Base.depwarn("`bumper` and `turbo` are deprecated. Use `eval_options` instead.", - :DynamicExpressionsLayer) - return (; turbo, bumper) -end - -construct_eval_options(::Missing, eval_options::EvalOptionsTypes) = eval_options -construct_eval_options(eval_options::EvalOptionsTypes, ::Missing) = eval_options -function construct_eval_options(::EvalOptionsTypes, ::EvalOptionsTypes) - throw(ArgumentError("`eval_options`, `turbo` and `bumper` are mutually exclusive. \ - Don't specify `eval_options` if you are using `turbo` or \ - `bumper`.")) -end - -function Lux.apply_dynamic_expression_internal( - de::Lux.DynamicExpressionsLayer, expr, operator_enum, x, ps) - Lux.update_de_expression_constants!(expr, ps) - @static if pkgversion(DynamicExpressions) ≥ v"0.19" - eval_options = EvalOptions(; de.eval_options.turbo, de.eval_options.bumper) - return first(eval_tree_array(expr, x, operator_enum; eval_options)) - else - return first(eval_tree_array( - expr, x, operator_enum; de.eval_options.turbo, de.eval_options.bumper)) - end -end - -function Lux.∇apply_dynamic_expression( - de::Lux.DynamicExpressionsLayer, expr, operator_enum, x, ps) - Lux.update_de_expression_constants!(expr, ps) - _, Jₓ, _ = eval_grad_tree_array( - expr, x, operator_enum; variable=Val(true), de.eval_options.turbo) - y, Jₚ, _ = eval_grad_tree_array( - expr, x, operator_enum; variable=Val(false), de.eval_options.turbo) - ∇apply_dynamic_expression_internal = @closure Δ -> begin - ∂x = Jₓ .* reshape(Δ, 1, :) - ∂ps = Jₚ * Δ - return NoTangent(), NoTangent(), NoTangent(), NoTangent(), ∂x, ∂ps, NoTangent() - end - return y, ∇apply_dynamic_expression_internal -end - -# Forward Diff rules -function Lux.apply_dynamic_expression(de::DynamicExpressionsLayer, expr, operator_enum, - x::AbstractMatrix{<:ForwardDiff.Dual{Tag, T, N}}, - ps, ::CPUDevice) where {T, N, Tag} - value_fn(x) = ForwardDiff.value(Tag, x) - partials_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - Lux.update_de_expression_constants!(expr, ps) - y, Jₓ, _ = eval_grad_tree_array( - expr, value_fn.(x), operator_enum; variable=Val(true), de.eval_options.turbo) - partials = ntuple( - @closure(i->dropdims(sum(partials_fn.(x, i) .* Jₓ; dims=1); dims=1)), N) - - fT = promote_type(eltype(y), T, eltype(Jₓ)) - partials_y = ForwardDiff.Partials{N, fT}.(tuple.(partials...)) - return ForwardDiff.Dual{Tag, fT, N}.(y, partials_y) -end - -function Lux.apply_dynamic_expression(de::DynamicExpressionsLayer, expr, operator_enum, x, - ps::AbstractVector{<:ForwardDiff.Dual{Tag, T, N}}, ::CPUDevice) where {T, N, Tag} - value_fn(x) = ForwardDiff.value(Tag, x) - partials_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - Lux.update_de_expression_constants!(expr, value_fn.(ps)) - y, Jₚ, _ = eval_grad_tree_array( - expr, x, operator_enum; variable=Val(false), de.eval_options.turbo) - partials = ntuple( - @closure(i->dropdims(sum(partials_fn.(ps, i) .* Jₚ; dims=1); dims=1)), N) - - fT = promote_type(eltype(y), T, eltype(Jₚ)) - partials_y = ForwardDiff.Partials{N, fT}.(tuple.(partials...)) - return ForwardDiff.Dual{Tag, fT, N}.(y, partials_y) -end - -function Lux.apply_dynamic_expression(de::DynamicExpressionsLayer, expr, operator_enum, - x::AbstractMatrix{<:ForwardDiff.Dual{Tag, T1, N}}, - ps::AbstractVector{<:ForwardDiff.Dual{Tag, T2, N}}, - ::CPUDevice) where {T1, T2, N, Tag} - value_fn(x) = ForwardDiff.value(Tag, x) - partials_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - ps_value = value_fn.(ps) - x_value = value_fn.(x) - - Lux.update_de_expression_constants!(expr, ps_value) - _, Jₓ, _ = eval_grad_tree_array( - expr, x_value, operator_enum; variable=Val(true), de.eval_options.turbo) - y, Jₚ, _ = eval_grad_tree_array( - expr, x_value, operator_enum; variable=Val(false), de.eval_options.turbo) - partials = ntuple( - @closure(i->dropdims(sum(partials_fn.(x, i) .* Jₓ; dims=1); dims=1) .+ - dropdims(sum(partials_fn.(ps, i) .* Jₚ; dims=1); dims=1)), - N) - - fT = promote_type(eltype(y), T1, T2, eltype(Jₓ), eltype(Jₚ)) - partials_y = ForwardDiff.Partials{N, fT}.(tuple.(partials...)) - return ForwardDiff.Dual{Tag, fT, N}.(y, partials_y) -end - -end diff --git a/ext/LuxFluxExt.jl b/ext/LuxFluxExt.jl index 54c9675854..d0f89b2b0d 100644 --- a/ext/LuxFluxExt.jl +++ b/ext/LuxFluxExt.jl @@ -17,22 +17,19 @@ function Lux.convert_flux_model(l::T; preserve_ps_st::Bool=false, kwargs...) whe return Lux.FluxLayer(l) end -Lux.convert_flux_model(l::Function; kwargs...) = Lux.WrappedFunction{:direct_call}(l) +Lux.convert_flux_model(l::Function; kwargs...) = Lux.WrappedFunction(l) function Lux.convert_flux_model(l::Flux.Chain; kwargs...) fn = x -> Lux.convert_flux_model(x; kwargs...) layers = map(fn, l.layers) - if layers isa NamedTuple - return Lux.Chain(layers; disable_optimizations=true) - else - return Lux.Chain(layers...; disable_optimizations=true) - end + layers isa NamedTuple && return Lux.Chain(layers) + return Lux.Chain(layers...) end function Lux.convert_flux_model(l::Flux.Dense; preserve_ps_st::Bool=false, kwargs...) out_dims, in_dims = size(l.weight) if preserve_ps_st - bias = l.bias isa Bool ? nothing : reshape(copy(l.bias), out_dims, 1) + bias = l.bias isa Bool ? nothing : copy(l.bias) return Lux.Dense(in_dims => out_dims, l.σ; init_weight=Returns(copy(l.weight)), init_bias=Returns(bias), use_bias=!(l.bias isa Bool)) else @@ -100,8 +97,7 @@ function Lux.convert_flux_model(l::Flux.Conv; preserve_ps_st::Bool=false, kwargs groups = l.groups pad = l.pad isa Flux.SamePad ? SamePad() : l.pad if preserve_ps_st - _bias = l.bias isa Bool ? nothing : - reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1) + _bias = l.bias isa Bool ? nothing : vec(copy(l.bias)) return Lux.Conv(k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation, groups, init_weight=Returns(Lux.maybe_flip_conv_weight(l.weight)), init_bias=Returns(_bias), use_bias=!(l.bias isa Bool)) @@ -117,16 +113,16 @@ function Lux.convert_flux_model( out_chs, in_chs = size(l.weight)[(end - 1):end] groups = l.groups pad = l.pad isa Flux.SamePad ? SamePad() : l.pad + outpad = hasfield(typeof(l), :outpad) ? l.outpad : 0 if preserve_ps_st - _bias = l.bias isa Bool ? nothing : - reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1) - return Lux.ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, - pad, l.dilation, groups, use_bias=!(l.bias isa Bool), + _bias = l.bias isa Bool ? nothing : vec(copy(l.bias)) + return Lux.ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, pad, + outpad, l.dilation, groups, use_bias=!(l.bias isa Bool), init_weight=Returns(Lux.maybe_flip_conv_weight(l.weight)), init_bias=Returns(_bias)) else return Lux.ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, pad, - l.dilation, groups, use_bias=!(l.bias isa Bool)) + outpad, l.dilation, groups, use_bias=!(l.bias isa Bool)) end end @@ -135,14 +131,13 @@ function Lux.convert_flux_model(l::Flux.CrossCor; preserve_ps_st::Bool=false, kw in_chs, out_chs = size(l.weight)[(end - 1):end] pad = l.pad isa Flux.SamePad ? SamePad() : l.pad if preserve_ps_st - _bias = l.bias isa Bool ? nothing : - reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1) - return Lux.CrossCor(k, in_chs => out_chs, l.σ; l.stride, pad, - l.dilation, init_weight=Returns(copy(l.weight)), - init_bias=Returns(_bias), use_bias=!(l.bias isa Bool)) + _bias = l.bias isa Bool ? nothing : vec(copy(l.bias)) + return Lux.Conv(k, in_chs => out_chs, l.σ; l.stride, pad, l.dilation, + init_weight=Returns(copy(l.weight)), init_bias=Returns(_bias), + use_bias=!(l.bias isa Bool), cross_correlation=true) else - return Lux.CrossCor(k, in_chs => out_chs, l.σ; l.stride, pad, - l.dilation, use_bias=!(l.bias isa Bool)) + return Lux.Conv(k, in_chs => out_chs, l.σ; l.stride, pad, l.dilation, + use_bias=!(l.bias isa Bool), cross_correlation=true) end end @@ -179,59 +174,7 @@ Lux.convert_flux_model(::typeof(Flux.flatten); kwargs...) = Lux.FlattenLayer() Lux.convert_flux_model(l::Flux.PixelShuffle; kwargs...) = Lux.PixelShuffle(l.r) function Lux.convert_flux_model(l::Flux.Upsample{mode}; kwargs...) where {mode} - return Lux.Upsample(mode; l.scale, l.size) -end - -function Lux.convert_flux_model( - l::Flux.RNNCell; preserve_ps_st::Bool=false, force_preserve::Bool=false) - out_dims, in_dims = size(l.Wi) - if preserve_ps_st - if force_preserve - throw(FluxModelConversionException("Recurrent Cell: $(typeof(l)) for Flux uses a `reset!` mechanism which hasn't been extensively tested with `FluxLayer`. Rewrite the model manually to use `RNNCell`.")) - end - @warn "Preserving Parameters: `Wh` & `Wi` for `Flux.RNNCell` is ambiguous in Lux and hence not supported. Ignoring these parameters." maxlog=1 - return Lux.RNNCell(in_dims => out_dims, l.σ; init_bias=Returns(copy(l.b)), - init_state=Returns(copy(l.state0))) - else - return Lux.RNNCell(in_dims => out_dims, l.σ) - end -end - -function Lux.convert_flux_model( - l::Flux.LSTMCell; preserve_ps_st::Bool=false, force_preserve::Bool=false) - _out_dims, in_dims = size(l.Wi) - out_dims = _out_dims ÷ 4 - if preserve_ps_st - if force_preserve - throw(FluxModelConversionException("Recurrent Cell: $(typeof(l)) for Flux uses a `reset!` mechanism which hasn't been extensively tested with `FluxLayer`. Rewrite the model manually to use `LSTMCell`.")) - end - @warn "Preserving Parameters: `Wh` & `Wi` for `Flux.LSTMCell` is ambiguous in Lux \ - and hence not supported. Ignoring these parameters." maxlog=1 - bs = LuxOps.multigate(l.b, Val(4)) - _s, _m = copy.(l.state0) - return Lux.LSTMCell(in_dims => out_dims; init_bias=Returns.(bs), - init_state=Returns(_s), init_memory=Returns(_m)) - else - return Lux.LSTMCell(in_dims => out_dims) - end -end - -function Lux.convert_flux_model( - l::Flux.GRUCell; preserve_ps_st::Bool=false, force_preserve::Bool=false) - _out_dims, in_dims = size(l.Wi) - out_dims = _out_dims ÷ 3 - if preserve_ps_st - if force_preserve - throw(FluxModelConversionException("Recurrent Cell: $(typeof(l)) for Flux uses a `reset!` mechanism which hasn't been extensively tested with `FluxLayer`. Rewrite the model manually to use `GRUCell`.")) - end - @warn "Preserving Parameters: `Wh` & `Wi` for `Flux.GRUCell` is ambiguous in Lux \ - and hence not supported. Ignoring these parameters." maxlog=1 - bs = LuxOps.multigate(l.b, Val(3)) - return Lux.GRUCell( - in_dims => out_dims; init_bias=Returns.(bs), init_state=Returns(copy(l.state0))) - else - return Lux.GRUCell(in_dims => out_dims) - end + return Lux.Upsample(mode; l.scale, l.size, align_corners=false) end function Lux.convert_flux_model( @@ -274,4 +217,15 @@ function Lux.convert_flux_model(l::T; kwargs...) where {T <: _INVALID_TRANSFORMA throw(FluxModelConversionException("Transformation of type $(T) is not supported.")) end +for cell in (:RNNCell, :LSTMCell, :GRUCell) + msg = "Recurrent Cell: $(cell) for Flux has semantical difference with Lux, \ + mostly in-terms of how the bias term is dealt with. Lux aligns with the Pytorch \ + definition of these models and hence converting `Flux.$(cell)` to `Lux.$(cell) \ + is not possible. Rewrite the model manually." + @eval function Lux.convert_flux_model( + ::Flux.$(cell); preserve_ps_st::Bool=false, force_preserve::Bool=false) + throw(FluxModelConversionException($msg)) + end +end + end diff --git a/ext/LuxReverseDiffExt/rules.jl b/ext/LuxReverseDiffExt/rules.jl index 08dd5fffb4..247bbd1200 100644 --- a/ext/LuxReverseDiffExt/rules.jl +++ b/ext/LuxReverseDiffExt/rules.jl @@ -4,15 +4,6 @@ @grad_from_chainrules Lux.apply_simple_chain( layer, x::TrackedArray, ps::TrackedArray, ::CPUDevice) -# DynamicExpressions.jl -@grad_from_chainrules Lux.apply_dynamic_expression( - de::Lux.DynamicExpressionsLayer, expr, operator_enum, x::TrackedArray, ps, ::CPUDevice) -@grad_from_chainrules Lux.apply_dynamic_expression( - de::Lux.DynamicExpressionsLayer, expr, operator_enum, x, ps::TrackedArray, ::CPUDevice) -@grad_from_chainrules Lux.apply_dynamic_expression( - de::Lux.DynamicExpressionsLayer, expr, operator_enum, - x::TrackedArray, ps::TrackedArray, ::CPUDevice) - # Nested AD @grad_from_chainrules Lux.AutoDiffInternalImpl.batched_jacobian_internal( f, backend::AbstractADType, x::TrackedArray) diff --git a/ext/LuxSimpleChainsExt.jl b/ext/LuxSimpleChainsExt.jl index c7a607b250..a311559ac7 100644 --- a/ext/LuxSimpleChainsExt.jl +++ b/ext/LuxSimpleChainsExt.jl @@ -17,8 +17,7 @@ end function Lux.fix_simplechain_input_dims(layers, input_dims) @warn "The model provided is not a `Chain`. Trying to wrap it into a `Chain` but this \ - might fail. Please consider using `Chain` directly (potentially with \ - `disable_optimizations = true`)." + might fail. Please consider using `Chain` directly." return fix_simplechain_input_dims([layers], input_dims) end @@ -62,8 +61,9 @@ function Lux.make_simplechain_network(layer::FlattenLayer) end function Lux.make_simplechain_network(layer::MaxPool) - if layer.stride == layer.k && (!(layer.pad isa SamePad) && all(==(0), layer.pad)) - return SimpleChains.MaxPool(layer.k) + if layer.layer.mode.stride == layer.layer.mode.kernel_size && + all(==(0), layer.layer.mode.pad) + return SimpleChains.MaxPool(layer.layer.mode.kernel_size) end throw(SimpleChainsModelConversionException("MaxPool with non-standard parameters not \ supported.")) diff --git a/ext/LuxTrackerExt/LuxTrackerExt.jl b/ext/LuxTrackerExt/LuxTrackerExt.jl index 8ef071ee51..34dd0d5270 100644 --- a/ext/LuxTrackerExt/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt/LuxTrackerExt.jl @@ -7,7 +7,6 @@ using Tracker: Tracker, TrackedArray, TrackedReal, @grad_from_chainrules using Lux: Lux, Utils using Lux.Training: TrainingBackendCache, TrainState -using MLDataDevices: CPUDevice const CRC = ChainRulesCore diff --git a/ext/LuxTrackerExt/rules.jl b/ext/LuxTrackerExt/rules.jl index 70883d3a55..5a6b5468dd 100644 --- a/ext/LuxTrackerExt/rules.jl +++ b/ext/LuxTrackerExt/rules.jl @@ -1,40 +1,3 @@ -# SimpleChains.jl: DON'T REPLACE THESE WITH @grad_from_chainrules -for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray) - T1 === :AbstractArray && T2 === :AbstractArray && continue - - @eval function Lux.apply_simple_chain(layer, x::$(T1), ps::$(T2), dev::CPUDevice) - return Tracker.track(Lux.apply_simple_chain, layer, x, ps, dev) - end -end - -Tracker.@grad function Lux.apply_simple_chain(layer, x, ps, ::CPUDevice) - Base.depwarn("`Tracker.jl` often produces incorrect gradients for `SimpleChains.jl` \ - models. In future versions of Lux.jl you will need to load `Zygote.jl` \ - to use `Tracker.jl` for your model.", - :apply_simple_chain) - @warn "`Tracker.jl` often produces incorrect gradients for `SimpleChains.jl` models. \ - As such please test your model with `FiniteDiff.jl` or `Zygote.jl` before using \ - `Tracker.jl` for your model." maxlog=1 - y, pb_f = CRC.rrule(layer, Tracker.data(x), Tracker.data(ps)) - ∇apply_simple_chain = let pb_f = pb_f - Δ -> begin - _, ∂x, ∂ps = pb_f(convert(Array, Tracker.data(Δ))) - return Tracker.nobacksies(:apply_simple_chain, (nothing, ∂x, ∂ps, nothing)) - end - end - # Tracker is not great at handling arbitrary types, so we convert to Array - return Array(y), ∇apply_simple_chain -end - -# DynamicExpressions.jl -for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray) - T1 === :AbstractArray && T2 === :AbstractArray && continue - - @eval @grad_from_chainrules Lux.apply_dynamic_expression( - de::Lux.DynamicExpressionsLayer, expr, - operator_enum, x::$(T1), ps::$(T2), dev::CPUDevice) -end - # Nested AD @grad_from_chainrules Lux.AutoDiffInternalImpl.batched_jacobian_internal( f, backend::AbstractADType, x::TrackedArray) diff --git a/src/Lux.jl b/src/Lux.jl index 712996c457..37506abcaa 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -13,22 +13,22 @@ using Functors: Functors, fmap using GPUArraysCore: @allowscalar using LossFunctions: LossFunctions using Markdown: @doc_str +using NNlib: NNlib using Optimisers: Optimisers using Random: Random, AbstractRNG using Static: StaticBool, StaticInt, StaticSymbol, True, False, static, known, dynamic -using Reexport: @reexport +using Reexport: Reexport, @reexport using Statistics: mean using UnrolledUtilities: unrolled_map, unrolled_mapreduce -# TODO: In v1 we remove the LuxDeviceUtils dependency and replace it with MLDataDevices -@reexport using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers -using MLDataDevices: MLDataDevices, AMDGPUDevice, CPUDevice -using NNlib: NNlib - -import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer, initialparameters, - initialstates, parameterlength, statelength, inputsize, outputsize, +import LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer, + initialparameters, initialstates, parameterlength, statelength, outputsize, update_state, trainmode, testmode, setup, apply, replicate +@reexport using LuxCore, LuxLib, MLDataDevices, WeightInitializers +using NNlib: NNlib, DenseConvDims, PoolDims, logsigmoid, logsoftmax, maxpool, meanpool, + pixel_shuffle, sigmoid_fast, tanh_fast + const CRC = ChainRulesCore const NAME_TYPE = Union{Nothing, String, Symbol} @@ -58,6 +58,7 @@ include("layers/basic.jl") include("layers/containers.jl") include("layers/normalize.jl") include("layers/conv.jl") +include("layers/pooling.jl") include("layers/dropout.jl") include("layers/recurrent.jl") include("layers/extension.jl") @@ -84,16 +85,12 @@ include("transform/simplechains.jl") include("distributed/backend.jl") include("distributed/public_api.jl") -# Deprecations -include("deprecated.jl") - # Layers -export cpu, gpu # deprecated - export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer -export Bilinear, Dense, Embedding, Scale, PeriodicEmbedding -export Conv, ConvTranspose, CrossCor, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool, - AdaptiveMaxPool, AdaptiveMeanPool, Upsample, PixelShuffle +export Bilinear, Dense, Embedding, Scale +export Conv, ConvTranspose, Upsample, PixelShuffle +export MaxPool, MeanPool, LPPool, GlobalMaxPool, GlobalMeanPool, GlobalLPPool, + AdaptiveMaxPool, AdaptiveMeanPool, AdaptiveLPPool export AlphaDropout, Dropout, VariationalHiddenDropout export BatchNorm, GroupNorm, InstanceNorm, LayerNorm export WeightNorm @@ -118,10 +115,8 @@ export GenericLossFunction export f16, f32, f64 export match_eltype -export transform export FromFluxAdaptor, FluxLayer export ToSimpleChainsAdaptor, SimpleChainsLayer -export DynamicExpressionsLayer export MPIBackend, NCCLBackend, DistributedUtils @@ -129,7 +124,46 @@ export LuxOps # Unexported functions that are part of the public API @compat public Experimental -@compat public xlogx, xlogy # TODO: deprecated in v1.0 @compat public set_dispatch_doctor_preferences! +# NNlib.jl reexports +## Functional API for common layers. Recommended to use the LuxLib versions +using NNlib: ConvDims, DenseConvDims, PoolDims, batched_adjoint, batched_mul, batched_mul!, + batched_transpose, batched_vec, bias_act!, conv, conv!, conv_bias_act, + conv_bias_act!, dot_product_attention, dot_product_attention_scores, + make_causal_mask, lpnormpool, lpnormpool!, maxpool, maxpool!, meanpool, + meanpool!, pixel_shuffle, imrotate, ∇conv_data, ∇conv_data!, ∇conv_filter, + ∇conv_filter!, ∇lpnormpool, ∇lpnormpool!, ∇maxpool, ∇maxpool!, ∇meanpool, + ∇meanpool!, ∇imrotate +export ConvDims, DenseConvDims, PoolDims, batched_adjoint, batched_mul, batched_mul!, + batched_transpose, batched_vec, bias_act!, conv, conv!, conv_bias_act, + conv_bias_act!, dot_product_attention, dot_product_attention_scores, + make_causal_mask, lpnormpool, lpnormpool!, maxpool, maxpool!, meanpool, meanpool!, + pixel_shuffle, imrotate, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!, + ∇lpnormpool, ∇lpnormpool!, ∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!, ∇imrotate + +## Padding +using NNlib: pad_circular, pad_constant, pad_reflect, pad_repeat, pad_symmetric, pad_zeros +export pad_circular, pad_constant, pad_reflect, pad_repeat, pad_symmetric, pad_zeros + +## Upsample +using NNlib: upsample_linear, upsample_bilinear, upsample_trilinear, upsample_nearest, + ∇upsample_linear, ∇upsample_bilinear, ∇upsample_trilinear, ∇upsample_nearest +export upsample_linear, upsample_bilinear, upsample_trilinear, upsample_nearest, + ∇upsample_linear, ∇upsample_bilinear, ∇upsample_trilinear, ∇upsample_nearest + +## Activation Functions +using NNlib: σ, celu, elu, gelu, glu, hardsigmoid, hardswish, hardtanh, hardσ, leakyrelu, + lisht, logcosh, logsigmoid, logσ, mish, relu, relu6, rrelu, selu, sigmoid, + sigmoid_fast, softplus, softshrink, softsign, swish, tanhshrink, tanh_fast, + thresholdrelu, trelu +export σ, celu, elu, gelu, glu, hardsigmoid, hardswish, hardtanh, hardσ, leakyrelu, lisht, + logcosh, logsigmoid, logσ, mish, relu, relu6, rrelu, selu, sigmoid, sigmoid_fast, + softplus, softshrink, softsign, swish, tanhshrink, tanh_fast, thresholdrelu, trelu + +using NNlib: softmax, softmax!, logsoftmax, logsoftmax!, logsumexp, ∇logsoftmax, + ∇logsoftmax!, ∇softmax, ∇softmax! +export softmax, softmax!, logsoftmax, logsoftmax!, logsumexp, ∇logsoftmax, ∇logsoftmax!, + ∇softmax, ∇softmax! + end diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index fdcf6b0c2a..3e62563f32 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -1,8 +1,5 @@ module Experimental -using ..Lux: Lux, Training, Utils, Optional -using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, apply - using ADTypes: ADTypes using ArgCheck: @argcheck using ChainRulesCore: ChainRulesCore @@ -16,37 +13,20 @@ using Random: AbstractRNG, Random using Setfield: Setfield using Static: StaticSymbol, StaticBool, True, known, static, dynamic +using ..Lux: Lux, Optional +using ..Utils: Utils, BoolType, SymbolType +using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer, apply + const CRC = ChainRulesCore include("map.jl") include("freeze.jl") include("share_parameters.jl") include("debug.jl") -include("deprecated.jl") -@compat public layer_map, @layer_map +@compat public layer_map @compat public FrozenLayer, freeze, unfreeze @compat public share_parameters @compat public DebugLayer, @debug_mode end - -# Deprecations for v1.0 -macro layer_map(f, l, ps, st) - Base.depwarn( - "`Lux.@layer_map` has been deprecated in favor of `Lux.Experimental.@layer_map`", - Symbol("@layer_map")) - quote - Experimental.layer_map($(esc(f)), $(esc(l)), $(esc(ps)), $(esc(st)), $(string(l))) - end -end - -for f in (:layer_map, :share_parameters, :FrozenLayer, :freeze, :unfreeze) - msg = "`Lux.$(f)` has been deprecated in favor of `Lux.Experimental.$(f)`" - @eval begin - $(f)(args...; kwargs...) = begin - Base.depwarn($(msg), Symbol($(f))) - return Experimental.$(f)(args...; kwargs...) - end - end -end diff --git a/src/contrib/debug.jl b/src/contrib/debug.jl index 3187208824..7d7b388618 100644 --- a/src/contrib/debug.jl +++ b/src/contrib/debug.jl @@ -1,8 +1,8 @@ """ - DebugLayer(layer::AbstractExplicitLayer; + DebugLayer(layer::AbstractLuxLayer; nan_check::Union{Symbol, StaticSymbol, Val}=static(:both), error_check::Union{StaticBool, Bool, Val{true}, Val{false}}=True(), - location::Union{KeyPath, String}=KeyPath()) + location::KeyPath=KeyPath()) A wrapper over Lux layers that adds checks for NaNs and errors. This is useful for debugging. @@ -43,26 +43,16 @@ track where the error originates. See [`Lux.Experimental.@debug_mode`](@ref) to construct this layer. """ -@concrete struct DebugLayer <: AbstractExplicitContainerLayer{(:layer,)} +@concrete struct DebugLayer <: AbstractLuxWrapperLayer{:layer} nan_check <: StaticSymbol error_check <: StaticBool - layer <: AbstractExplicitLayer + layer <: AbstractLuxLayer location::KeyPath end -function DebugLayer(layer::AbstractExplicitLayer; - nan_check::Union{Symbol, StaticSymbol, Val}=static(:both), - error_check::Union{StaticBool, Bool, Val{true}, Val{false}}=True(), - location::Union{KeyPath, String}=KeyPath()) +function DebugLayer(layer::AbstractLuxLayer; nan_check::SymbolType=static(:both), + error_check::BoolType=True(), location::KeyPath=KeyPath()) @argcheck dynamic(nan_check) in (:both, :forward, :backward, :none) - - if location isa String - Base.depwarn( - "Using a String for location in DebugLayer is deprecated. Use \ - `Functors.KeyPath` instead.", :DebugLayer) - location = KeyPath(Symbol.(split(location, "."))...) - end - return DebugLayer(static(nan_check), static(error_check), layer, location) end diff --git a/src/contrib/deprecated.jl b/src/contrib/deprecated.jl deleted file mode 100644 index c81ff89ee4..0000000000 --- a/src/contrib/deprecated.jl +++ /dev/null @@ -1,16 +0,0 @@ -macro compact(exs...) - Base.depwarn( - "Lux.Experimental.@compact` has been promoted out of `Lux.Experimental` and is now \ - available in `Lux`. In other words this has been deprecated and will be removed \ - in v1. Use `Lux.@compact` instead.", - Symbol("@compact")) - return Lux.CompactMacroImpl.compact_macro_impl(exs...) -end - -Base.@deprecate StatefulLuxLayer(args...; kwargs...) Lux.StatefulLuxLayer( - args...; kwargs...) false - -for f in (:TrainState, :TrainingBackendCache, :single_train_step, :single_train_step!, - :apply_gradients, :apply_gradients!, :compute_gradients) - @eval Base.@deprecate $f(args...; kwargs...) Training.$f(args...; kwargs...) false -end diff --git a/src/contrib/freeze.jl b/src/contrib/freeze.jl index 3b32d7c094..cc0cfac74a 100644 --- a/src/contrib/freeze.jl +++ b/src/contrib/freeze.jl @@ -1,5 +1,5 @@ """ - FrozenLayer(l::AbstractExplicitLayer, which_params::Optional{Tuple}) + FrozenLayer(l::AbstractLuxLayer, which_params::Optional{Tuple}) Freeze the parameters with name `which_params` of the layer `l`. @@ -16,7 +16,7 @@ Freeze the parameters with name `which_params` of the layer `l`. ## Arguments - - `l`: Lux AbstractExplicitLayer. + - `l`: Lux AbstractLuxLayer. - `which_params`: Parameter Names to be Frozen. Can be set to `nothing`, in which case all parameters are frozen. @@ -46,10 +46,10 @@ FrozenLayer(Dense(2 => 2), (:weight,)) # 2 parameters, plus 4 non-trainable See also [`Lux.Experimental.freeze`](@ref), [`Lux.Experimental.unfreeze`](@ref). """ -struct FrozenLayer{which_params, L <: AbstractExplicitLayer} <: AbstractExplicitLayer +struct FrozenLayer{which_params, L <: AbstractLuxLayer} <: AbstractLuxLayer layer::L - function FrozenLayer(l::AbstractExplicitLayer, which_params::Optional{Tuple}=nothing) + function FrozenLayer(l::AbstractLuxLayer, which_params::Optional{Tuple}=nothing) if which_params !== nothing && length(which_params) == 0 @warn "Layer `FrozenLayer($l, (,))` is same as `l`, returning `l`." return l @@ -92,24 +92,24 @@ function Base.show(io::IO, f::FrozenLayer{which_params}) where {which_params} end """ - freeze(l::AbstractExplicitLayer, which_params::Optional{Tuple} = nothing) + freeze(l::AbstractLuxLayer, which_params::Optional{Tuple} = nothing) Constructs a version of `l` with `which_params` frozen. If `which_params` is nothing, then all parameters are frozen. """ -function freeze(l::AbstractExplicitLayer, which_params::Optional{Tuple}=nothing) +function freeze(l::AbstractLuxLayer, which_params::Optional{Tuple}=nothing) return FrozenLayer(l, which_params) end """ - freeze(l::AbstractExplicitLayer, ps, st::NamedTuple, + freeze(l::AbstractLuxLayer, ps, st::NamedTuple, which_params::Optional{Tuple} = nothing) Construct a [`Lux.Experimental.FrozenLayer`](@ref) for `l` with the current parameters and states. If `which_params` is nothing, then all parameters are frozen. """ function freeze( - l::AbstractExplicitLayer, ps, st::NamedTuple, which_params::Optional{Tuple}=nothing) + l::AbstractLuxLayer, ps, st::NamedTuple, which_params::Optional{Tuple}=nothing) fl = freeze(l, which_params) ps_frozen = [] ps_trainable = [] @@ -137,6 +137,6 @@ unfreeze(l::FrozenLayer) = l.layer Unwraps a [`Lux.Experimental.FrozenLayer`](@ref) `l` with the current parameters and states. """ -function unfreeze(fl::AbstractExplicitLayer, ps, st::NamedTuple) +function unfreeze(fl::AbstractLuxLayer, ps, st::NamedTuple) return unfreeze(fl), merge(ps, st.frozen_params), st.states end diff --git a/src/contrib/map.jl b/src/contrib/map.jl index 17f1612c7a..f5142f0db9 100644 --- a/src/contrib/map.jl +++ b/src/contrib/map.jl @@ -1,74 +1,25 @@ @doc doc""" - @layer_map func layer ps st - -See the documentation of [`Lux.Experimental.layer_map`](@ref) for more details. This macro -eliminates the need to the set the layer name, and uses the variable name as the starting -point. - -## Example - -```jldoctest -julia> using Lux, Random - -julia> c = Parallel( - +; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)), - dense_3=Dense(5 => 1)); - -julia> rng = Random.default_rng(); - -julia> ps, st = Lux.setup(rng, c); - -julia> # Makes parameters of Dense Layers inside Chain zero - function zero_dense_params(l, ps, st, name) - if l isa Dense - println("zeroing params of $name") - ps = merge(ps, (; weight=zero.(ps.weight), bias=zero.(ps.bias))) - end - return l, ps, st - end; - -julia> _, ps_new, _ = Lux.Experimental.@layer_map zero_dense_params c ps st; -zeroing params of c.layers.chain.layers.dense_1 -zeroing params of c.layers.chain.layers.dense_2 -zeroing params of c.layers.dense_3 - -julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias, - ps_new.chain.dense_2.weight, ps_new.chain.dense_2.bias, - ps_new.dense_3.weight, ps_new.dense_3.bias)) -true -``` -""" -macro layer_map(f, l, ps, st) - return quote - layer_map($(esc(f)), $(esc(l)), $(esc(ps)), $(esc(st)), $(string(l))) - end -end - -@doc doc""" - layer_map(f::Function, l::AbstractExplicitLayer, ps, st::NamedTuple, - name::String="model") + layer_map(f, l::AbstractLuxLayer, ps, st::NamedTuple) Map the function `f` over the model `l`, with the parameters `ps` and states `st`. This is different from `Functors.fmap` since it zips the layers, parameters, and states and invokes the function on all of them together. -## Call Signature for `f` - - - Must take 4 inputs -- `AbstractExplicitLayer`, Corresponding Parameters, Corresponding - States, and the name of the layer. - - Must return a tuple of 3 elements -- `AbstractExplicitLayer`, new parameters and the new - states. +!!! tip "KeyPath provided to the function" -!!! tip "Use `Lux.Experimental.@layer_map` instead" + The `KeyPath` depths on the structure of the parameters and states. This is of + consequence exclusively for [`AbstractLuxWrapperLayer`](@ref) where the structure of the + layer doesn't match the structure of the parameters and states. In the example, provided + below, the `KeyPath` is `(:chain, :dense_1)` for the first layer (following the + structure in `ps`) while accessing the same layer in the chain is done with `( + :chain, :layers, :dense_1)`. - We recommend using the macro `Lux.Experimental.@layer_map` instead of this function. It - automatically sets the `name` of the layer to be the variable name. - -!!! danger "Deprecation Notice" +## Call Signature for `f` - Starting `v1`, instead of the name of the layer, we will provide the [KeyPath to the - layer](https://fluxml.ai/Functors.jl/stable/api/#KeyPath). The current version of - providing a String has been deprecated. + - Must take 4 inputs -- `AbstractLuxLayer`, Corresponding Parameters, Corresponding + States, and the `Functors.KeyPath` to the layer. + - Must return a tuple of 3 elements -- `AbstractLuxLayer`, new parameters and the new + states. # Extended Help @@ -77,7 +28,6 @@ the function on all of them together. ```jldoctest julia> using Lux, Random - julia> c = Parallel( +; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)), dense_3=Dense(5 => 1)); @@ -96,9 +46,9 @@ julia> # Makes parameters of Dense Layers inside Chain zero end; julia> _, ps_new, _ = Lux.Experimental.layer_map(zero_dense_params, c, ps, st); -zeroing params of model.layers.chain.layers.dense_1 -zeroing params of model.layers.chain.layers.dense_2 -zeroing params of model.layers.dense_3 +zeroing params of KeyPath(:chain, :dense_1) +zeroing params of KeyPath(:chain, :dense_2) +zeroing params of KeyPath(:dense_3,) julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias, ps_new.chain.dense_2.weight, ps_new.chain.dense_2.bias, @@ -106,41 +56,50 @@ julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias, true ``` """ -function layer_map(f::F, l, ps, st, name::String="model") where {F <: Function} - # TODO: In v1 deprecate passing the string - f_wrapper = @closure (kp, layer, ps_, st_) -> f( - layer, ps_, st_, __keypath_to_string(name, kp)) - return fmap_with_path(f_wrapper, l, ps, st; walk=LayerWalkWithPath()) +function layer_map(f, l, ps, st) + return fmap_with_path(l, ps, st; walk=LayerWalkWithPath()) do kp, layer, ps_, st_ + return f(layer, ps_, st_, kp) + end end -__keypath_to_string(kp::KeyPath) = join(kp.keys, ".") -__keypath_to_string(str::String, kp::KeyPath) = "$(str).$(__keypath_to_string(kp))" - struct LayerWalkWithPath <: Functors.AbstractWalk end -function (::LayerWalkWithPath)(recurse, kp::KeyPath, layer, ps, st) - _layer_children, layer_re = functor(layer) +function (::LayerWalkWithPath)( + recurse::R, kp::KeyPath, layer::AbstractLuxWrapperLayer{field}, + ps, st) where {R, field} + layer_children, layer_re = functor(getfield(layer, field)) ps_children, ps_re = functor(ps) st_children, st_re = functor(st) - _children = keys(ps_children) - needs_correction = _children != keys(_layer_children) - _key = needs_correction ? only(keys(_layer_children)) : nothing - layer_children = needs_correction ? getfield(layer, _key) : _layer_children - @assert keys(layer_children) == keys(ps_children) == keys(st_children) + layer_children_new, ps_children_new, st_children_new = perform_layer_map( + recurse, kp, ps_children, st_children, layer_children) - kps = NamedTuple{_children}(map( - x -> needs_correction ? KeyPath(kp, _key, x) : KeyPath(kp, x), _children)) + inner_layer = layer_re(layer_children_new) + return (Setfield.set(layer, Setfield.PropertyLens{field}(), inner_layer), + ps_re(ps_children_new), st_re(st_children_new)) +end + +function (::LayerWalkWithPath)( + recurse::R, kp::KeyPath, layer::AbstractLuxLayer, ps, st) where {R} + layer_children, layer_re = functor(layer) + ps_children, ps_re = functor(ps) + st_children, st_re = functor(st) + + layer_children_new, ps_children_new, st_children_new = perform_layer_map( + recurse, kp, ps_children, st_children, layer_children) + + return layer_re(layer_children_new), ps_re(ps_children_new), st_re(st_children_new) +end + +function perform_layer_map(recurse, kp, ps_children, st_children, layer_children) + @argcheck keys(layer_children) == keys(ps_children) == keys(st_children) + + kps = NamedTuple{keys(ps_children)}(map(Base.Fix1(KeyPath, kp), keys(ps_children))) ys = map(recurse, kps, layer_children, ps_children, st_children) layer_children_new = map(Base.Fix2(getindex, 1), ys) ps_children_new = map(Base.Fix2(getindex, 2), ys) st_children_new = map(Base.Fix2(getindex, 3), ys) - layer_new = needs_correction ? layer_re(NamedTuple{(_key,)}((layer_children_new,))) : - layer_re(layer_children_new) - ps_new = ps_re(ps_children_new) - st_new = st_re(st_children_new) - - return layer_new, ps_new, st_new + return layer_children_new, ps_children_new, st_children_new end diff --git a/src/contrib/share_parameters.jl b/src/contrib/share_parameters.jl index 59c7e95bdf..9855089800 100644 --- a/src/contrib/share_parameters.jl +++ b/src/contrib/share_parameters.jl @@ -26,7 +26,7 @@ Updated Parameters having the same structure as `ps`. julia> model = Chain(; d1=Dense(2 => 4, tanh), d3=Chain(; l1=Dense(4 => 2), l2=Dense(2 => 4)), d2=Dense(4 => 2)) Chain( - d1 = Dense(2 => 4, tanh_fast), # 12 parameters + d1 = Dense(2 => 4, tanh), # 12 parameters d3 = Chain( l1 = Dense(4 => 2), # 10 parameters l2 = Dense(2 => 4), # 12 parameters diff --git a/src/custom_errors.jl b/src/custom_errors.jl index bf267a70ae..7ef16a186b 100644 --- a/src/custom_errors.jl +++ b/src/custom_errors.jl @@ -16,7 +16,7 @@ struct SimpleChainsModelConversionException <: AbstractLuxException msg::String end -function SimpleChainsModelConversionException(layer::AbstractExplicitLayer) +function SimpleChainsModelConversionException(layer::AbstractLuxLayer) return SimpleChainsModelConversionException("Conversion to SimpleChains not supported \ for $(typeof(layer))") end diff --git a/src/deprecated.jl b/src/deprecated.jl deleted file mode 100644 index 2073be8a6a..0000000000 --- a/src/deprecated.jl +++ /dev/null @@ -1,57 +0,0 @@ -# Deprecations for v1 -""" - cpu(x) - -Transfer `x` to CPU. - -!!! danger "Deprecation Notice" - - This function has been deprecated. Use [`cpu_device`](@ref) instead. -""" -function cpu end - -@deprecate cpu(x) (MLDataDevices.cpu_device())(x) - -""" - gpu(x) - -Transfer `x` to GPU determined by the backend set using [`Lux.gpu_backend!`](@ref). - -!!! danger "Deprecation Notice" - - This function has been deprecated. Use [`gpu_device`](@ref) instead. Using this function - inside performance critical code will cause massive slowdowns due to type inference - failure. -""" -function gpu end - -@deprecate gpu(x) (MLDataDevices.gpu_device())(x) - -""" - disable_stacktrace_truncation!(; disable::Bool=true) - -An easy way to update `TruncatedStacktraces.VERBOSE` without having to load it manually. - -Effectively does `TruncatedStacktraces.VERBOSE[] = disable` - -!!! danger "Deprecation Notice" - - This function is now deprecated and will be removed in v1. -""" -function disable_stacktrace_truncation!(; disable::Bool=true) - Base.depwarn( - "`disable_stacktrace_truncation!` is not needed anymore, as stacktraces are \ - truncated by default. This function is now deprecated and will be removed in v1.", - :disable_stacktrace_truncation) - return -end - -# Other deprecated functions -@deprecate xlogx(x::Number) LuxOps.xlogx(x) -@deprecate xlogy(x::Number, y::Number) LuxOps.xlogy(x, y) -@deprecate foldl_init(args...) LuxOps.foldl_init(args...) -@deprecate istraining(args...) LuxOps.istraining(args...) - -# While the ones below aren't public, we ended up using them at quite a few places -@deprecate _getproperty(args...) LuxOps.getproperty(args...) -@deprecate _eachslice(args...) LuxOps.eachslice(args...) diff --git a/src/extended_ops.jl b/src/extended_ops.jl index 2bcd12555d..ce9f662153 100644 --- a/src/extended_ops.jl +++ b/src/extended_ops.jl @@ -11,9 +11,10 @@ using ChainRulesCore: ChainRulesCore, NoTangent, ZeroTangent, @thunk, @non_diffe using Compat: @compat using EnzymeCore: EnzymeCore using FastClosures: @closure -using MLDataDevices: get_device_type, AbstractGPUDevice, AbstractDevice using Static: StaticBool, StaticSymbol, known +using MLDataDevices: get_device_type, AbstractGPUDevice, AbstractDevice + using ..Utils: Utils const CRC = ChainRulesCore @@ -235,7 +236,7 @@ const private_foldl_init = LuxOps.foldl_init # These are defined here to avoid a circular dependency among modules for (op, field) in (:bias => :use_bias, :affine => :affine, :track_stats => :track_stats, :train_state => :train_state) - @eval function $(Symbol(:has_, op))(l::AbstractExplicitLayer) + @eval function $(Symbol(:has_, op))(l::AbstractLuxLayer) res = known(safe_getproperty(l, Val($(Meta.quot(field))))) return ifelse(res === nothing, false, res) end diff --git a/src/helpers/compact.jl b/src/helpers/compact.jl index bbbb8544c2..f569cae811 100644 --- a/src/helpers/compact.jl +++ b/src/helpers/compact.jl @@ -293,7 +293,7 @@ macro non_trainable(x) end struct CompactLuxLayer{dispatch, F, N, L, V, SK} <: - AbstractExplicitContainerLayer{(:layers, :value_storage)} + AbstractLuxContainerLayer{(:layers, :value_storage)} d::StaticSymbol{dispatch} f::F name::N @@ -323,15 +323,14 @@ function CompactLuxLayer(dispatch::StaticSymbol, f::F, name::NAME_TYPE, setup_strings = NamedTuple() for (name, val) in pairs(kws) is_lux_layer = false - if val isa AbstractExplicitLayer + if val isa AbstractLuxLayer is_lux_layer = true push!(layers, name => val) elseif LuxCore.contains_lux_layer(val) # FIXME: This might lead to incorrect constructions? If the function is a # closure over the provided keyword arguments? val = CompactMacroImpl.try_make_lux_layer(val) - if LuxCore.check_fmap_condition( - !Base.Fix2(isa, AbstractExplicitLayer), nothing, val) + if LuxCore.check_fmap_condition(!Base.Fix2(isa, AbstractLuxLayer), nothing, val) throw(LuxCompactModelParsingException("A container `$(name) = $(val)` is \ found which combines Lux layers \ with non-Lux layers. This is not \ @@ -422,7 +421,7 @@ using MacroTools: MacroTools, @capture, combinedef, splitdef using Random: AbstractRNG using Static: static -using LuxCore: LuxCore, AbstractExplicitLayer +using LuxCore: LuxCore, AbstractLuxLayer using ..Lux: Lux, CompactLuxLayer, LuxCompactModelParsingException, StatefulLuxLayer, safe_getproperty @@ -585,7 +584,7 @@ end (f::InitFn)(args...) = f.f(args...) -@concrete struct ValueStorage <: AbstractExplicitLayer +@concrete struct ValueStorage <: AbstractLuxLayer ps_init_fns st_init_fns end @@ -657,7 +656,7 @@ function try_make_lux_layer(x::Union{AbstractVector, Tuple}) end try_make_lux_layer(x) = x -function maybe_make_stateful(layer::AbstractExplicitLayer, ps, st) +function maybe_make_stateful(layer::AbstractLuxLayer, ps, st) return StatefulLuxLayer{true}(layer, ps, st) end maybe_make_stateful(::Nothing, ::Nothing, st) = st diff --git a/src/helpers/losses.jl b/src/helpers/losses.jl index 0496f5b640..597f35fea1 100644 --- a/src/helpers/losses.jl +++ b/src/helpers/losses.jl @@ -120,7 +120,7 @@ end abstract type AbstractLossFunction <: Function end -function (loss::AbstractLossFunction)(model::AbstractExplicitLayer, ps, st, (x, y)) +function (loss::AbstractLossFunction)(model::AbstractLuxLayer, ps, st, (x, y)) ŷ, stₙ = model(x, ps, st) return loss(ŷ, y), stₙ, (;) end diff --git a/src/helpers/size_propagator.jl b/src/helpers/size_propagator.jl index 33d301e70e..954065e53d 100644 --- a/src/helpers/size_propagator.jl +++ b/src/helpers/size_propagator.jl @@ -1,6 +1,4 @@ # Initial design is based off of https://github.com/FluxML/Flux.jl/blob/942c6e5051b7a8cb064432d1f0604319497d5f09/src/outputsize.jl -# Currently this is not being used anywhere. However, with 1.0 release we will define -# outputsize for all layers using this. module NilSizePropagation using ArrayInterface: ArrayInterface @@ -199,25 +197,8 @@ end end -# TODO: In v1 we change to this `outputsize` function, till then this is private API -function compute_output_size(layer::AbstractExplicitLayer, - input_size::NTuple{N, <:Integer}, rng::AbstractRNG) where {N} - x = NilSizePropagation.NilArray{N}(input_size) - return compute_output_size(layer, x, rng) -end - -function compute_output_size( - layer::AbstractExplicitLayer, input_size::NTuple{N, <:Integer}, ps, st) where {N} - x = NilSizePropagation.NilArray{N}(input_size) - return compute_output_size(layer, x, ps, st) -end - -function compute_output_size(layer::AbstractExplicitLayer, x, rng::AbstractRNG) +function LuxCore.outputsize(layer::AbstractLuxLayer, x, rng::AbstractRNG) ps, st = setup(rng, layer) - return compute_output_size(layer, x, ps, st) -end - -function compute_output_size(layer::AbstractExplicitLayer, x, ps, st) x_nil = NilSizePropagation.recursively_nillify(x) ps_nil = NilSizePropagation.recursively_nillify(ps) st_nil = NilSizePropagation.recursively_nillify(st) diff --git a/src/helpers/stateful.jl b/src/helpers/stateful.jl index 9d47a1c86b..0fdf475ee6 100644 --- a/src/helpers/stateful.jl +++ b/src/helpers/stateful.jl @@ -1,10 +1,9 @@ """ - StatefulLuxLayer(model, ps, st; st_fixed_type = Val(true)) # deprecated - StatefulLuxLayer{ST}(model, ps, st) + StatefulLuxLayer{FT}(model, ps, st) !!! warning - This is not a Lux.AbstractExplicitLayer + This is not a Lux.AbstractLuxLayer A convenience wrapper over Lux layers which stores the parameters and states internally. This is meant to be used in internal implementation of layers. @@ -18,6 +17,13 @@ This is meant to be used in internal implementation of layers. - Facilitates Nested AD support in Lux. For more details on this feature, see the [Nested AD Manual Page](@ref nested_autodiff). +## Static Parameters + + - If `FT = true` then the type of the `state` is fixed, i.e., + `typeof(last(model(x, ps, st))) == st`. + - If `FT = false` then type of the state might change. Note that while this works in all + cases, it will introduce type instability. + ## Arguments - `model`: A Lux layer @@ -25,13 +31,6 @@ This is meant to be used in internal implementation of layers. the parameters on function call - `st`: The state of the layer -## Keyword Arguments - - - `st_fixed_type`: If `Val(true)`, then the type of the `state` is fixed, i.e., - `typeof(last(model(x, ps, st))) == st`. If this is not the case, then `st_fixed_type` - must be set to `Val(false)`. If `st_fixed_type` is set to `Val(false)`, then type - stability is not guaranteed. - ## Inputs - `x`: The input to the layer @@ -41,7 +40,7 @@ This is meant to be used in internal implementation of layers. - `y`: The output of the layer """ -mutable struct StatefulLuxLayer{ST, M <: AbstractExplicitLayer, psType, stType} +mutable struct StatefulLuxLayer{ST, M <: AbstractLuxLayer, psType, stType} const model::M ps::psType st::stType @@ -49,7 +48,7 @@ mutable struct StatefulLuxLayer{ST, M <: AbstractExplicitLayer, psType, stType} fixed_state_type::ST function StatefulLuxLayer( - model::AbstractExplicitLayer, ps, st, st_any, fixed_state_type::StaticBool) + model::AbstractLuxLayer, ps, st, st_any, fixed_state_type::StaticBool) return new{typeof(fixed_state_type), typeof(model), typeof(ps), typeof(st)}( model, ps, st, st_any, fixed_state_type) end @@ -59,19 +58,10 @@ function StatefulLuxLayer{ST}(model, ps, st, st_any) where {ST} return StatefulLuxLayer(model, ps, st, st_any, static(ST)) end -function StatefulLuxLayer(model::AbstractExplicitLayer, st::NamedTuple; kwargs...) - return StatefulLuxLayer(model, nothing, st; kwargs...) -end -function StatefulLuxLayer(model::AbstractExplicitLayer, ps, st::NamedTuple; - st_fixed_type::Val{ST}=Val(true)) where {ST} - Base.depwarn("`st_fixed_type` is deprecated. Use `StatefulLuxLayer{ST}` instead.", - :StatefulLuxLayer) - return StatefulLuxLayer{ST}(model, ps, st) -end -function StatefulLuxLayer{true}(model::AbstractExplicitLayer, ps, st::NamedTuple) +function StatefulLuxLayer{true}(model::AbstractLuxLayer, ps, st::NamedTuple) return StatefulLuxLayer{true}(model, ps, st, nothing) end -function StatefulLuxLayer{false}(model::AbstractExplicitLayer, ps, st::NamedTuple) +function StatefulLuxLayer{false}(model::AbstractLuxLayer, ps, st::NamedTuple) return StatefulLuxLayer{false}(model, ps, nothing, st) end @@ -131,10 +121,12 @@ function (s::StatefulLuxLayer)(x, p=s.ps) return y end -function CRC.rrule(::Type{<:StatefulLuxLayer{FT}}, - model::AbstractExplicitLayer, ps, st, st_any) where {FT} - slayer = StatefulLuxLayer{FT}(model, ps, st, st_any) - ∇StatefulLuxLayer(Δ) = NoTangent(), NoTangent(), Δ.ps, NoTangent(), NoTangent() +function CRC.rrule(::Type{<:StatefulLuxLayer}, model::AbstractLuxLayer, + ps, st, st_any, fixed_state_type) + slayer = StatefulLuxLayer(model, ps, st, st_any, fixed_state_type) + function ∇StatefulLuxLayer(Δ) + return NoTangent(), NoTangent(), Δ.ps, NoTangent(), NoTangent(), NoTangent() + end return slayer, ∇StatefulLuxLayer end diff --git a/src/helpers/training.jl b/src/helpers/training.jl index fa52f357a3..cd146169cf 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -4,12 +4,10 @@ using ADTypes: AbstractADType, AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZyg using Compat: @compat using ConcreteStructs: @concrete using FastClosures: @closure -using Optimisers: Optimisers, AbstractRule -using Random: AbstractRNG +using Optimisers: Optimisers using ..Lux: Lux -using LuxCore: LuxCore, AbstractExplicitLayer -using MLDataDevices: MLDataDevices +using LuxCore: LuxCore, AbstractLuxLayer """ TrainState @@ -45,11 +43,7 @@ Internal fields: end """ - TrainState(rng::Random.AbstractRNG, model::LuxCore.AbstractExplicitLayer, - optimizer::Optimisers.AbstractRule; - transform_variables::Union{Function, AbstractDevice}=gpu_device()) - TrainState(model::LuxCore.AbstractExplicitLayer, ps, st, - optimizer::Optimisers.AbstractRule) + TrainState(model::Lux.AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule) Constructor for [`TrainState`](@ref). @@ -67,21 +61,7 @@ Constructor for [`TrainState`](@ref). [`TrainState`](@ref) object. """ -function TrainState(rng::AbstractRNG, model::AbstractExplicitLayer, optimizer::AbstractRule; - transform_variables=MLDataDevices.gpu_device()) - Base.depwarn( - "`TrainState(rng::AbstractRNG, model::AbstractExplicitLayer, \ - optimizer::Optimisers.AbstractRule; transform_variables::Union{Function, \ - AbstractLuxDevice}=gpu_device())` has been deprecated in favor of \ - `TrainState(model::AbstractExplicitLayer, ps, st, \ - optimizer::Optimisers.AbstractRule)`", - :TrainState) - ps, st = LuxCore.setup(rng, model) .|> transform_variables - return TrainState(model, ps, st, optimizer) -end - -function TrainState( - model::AbstractExplicitLayer, ps, st, optimizer::Optimisers.AbstractRule) +function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule) st_opt = Optimisers.setup(optimizer, ps) return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0) end diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 77cd98f482..d05ff8050d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -1,3 +1,14 @@ +function init_linear_bias(rng::AbstractRNG, init_bias::F, fan_in::IntegerType, + bias_len::IntegerType) where {F} + if init_bias === nothing # Default from PyTorch + bound = inv(sqrt(fan_in)) + y = rand32(rng, bias_len) + @. y = (y - 0.5f0) * 2 * bound + return y + end + return init_bias(rng, bias_len) +end + """ ReshapeLayer(dims) @@ -32,11 +43,11 @@ julia> y, st_new = model(x, ps, st); (2, 2, 3) ``` """ -struct ReshapeLayer{N} <: AbstractExplicitLayer +struct ReshapeLayer{N} <: AbstractLuxLayer dims::NTuple{N, Int} end -outputsize(r::ReshapeLayer) = r.dims +outputsize(r::ReshapeLayer, _, ::AbstractRNG) = r.dims function (r::ReshapeLayer)(x::AbstractArray, _, st::NamedTuple) return reshape(x, r.dims..., size(x, ndims(x))), st @@ -81,7 +92,7 @@ julia> y, st_new = model(x, ps, st) ([3.0, 2.0, 1.0], NamedTuple()) ``` """ -@concrete struct ReverseSequence <: AbstractExplicitLayer +@concrete struct ReverseSequence <: AbstractLuxLayer dim <: Union{Nothing, StaticInt} end @@ -141,7 +152,7 @@ julia> y, st_new = model(x, ps, st); (8, 2) ``` """ -@concrete struct FlattenLayer <: AbstractExplicitLayer +@concrete struct FlattenLayer <: AbstractLuxLayer N <: Union{Nothing, StaticInt} end @@ -177,7 +188,7 @@ Return a view of all the data of the input `x` where the index for dimension `di - `view(x,:,:,...,i,:,:,...)` where `i` is in position `d` - Empty `NamedTuple()` """ -@concrete struct SelectDim <: AbstractExplicitLayer +@concrete struct SelectDim <: AbstractLuxLayer dim <: StaticInt index <: StaticInt end @@ -212,13 +223,12 @@ julia> y, st_new = model(x, ps, st) (1, NamedTuple()) ``` """ -struct NoOpLayer <: AbstractExplicitLayer end +struct NoOpLayer <: AbstractLuxLayer end (noop::NoOpLayer)(x, _, st::NamedTuple) = x, st """ - WrappedFunction{DC}(f) - WrappedFunction(f) -> WrappedFunction{:direct_call}(f) + WrappedFunction(f) Wraps a stateless and parameter less function. Might be used when a function is added to `Chain`. For example, `Chain(x -> relu.(x))` would not work and the right thing to do would @@ -227,10 +237,6 @@ be `Chain((x, ps, st) -> (relu.(x), st))`. An easier thing to do would be ## Arguments - - `DC`: If `:runtime_check`, then we check if the function can be called with the input - `x`, `ps`, and `st` using `hasmethod`. If `:direct_call`, we call `f(x)` directly. - For all other values, we call `f(x, ps, st)` which must return a tuple. **(In future - versions, we will default to `:runtime_check`)** - `f`: Some function. ## Inputs @@ -243,49 +249,17 @@ be `Chain((x, ps, st) -> (relu.(x), st))`. An easier thing to do would be - Output of `f(x)` - Empty `NamedTuple()` """ -struct WrappedFunction{DC, F} <: AbstractExplicitLayer - call_mode::StaticSymbol{DC} - func::F -end - -function WrappedFunction{call_mode}(f::F) where {call_mode, F} - return WrappedFunction(static(call_mode), f) -end - -function WrappedFunction(f::F) where {F} - # Not a depwarn but helpful to call this - Base.depwarn("The current default of `:direct_call` will be replaced with \ - `:runtime_check` from v1). Please make sure that the assumptions of \ - this function are correct or specify `WrappedFunction{:direct_call}(f)`", - :WrappedFunction) - return WrappedFunction{:direct_call}(f) -end - -function (wf::WrappedFunction{:direct_call})(x, ps, st::NamedTuple) - return wrapped_function_call(wf.func, x, ps, st, True()) +@concrete struct WrappedFunction <: AbstractLuxLayer + func <: Function end -function (wf::WrappedFunction)(x, ps, st::NamedTuple) - return wrapped_function_call(wf.func, x, ps, st, False()) -end - -function (wf::WrappedFunction{:runtime_check})(x, ps, st::NamedTuple) - return wrapped_function_call(wf.func, x, ps, st, - static(!hasmethod(wf.func, (typeof(x), typeof(ps), typeof(st))))) -end +(wf::WrappedFunction)(x, ps, st::NamedTuple{}) = wf.func(x), st -wrapped_function_call(f, x, ps, st, ::False) = f(x, ps, st) -wrapped_function_call(f, x, _, st, ::True) = f(x), st - -function Base.show(io::IO, w::WrappedFunction{T}) where {T} - print(io, "WrappedFunction(", static(w.call_mode), ", ") - show(io, w.func) - print(io, ")") -end +Base.show(io::IO, w::WrappedFunction) = print(io, "WrappedFunction(", w.func, ")") """ - Dense(in_dims => out_dims, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, use_bias=True(), allow_fast_activation=True()) + Dense(in_dims => out_dims, activation=identity; init_weight=nothing, + init_bias=nothing, use_bias=True()) Create a traditional fully connected layer, whose forward pass is given by: `y = activation.(weight * x .+ bias)` @@ -299,12 +273,14 @@ Create a traditional fully connected layer, whose forward pass is given by: ## Keyword Arguments - `init_weight`: initializer for the weight matrix - (`weight = init_weight(rng, out_dims, in_dims)`) - - `init_bias`: initializer for the bias vector (ignored if `use_bias=false`) + (`weight = init_weight(rng, out_dims, in_dims)`). If `nothing`, then we use + [`kaiming_uniform`](@ref) with gain computed on the basis of the activation + function (taken from Pytorch + [`nn.init.calculate_gain`](https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.calculate_gain)). + - `init_bias`: initializer for the bias vector (ignored if `use_bias=false`). If + `nothing`, then we use uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(in_dims))`. - `use_bias`: Trainable bias can be disabled entirely by setting this to `false` - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` ## Input @@ -320,7 +296,7 @@ Create a traditional fully connected layer, whose forward pass is given by: - `weight`: Weight Matrix of size `(out_dims, in_dims)` - `bias`: Bias of size `(out_dims, 1)` (present if `use_bias=true`) """ -@concrete struct Dense <: AbstractExplicitLayer +@concrete struct Dense <: AbstractLuxLayer activation in_dims <: IntegerType out_dims <: IntegerType @@ -341,38 +317,37 @@ function Dense(mapping::Pair{<:IntegerType, <:IntegerType}, activation=identity; end function Dense(in_dims::IntegerType, out_dims::IntegerType, activation=identity; - init_weight=glorot_uniform, init_bias=zeros32, - use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + init_weight=nothing, init_bias=nothing, use_bias::BoolType=True()) return Dense(activation, in_dims, out_dims, init_weight, init_bias, static(use_bias)) end function initialparameters(rng::AbstractRNG, d::Dense) - if has_bias(d) - return (weight=d.init_weight(rng, d.out_dims, d.in_dims), - bias=d.init_bias(rng, d.out_dims, 1)) #TODO: In v1 make it a vector + weight = if d.init_weight === nothing + kaiming_uniform(rng, Float32, d.out_dims, d.in_dims; + gain=Utils.calculate_gain(d.activation, √5.0f0)) else - return (weight=d.init_weight(rng, d.out_dims, d.in_dims),) + d.init_weight(rng, d.out_dims, d.in_dims) end + has_bias(d) || return (; weight) + return (; weight, bias=init_linear_bias(rng, d.init_bias, d.in_dims, d.out_dims)) end parameterlength(d::Dense) = d.out_dims * d.in_dims + has_bias(d) * d.out_dims statelength(d::Dense) = 0 -outputsize(d::Dense) = (d.out_dims,) +outputsize(d::Dense, _, ::AbstractRNG) = (d.out_dims,) function (d::Dense)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(d, ps, st, x) - bias = safe_vec(safe_getproperty(ps, Val(:bias))) + bias = safe_getproperty(ps, Val(:bias)) + σ = NNlib.fast_act(d.activation, x) z = matrix_to_array( - fused_dense_bias_activation(d.activation, ps.weight, make_abstract_matrix(y), bias), - y) + fused_dense_bias_activation(σ, ps.weight, make_abstract_matrix(y), bias), y) return z, st end """ - Scale(dims, activation=identity; init_weight=ones32, init_bias=zeros32, use_bias=True(), - allow_fast_activation=True()) + Scale(dims, activation=identity; init_weight=ones32, init_bias=zeros32, use_bias=True()) Create a Sparsely Connected Layer with a very specific structure (only Diagonal Elements are non-zero). The forward pass is given by: `y = activation.(weight .* x .+ bias)` @@ -388,9 +363,6 @@ Elements are non-zero). The forward pass is given by: `y = activation.(weight .* (`weight = init_weight(rng, out_dims, in_dims)`) - `init_bias`: initializer for the bias vector (ignored if `use_bias=false`) - `use_bias`: Trainable bias can be disabled entirely by setting this to `false` - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` ## Input @@ -407,7 +379,7 @@ Elements are non-zero). The forward pass is given by: `y = activation.(weight .* - `weight`: Weight Array of size `(dims...)` - `bias`: Bias of size `(dims...)` """ -@concrete struct Scale{UB <: StaticBool} <: AbstractExplicitLayer +@concrete struct Scale{UB <: StaticBool} <: AbstractLuxLayer activation dims <: Tuple{Vararg{IntegerType}} init_weight @@ -423,9 +395,7 @@ function Base.show(io::IO, d::Scale) end function Scale(dims::Tuple{Vararg{IntegerType}}, activation=identity; - init_weight=glorot_uniform, init_bias=zeros32, - use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + init_weight=glorot_uniform, init_bias=zeros32, use_bias::BoolType=True()) return Scale(activation, dims, init_weight, init_bias, static(use_bias)) end @@ -446,22 +416,24 @@ end parameterlength(d::Scale) = (1 + has_bias(d)) * prod(d.dims) statelength(d::Scale) = 0 -outputsize(d::Scale) = d.dims +outputsize(d::Scale, _, ::AbstractRNG) = d.dims function (d::Scale{False})(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(d, ps, st, x) - return @.(d.activation(y .* ps.weight)), st + σ = NNlib.fast_act(d.activation, y) + return @.(σ(y .* ps.weight)), st end function (d::Scale{True})(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(d, ps, st, x) - return @.(d.activation(y * ps.weight + ps.bias)), st + σ = NNlib.fast_act(d.activation, y) + return @.(σ(y * ps.weight + ps.bias)), st end """ - Bilinear((in1_dims, in2_dims) => out, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, use_bias=True(), allow_fast_activation=True()) - Bilinear(in12_dims => out, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, use_bias=True(), allow_fast_activation=True()) + Bilinear((in1_dims, in2_dims) => out, activation=identity; init_weight=nothing, + init_bias=nothing, use_bias=True()) + Bilinear(in12_dims => out, activation=identity; init_weight=nothing, + init_bias=nothing, use_bias=True()) Create a fully connected layer between two inputs and an output, and otherwise similar to [`Dense`](@ref). Its output, given vectors `x` & `y`, is another vector `z` with, for all @@ -483,12 +455,13 @@ with `B` the Bilinear layer. ## Keyword Arguments - `init_weight`: initializer for the weight matrix - (`weight = init_weight(rng, out_dims, in1_dims, in2_dims)`) - - `init_bias`: initializer for the bias vector (ignored if `use_bias=false`) + (`weight = init_weight(rng, out_dims, in1_dims, in2_dims)`). If `nothing`, then we + use uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(in1_dims))`. + - `init_bias`: initializer for the bias vector (ignored if `use_bias=false`). If + `nothing`, then we use uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(in1_dims))`. - `use_bias`: Trainable bias can be disabled entirely by setting this to `false` - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` ## Input @@ -509,7 +482,7 @@ with `B` the Bilinear layer. - `weight`: Weight Matrix of size `(out_dims, in1_dims, in2_dims)` - `bias`: Bias of size `(out_dims, 1)` (present if `use_bias=true`) """ -@concrete struct Bilinear <: AbstractExplicitLayer +@concrete struct Bilinear <: AbstractLuxLayer activation in1_dims <: IntegerType in2_dims <: IntegerType @@ -531,21 +504,24 @@ function Bilinear((in12_dims, out)::Pair{<:IntegerType, <:IntegerType}, return Bilinear((in12_dims, in12_dims) => out, activation; kwargs...) end -function Bilinear(((in1_dims, in2_dims), out)::Pair{<:Tuple, <:IntegerType}, - activation=identity; init_weight=glorot_uniform, init_bias=zeros32, - use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation +function Bilinear( + ((in1_dims, in2_dims), out)::Pair{<:Tuple, <:IntegerType}, activation=identity; + init_weight=nothing, init_bias=nothing, use_bias::BoolType=True()) return Bilinear( activation, in1_dims, in2_dims, out, init_weight, init_bias, static(use_bias)) end function initialparameters(rng::AbstractRNG, b::Bilinear) - if has_bias(b) - return (weight=b.init_weight(rng, b.out_dims, b.in1_dims, b.in2_dims), - bias=b.init_bias(rng, b.out_dims, 1)) # TODO: In v1.0 make it a vector + weight = if b.init_weight === nothing + bound = inv(sqrt(b.in1_dims)) + y = randn32(rng, b.out_dims, b.in1_dims, b.in2_dims) + @. y = (y - 0.5f0) * 2 * bound + y else - return (weight=b.init_weight(rng, b.out_dims, b.in1_dims, b.in2_dims),) + b.init_weight(rng, b.out_dims, b.in1_dims, b.in2_dims) end + has_bias(b) || return (; weight) + return (; weight, bias=init_linear_bias(rng, b.init_bias, b.in1_dims, b.out_dims)) end function parameterlength(b::Bilinear) @@ -553,7 +529,7 @@ function parameterlength(b::Bilinear) end statelength(b::Bilinear) = 0 -outputsize(b::Bilinear) = (b.out_dims,) +outputsize(b::Bilinear, _, ::AbstractRNG) = (b.out_dims,) function (b::Bilinear)( (x, y)::Tuple{<:AbstractVecOrMat, <:AbstractVecOrMat}, ps, st::NamedTuple) @@ -564,8 +540,8 @@ function (b::Bilinear)( Wy = reshape(reshape(ps.weight, (:, s₃)) * y, (s₁, s₂, :)) Wyx = reshape(batched_matmul(Wy, reshape(x, (s₂, 1, :))), (s₁, :)) - bias = safe_vec(safe_getproperty(ps, Val(:bias))) - return (bias_activation!!(b.activation, Wyx, bias), st) + σ = NNlib.fast_act(b.activation, Wyx) + return bias_activation!!(σ, Wyx, safe_getproperty(ps, Val(:bias))), st end function (b::Bilinear)((x, y)::Tuple{<:AbstractArray, <:AbstractArray}, ps, st::NamedTuple) @@ -583,7 +559,7 @@ end (b::Bilinear)(x::AbstractArray, ps, st::NamedTuple) = b((x, x), ps, st) """ - Embedding(in_dims => out_dims; init_weight=randn32) + Embedding(in_dims => out_dims; init_weight=rand32) A lookup table that stores embeddings of dimension `out_dims` for a vocabulary of size `in_dims`. When the vocabulary is multi-dimensional, the input is expected to be a tuple @@ -616,13 +592,13 @@ This layer is often used to store word embeddings and retrieve them using indice input, an N + 1 dimensional output is returned. - Empty `NamedTuple()` """ -@concrete struct Embedding <: AbstractExplicitLayer +@concrete struct Embedding <: AbstractLuxLayer in_dims <: Union{IntegerType, Tuple{Vararg{IntegerType}}} out_dims <: IntegerType init_weight end -function Embedding((in_dims, out_dims)::Pair; init_weight=randn32) +function Embedding((in_dims, out_dims)::Pair; init_weight=rand32) return Embedding(in_dims, out_dims, init_weight) end @@ -634,7 +610,7 @@ function Base.show(io::IO, e::Embedding) return print(io, "Embedding(", e.in_dims, " => ", e.out_dims, ")") end -outputsize(e::Embedding) = (e.out_dims,) +outputsize(e::Embedding, _, ::AbstractRNG) = (e.out_dims,) (e::Embedding)(x::Integer, ps, st::NamedTuple) = view(ps.weight, :, x), st function (e::Embedding)(x::AbstractVector{<:Integer}, ps, st::NamedTuple) @@ -659,72 +635,3 @@ end function (e::Embedding)(::Tuple{}, _, ::NamedTuple) throw(ArgumentError("Input tuple must contain at least one element")) end - -""" - PeriodicEmbedding(idxs, periods) - -Create an embedding periodic in some inputs with specified periods. Input indices not in -`idxs` are passed through unchanged, but inputs in `idxs` are moved to the end of the -output and replaced with their sines, followed by their cosines (scaled appropriately to -have the specified periods). This smooth embedding preserves phase information and enforces -periodicity. - -For example, `layer = PeriodicEmbedding([2, 3], [3.0, 1.0])` will create a layer periodic in -the second input with period 3.0 and periodic in the third input with period 1.0. In this -case, `layer([a, b, c, d], st) == ([a, d, sinpi(2 / 3.0 * b), sinpi(2 / 1.0 * c), cospi(2 / 3.0 * b), cospi(2 / 1.0 * c)], st)`. - -## Arguments - - - `idxs`: Indices of the periodic inputs - - `periods`: Periods of the periodic inputs, in the same order as in `idxs` - -!!! danger "Deprecation Notice" - - This layer is deprecated and will be removed in v1. Please use the version in - [`Boltz.jl`](https://github.com/LuxDL/Boltz.jl) instead. - -# Extended Help - -## Inputs - - - `x` must be an `AbstractArray` with `issubset(idxs, axes(x, 1))` - - `st` must be a `NamedTuple` where `st.k = 2 ./ periods`, but on the same device as `x` - -## Returns - - - `AbstractArray` of size `(size(x, 1) + length(idxs), ...)` where `...` are the other - dimensions of `x`. - - `st`, unchanged -""" -struct PeriodicEmbedding{I, P} <: AbstractExplicitLayer - idxs::I - periods::P - - function PeriodicEmbedding(idxs::I, periods::P) where {I, P} - Base.depwarn("`PeriodicEmbedding` is deprecated and will be removed in v1. Please \ - use the corresponding version in `Boltz.jl` instead.", - :PeriodicEmbedding) - return new{I, P}(idxs, periods) - end -end - -initialstates(::AbstractRNG, p::PeriodicEmbedding) = (k=2 ./ p.periods,) - -function (p::PeriodicEmbedding)(x::AbstractVector, ps, st::NamedTuple) - return vec(first(p(reshape(x, :, 1), ps, st))), st -end - -function (p::PeriodicEmbedding)(x::AbstractMatrix, ps, st::NamedTuple) - other_idxs = CRC.@ignore_derivatives setdiff(axes(x, 1), p.idxs) - return ( - vcat(x[other_idxs, :], sinpi.(st.k .* x[p.idxs, :]), cospi.(st.k .* x[p.idxs, :])), - st) -end - -function (p::PeriodicEmbedding)(x::AbstractArray, ps, st::NamedTuple) - return reshape(first(p(reshape(x, size(x, 1), :), ps, st)), :, size(x)[2:end]...), st -end - -function Base.show(io::IO, p::PeriodicEmbedding) - return print(io, "PeriodicEmbedding(", p.idxs, ", ", p.periods, ")") -end diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 21c1506a8c..1b5832c892 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -16,7 +16,7 @@ The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`. - `connection`: + A 2-argument function that takes `layer(input)` and the input OR - + An AbstractExplicitLayer that takes `(layer(input), input)` as input + + An AbstractLuxLayer that takes `(layer(input), input)` as input # Extended Help @@ -32,18 +32,18 @@ The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`. ## Parameters - Parameters of `layer` OR - - If `connection` is an AbstractExplicitLayer, then NamedTuple with fields `:layers` and + - If `connection` is an AbstractLuxLayer, then NamedTuple with fields `:layers` and `:connection` ## States - States of `layer` OR - - If `connection` is an AbstractExplicitLayer, then NamedTuple with fields `:layers` and + - If `connection` is an AbstractLuxLayer, then NamedTuple with fields `:layers` and `:connection` See [`Parallel`](@ref) for a more general implementation. """ -@concrete struct SkipConnection <: AbstractExplicitContainerLayer{(:layers,)} +@concrete struct SkipConnection <: AbstractLuxWrapperLayer{:layers} layers connection name @@ -60,13 +60,12 @@ function SkipConnection(; layers, connection, name::NAME_TYPE=nothing) end function initialparameters( - rng::AbstractRNG, l::SkipConnection{T, <:AbstractExplicitLayer}) where {T} + rng::AbstractRNG, l::SkipConnection{T, <:AbstractLuxLayer}) where {T} return (layers=initialparameters(rng, l.layers), connection=initialparameters(rng, l.connection)) end -function initialstates( - rng::AbstractRNG, l::SkipConnection{T, <:AbstractExplicitLayer}) where {T} +function initialstates(rng::AbstractRNG, l::SkipConnection{T, <:AbstractLuxLayer}) where {T} return ( layers=initialstates(rng, l.layers), connection=initialstates(rng, l.connection)) end @@ -76,7 +75,7 @@ function (skip::SkipConnection)(x, ps, st::NamedTuple) return skip.connection(mx, x), st end -function (skip::SkipConnection{<:AbstractExplicitLayer, <:AbstractExplicitLayer})( +function (skip::SkipConnection{<:AbstractLuxLayer, <:AbstractLuxLayer})( x, ps, st::NamedTuple) mx, st1 = apply(skip.layers, x, ps.layers, st.layers) y, st2 = apply(skip.connection, (mx, x), ps.connection, st.connection) @@ -147,7 +146,7 @@ julia> size.(first(model((x1, x2), ps, st))) ((1,), (1,)) ``` """ -@concrete struct Parallel <: AbstractExplicitContainerLayer{(:layers,)} +@concrete struct Parallel <: AbstractLuxWrapperLayer{:layers} connection layers <: NamedTuple name @@ -194,8 +193,6 @@ end return Expr(:block, calls...) end -Base.keys(m::Parallel) = Base.keys(getfield(m, :layers)) - """ BranchLayer(layers...) BranchLayer(; name=nothing, layers...) @@ -256,7 +253,7 @@ BranchLayer( # plus 0 states. ``` """ -@concrete struct BranchLayer <: AbstractExplicitContainerLayer{(:layers,)} +@concrete struct BranchLayer <: AbstractLuxWrapperLayer{:layers} layers <: NamedTuple name end @@ -283,8 +280,6 @@ BranchLayer(; name::NAME_TYPE=nothing, kwargs...) = BranchLayer((; kwargs...), n return Expr(:block, calls...) end -Base.keys(m::BranchLayer) = Base.keys(getfield(m, :layers)) - """ PairwiseFusion(connection, layers...; name=nothing) PairwiseFusion(connection; name=nothing, layers...) @@ -301,7 +296,7 @@ x1 → layer1 → y1 ↘ - `connection`: Takes 2 inputs and combines them - - `layers`: `AbstractExplicitLayer`s. Layers can be specified in two formats: + - `layers`: `AbstractLuxLayer`s. Layers can be specified in two formats: + A list of `N` Lux layers + Specified as `N` keyword arguments. @@ -346,7 +341,7 @@ end - States of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) """ -@concrete struct PairwiseFusion <: AbstractExplicitContainerLayer{(:layers,)} +@concrete struct PairwiseFusion <: AbstractLuxWrapperLayer{:layers} connection layers <: NamedTuple name @@ -391,11 +386,9 @@ end return Expr(:block, calls...) end -Base.keys(m::PairwiseFusion) = Base.keys(getfield(m, :layers)) - """ - Chain(layers...; name=nothing, disable_optimizations::Bool = false) - Chain(; layers..., name=nothing, disable_optimizations::Bool = false) + Chain(layers...; name=nothing) + Chain(; layers..., name=nothing) Collects multiple layers / functions to be called in sequence on a given input. @@ -406,11 +399,6 @@ Collects multiple layers / functions to be called in sequence on a given input. + A list of `N` Lux layers + Specified as `N` keyword arguments. -## Keyword Arguments - - - `disable_optimizations`: Prevents any structural optimization - - `name`: Name of the layer (optional) - # Extended Help ## Inputs @@ -433,20 +421,6 @@ of the internal layers. - States of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) -## Optimizations - -Performs a few optimizations to generate reasonable architectures. Can be disabled using -keyword argument `disable_optimizations`. - - - All sublayers are recursively optimized. - - If a function `f` is passed as a layer and it doesn't take 3 inputs, it is converted to - a [`WrappedFunction`](@ref)(`f`) which takes only one input. - - If the layer is a Chain, it is flattened. - - [`NoOpLayer`](@ref)s are removed. - - If there is only 1 layer (left after optimizations), then it is returned without the - `Chain` wrapper. - - If there are no layers (left after optimizations), a [`NoOpLayer`](@ref) is returned. - ## Miscellaneous Properties - Allows indexing and field access syntax. We can access the `i`th layer by `m[i]` or @@ -462,58 +436,46 @@ Chain( layer_3 = Dense(3 => 2), # 8 parameters ) # Total: 23 parameters, # plus 7 states. + +julia> Chain(Dense(2, 3, relu), BatchNorm(3), Dense(3, 2); name="MyFancyChain") +MyFancyChain( + layer_1 = Dense(2 => 3, relu), # 9 parameters + layer_2 = BatchNorm(3, affine=true, track_stats=true), # 6 parameters, plus 7 + layer_3 = Dense(3 => 2), # 8 parameters +) # Total: 23 parameters, + # plus 7 states. ``` """ -@concrete struct Chain <: AbstractExplicitContainerLayer{(:layers,)} +@concrete struct Chain <: AbstractLuxWrapperLayer{:layers} layers <: NamedTuple name end -function Chain(xs...; name::NAME_TYPE=nothing, disable_optimizations::Bool=false) - xs = disable_optimizations ? xs : flatten_lux_chain(xs) - length(xs) == 0 && return NoOpLayer() - length(xs) == 1 && return first(xs) - return Chain(Utils.named_tuple_layers(xs...), name) +function Chain(xs...; name::NAME_TYPE=nothing) + return Chain(Utils.named_tuple_layers(wrap_functions_in_chain_call(xs)...), name) end - Chain(xs::AbstractVector; kwargs...) = Chain(xs...; kwargs...) +Chain(nt::NamedTuple; name::NAME_TYPE=nothing) = Chain(nt, name) +Chain(; name::NAME_TYPE=nothing, kwargs...) = Chain((; kwargs...); name) -function Chain(nt::NamedTuple; disable_optimizations::Bool=true, name::NAME_TYPE=nothing) - if !disable_optimizations - throw(ArgumentError("Chain(::NamedTuple) is not compatible with disable_optimizations=true")) - end - return Chain(nt, name) -end - -function Chain(; disable_optimizations::Bool=true, name::NAME_TYPE=nothing, kwargs...) - return Chain((; kwargs...); disable_optimizations, name) -end - -function flatten_lux_chain(layers::Union{AbstractVector, Tuple}) +function wrap_functions_in_chain_call(layers::Union{AbstractVector, Tuple}) new_layers = [] for l in layers - f = flatten_lux_chain(l) + f = wrap_functions_in_chain_call(l) if f isa Tuple || f isa AbstractVector append!(new_layers, f) elseif f isa Function - if !hasmethod(f, (Any, Any, NamedTuple)) - f === identity && continue - push!(new_layers, WrappedFunction{:direct_call}(f)) - else - push!(new_layers, WrappedFunction{:layer}(f)) - end - elseif f isa Chain - append!(new_layers, f.layers) - elseif f isa NoOpLayer - continue - else + push!(new_layers, WrappedFunction(f)) + elseif f isa AbstractLuxLayer push!(new_layers, f) + else + throw("Encountered a non-AbstractLuxLayer in Chain.") end end return layers isa AbstractVector ? new_layers : Tuple(new_layers) end -flatten_lux_chain(x) = x +wrap_functions_in_chain_call(x) = x (c::Chain)(x, ps, st::NamedTuple) = applychain(c.layers, x, ps, st) @@ -530,8 +492,6 @@ flatten_lux_chain(x) = x return Expr(:block, calls...) end -Base.keys(c::Chain) = Base.keys(getfield(c, :layers)) - Base.getindex(c::Chain, i::Int) = c.layers[i] Base.getindex(c::Chain, i::AbstractArray) = Chain(Utils.index_namedtuple(c.layers, i)) @@ -546,8 +506,6 @@ Base.length(c::Chain) = length(c.layers) Base.lastindex(c::Chain) = lastindex(c.layers) Base.firstindex(c::Chain) = firstindex(c.layers) -outputsize(c::Chain) = outputsize(c.layers[end]) - """ Maxout(layers...) Maxout(; layers...) @@ -595,7 +553,7 @@ See also [`Parallel`](@ref) to reduce with other operators. [1] Goodfellow, Warde-Farley, Mirza, Courville & Bengio "Maxout Networks" [https://arxiv.org/abs/1302.4389](https://arxiv.org/abs/1302.4389) """ -@concrete struct Maxout <: AbstractExplicitContainerLayer{(:layers,)} +@concrete struct Maxout <: AbstractLuxWrapperLayer{:layers} layers <: NamedTuple end @@ -620,8 +578,6 @@ Maxout(f::Function, n_alts::Int) = Maxout(ntuple(Returns(f()), n_alts)...) return Expr(:block, calls...) end -Base.keys(m::Maxout) = Base.keys(getfield(m, :layers)) - """ RepeatedLayer(model; repeats::Val = Val(10), input_injection::Val = Val(false)) @@ -652,7 +608,7 @@ times for gradients might be unreasonably high. ## Arguments - - `model` must be an `AbstractExplicitLayer` + - `model` must be an `AbstractLuxLayer` ## Keyword Arguments @@ -679,10 +635,10 @@ times for gradients might be unreasonably high. - State of `model` """ -@concrete struct RepeatedLayer <: AbstractExplicitContainerLayer{(:model,)} +@concrete struct RepeatedLayer <: AbstractLuxWrapperLayer{:model} nrepeats <: StaticInt input_injection <: StaticBool - model <: AbstractExplicitLayer + model <: AbstractLuxLayer end function LuxCore.display_name(r::RepeatedLayer) @@ -691,7 +647,7 @@ function LuxCore.display_name(r::RepeatedLayer) end function RepeatedLayer( - model::AbstractExplicitLayer; repeats::Union{StaticInt, Integer, Val}=Val(10), + model::AbstractLuxLayer; repeats::Union{StaticInt, Integer, Val}=Val(10), input_injection::Union{StaticBool, Bool, Val{true}, Val{false}}=Val(false)) return RepeatedLayer(static(repeats), static(input_injection), model) end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index ad40385a04..5a1d8a586f 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -17,14 +17,14 @@ end CRC.@non_differentiable calc_padding(::Any...) function conv_transpose_dims( - x::AbstractArray, weight::AbstractArray; padding, stride, dilation, groups) + x::AbstractArray, weight::AbstractArray; padding, stride, dilation, groups, outpad) # Calculate size of "input", from ∇conv_data()'s perspective... - function calc_dim(xsz, wsz, stride, dilation, pad) - return (xsz - 1) * stride + 1 + (wsz - 1) * dilation - pad + function calc_dim(xsz, wsz, stride, dilation, pad, outpad) + return (xsz - 1) * stride + 1 + (wsz - 1) * dilation - pad + outpad end combined_pad = ntuple(i -> padding[2i - 1] + padding[2i], length(padding) ÷ 2) I = map(calc_dim, size(x)[1:(end - 2)], size(weight)[1:(end - 2)], - stride, dilation, combined_pad) + stride, dilation, combined_pad, outpad) C_in = size(weight)[end - 1] * groups C_out = size(weight)[end] batch_size = size(x)[end] @@ -42,27 +42,35 @@ CRC.@non_differentiable conv_transpose_dims(::Any...) conv_transpose(x, weight, cdims) = LuxLib.Impl.∇conv_data(x, weight, cdims) -function compute_adaptive_pooling_dims(x::AbstractArray, outsize) - insize = size(x)[1:(end - 2)] - stride = insize .÷ outsize - k = insize .- (outsize .- 1) .* stride - return PoolDims(x, k; padding=0, stride=stride) +function init_conv_weight( + rng::AbstractRNG, init_weight::F, filter::NTuple{N, <:IntegerType}, + in_chs::IntegerType, out_chs::IntegerType, groups, σ::A) where {F, N, A} + if init_weight === nothing # Default from PyTorch + return kaiming_uniform(rng, Float32, filter..., in_chs ÷ groups, + out_chs; gain=Utils.calculate_gain(σ, √5.0f0)) + end + return init_weight(rng, filter..., in_chs ÷ groups, out_chs) end -CRC.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any) - -function init_conv_filter(rng::AbstractRNG, filter::NTuple{N, Integer}, - ch::Pair{<:Integer, <:Integer}; init=glorot_uniform, groups=1) where {N} - cin, cout = ch - @argcheck cin % groups==0 DimensionMismatch("Input channel dimension must be divisible by groups.") - @argcheck cout % groups==0 DimensionMismatch("Output channel dimension must be divisible by groups.") - return init(rng, filter..., cin ÷ groups, cout) +function init_conv_bias(rng::AbstractRNG, init_bias::F, filter::NTuple{N, <:IntegerType}, + in_chs::IntegerType, out_chs::IntegerType, groups) where {F, N} + if init_bias === nothing # Default from PyTorch + fan_in = prod(filter) * (in_chs ÷ groups) + bound = inv(sqrt(fan_in)) + y = rand32(rng, out_chs) + @. y = (y - 0.5f0) * 2 * bound + return y + end + return init_bias(rng, out_chs) end +construct_crosscor_convdims(::False, cdims::DenseConvDims) = cdims +construct_crosscor_convdims(::True, cdims::DenseConvDims) = DenseConvDims(cdims; F=true) + @doc doc""" Conv(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, - activation=identity; init_weight=glorot_uniform, init_bias=zeros32, stride=1, - pad=0, dilation=1, groups=1, use_bias=True(), allow_fast_activation=True()) + activation=identity; init_weight=nothing, init_bias=nothing, stride=1, + pad=0, dilation=1, groups=1, use_bias=True(), cross_correlation=False()) Standard convolutional layer. @@ -79,7 +87,8 @@ Standard convolutional layer. !!! warning Frameworks like [`Pytorch`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) - perform cross-correlation in their convolution layers + perform cross-correlation in their convolution layers. Pass `cross_correlation=true` to + use cross-correlation instead. ## Arguments @@ -93,8 +102,13 @@ Standard convolutional layer. ## Keyword Arguments - - `init_weight`: Controls the initialization of the weight parameter - - `init_bias`: Controls the initialization of the bias parameter + - `init_weight`: Controls the initialization of the weight parameter. If `nothing`, then + we use [`kaiming_uniform`](@ref) with gain computed on the basis of the activation + function (taken from Pytorch + [`nn.init.calculate_gain`](https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.calculate_gain)). + - `init_bias`: Controls the initialization of the bias parameter. If `nothing`, then we + use uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(fan_in))`. - `stride`: Should each be either single integer, or a tuple with `N` integers - `dilation`: Should each be either single integer, or a tuple with `N` integers @@ -115,9 +129,9 @@ Standard convolutional layer. convolution into (set `groups = in_chs` for Depthwise Convolutions). `in_chs` and `out_chs` must be divisible by `groups`. - `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` + - `cross_correlation`: If `true`, perform cross-correlation instead of convolution. Prior + to `v1`, Lux used to have a `CrossCor` layer which performed cross-correlation. This + was removed in `v1` in favor of `Conv` with `cross_correlation=true`. ## Inputs @@ -139,7 +153,7 @@ O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s - `weight`: Convolution kernel - `bias`: Bias (present if `use_bias=true`) """ -@concrete struct Conv <: AbstractExplicitLayer +@concrete struct Conv <: AbstractLuxLayer activation in_chs <: IntegerType out_chs <: IntegerType @@ -151,28 +165,30 @@ O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s init_weight init_bias use_bias <: StaticBool + cross_correlation <: StaticBool end function Conv(k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, - activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, stride=1, pad=0, dilation=1, groups=1, - use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) + activation=identity; init_weight=nothing, + init_bias=nothing, stride=1, pad=0, dilation=1, groups=1, + use_bias::BoolType=True(), cross_correlation::BoolType=False()) stride = Utils.expand(Val(length(k)), stride) dilation = Utils.expand(Val(length(k)), dilation) pad = calc_padding(pad, k, dilation, stride) + + @argcheck ch[1] % groups==0 DimensionMismatch("Input channel dimension must be divisible by groups.") + @argcheck ch[2] % groups==0 DimensionMismatch("Output channel dimension must be divisible by groups.") @argcheck allequal(length, (stride, dilation, k)) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation - return Conv(activation, first(ch), last(ch), k, stride, pad, dilation, - groups, init_weight, init_bias, static(use_bias)) + return Conv(activation, first(ch), last(ch), k, stride, pad, dilation, groups, + init_weight, init_bias, static(use_bias), static(cross_correlation)) end function initialparameters(rng::AbstractRNG, c::Conv) - weight = init_conv_filter( - rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight, c.groups) + args = (c.kernel_size, c.in_chs, c.out_chs, c.groups) + weight = init_conv_weight(rng, c.init_weight, args..., c.activation) has_bias(c) || return (; weight) - return (; weight, - bias=c.init_bias(rng, ntuple(_ -> 1, length(c.kernel_size))..., c.out_chs, 1)) # TODO: flatten in v1 + return (; weight, bias=init_conv_bias(rng, c.init_bias, args...)) end function parameterlength(c::Conv) @@ -181,9 +197,11 @@ end function (c::Conv)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(c, ps, st, x) - cdims = DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups) - bias = safe_vec(safe_getproperty(ps, Val(:bias))) - return fused_conv_bias_activation(c.activation, ps.weight, y, bias, cdims), st + cdims = construct_crosscor_convdims(c.cross_correlation, + DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)) + bias = safe_getproperty(ps, Val(:bias)) + σ = NNlib.fast_act(c.activation, y) + return fused_conv_bias_activation(σ, ps.weight, y, bias, cdims), st end function Base.show(io::IO, l::Conv) @@ -196,24 +214,38 @@ function Base.show(io::IO, l::Conv) print(io, ", dilation=", PrettyPrinting.tuple_string(l.dilation)) (l.groups == 1) || print(io, ", groups=", l.groups) has_bias(l) || print(io, ", use_bias=false") + known(l.cross_correlation) && print(io, ", cross_correlation=true") print(io, ")") end @doc doc""" - MaxPool(window::NTuple; pad=0, stride=window) + ConvTranspose(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, + activation=identity; init_weight=glorot_uniform, init_bias=zeros32, + stride=1, pad=0, outpad=0, dilation=1, groups=1, use_bias=True(), + cross_correlation=False()) -Max pooling layer, which replaces all pixels in a block of size `window` with the maximum -value. +Standard convolutional transpose layer. -# Arguments +## Arguments - - `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling - `length(window) == 2` + - `k`: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D + convolutions `length(k) == 2` + - `in_chs`: Number of input channels + - `out_chs`: Number of input and output channels + - `activation`: Activation Function ## Keyword Arguments - - `stride`: Should each be either single integer, or a tuple with `N` integers + - `init_weight`: Controls the initialization of the weight parameter. If `nothing`, then + we use [`kaiming_uniform`](@ref) with gain computed on the basis of the activation + function (taken from Pytorch + [`nn.init.calculate_gain`](https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.calculate_gain)). + - `init_bias`: Controls the initialization of the bias parameter. If `nothing`, then we + use uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(fan_in))`. + - `stride`: Should each be either single integer, or a tuple with `N` integers + - `dilation`: Should each be either single integer, or a tuple with `N` integers - `pad`: Specifies the number of elements added to the borders of the data array. It can be @@ -222,125 +254,117 @@ value. dimension, + a tuple of `2*N` integers, for asymmetric padding, or + the singleton `SamePad()`, to calculate padding such that - `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial + `size(output,d) == size(x,d) * stride` (possibly rounded) for each spatial dimension. + - `groups`: Expected to be an `Int`. It specifies the number of groups to divide a + convolution into (set `groups = in_chs` for Depthwise Convolutions). `in_chs` + and `out_chs` must be divisible by `groups`. + - `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. + - `cross_correlation`: If `true`, perform transposed cross-correlation instead of + transposed convolution. + - `outpad`: To converse [`Conv`](@ref) inversability when `stride > 1`, `outpad` can be + used to increase the size of the output in the desired dimensions. Whereas `pad` is used + to zero-pad the input, `outpad` only affects the output shape. + # Extended Help ## Inputs - - `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` + - `x`: Data satisfying `ndims(x) == N + 2 && size(x, N - 1) == in_chs`, i.e. + `size(x) = (I_N, ..., I_1, C_in, N)` ## Returns - - Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where - -```math - O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s_i} + 1\right\rfloor -``` - + - Output of the convolution transpose `y` of size `(O_N, ..., O_1, C_out, N)` where - Empty `NamedTuple()` -See also [`Conv`](@ref), [`MeanPool`](@ref), [`GlobalMaxPool`](@ref), -[`AdaptiveMaxPool`](@ref) +## Parameters + + - `weight`: Convolution Transpose kernel + - `bias`: Bias (present if `use_bias=true`) """ -@concrete struct MaxPool <: AbstractExplicitLayer - k <: Tuple{Vararg{IntegerType}} - pad <: Tuple{Vararg{IntegerType}} +@concrete struct ConvTranspose <: AbstractLuxLayer + activation + in_chs <: IntegerType + out_chs <: IntegerType + kernel_size <: Tuple{Vararg{IntegerType}} stride <: Tuple{Vararg{IntegerType}} + pad <: Tuple{Vararg{IntegerType}} + outpad <: Tuple{Vararg{IntegerType}} + dilation <: Tuple{Vararg{IntegerType}} + groups <: IntegerType + init_weight + init_bias + use_bias <: StaticBool + cross_correlation <: StaticBool end -function MaxPool(k::Tuple{Vararg{IntegerType}}; pad=0, stride=k) +function ConvTranspose( + k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, + activation=identity; init_weight=glorot_uniform, + init_bias=zeros32, stride=1, pad=0, outpad=0, dilation=1, groups=1, + use_bias::BoolType=True(), cross_correlation::BoolType=False()) stride = Utils.expand(Val(length(k)), stride) - pad = calc_padding(pad, k, 1, stride) - @argcheck allequal(length, (stride, k)) - - return MaxPool(k, pad, stride) -end + dilation = Utils.expand(Val(length(k)), dilation) + pad = if pad isa SamePad + calc_padding(pad, k .- stride .+ 1, dilation, stride) + else + calc_padding(pad, k, dilation, stride) + end + outpad = Utils.expand(Val(length(k)), outpad) -function (m::MaxPool)(x, _, st::NamedTuple) - return maxpool(x, PoolDims(x, m.k; padding=m.pad, m.stride)), st -end + @argcheck ch[2] % groups==0 DimensionMismatch("Input channel dimension must be divisible by groups.") + @argcheck ch[1] % groups==0 DimensionMismatch("Output channel dimension must be divisible by groups.") + @argcheck allequal(length, (stride, dilation, k)) -function Base.show(io::IO, m::MaxPool) - print(io, "MaxPool(", m.k) - all(==(0), m.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(m.pad)) - m.stride == m.k || print(io, ", stride=", PrettyPrinting.tuple_string(m.stride)) - return print(io, ")") + return ConvTranspose( + activation, first(ch), last(ch), k, stride, pad, outpad, dilation, groups, + init_weight, init_bias, static(use_bias), static(cross_correlation)) end -@doc doc""" - MeanPool(window::NTuple; pad=0, stride=window) - -Mean pooling layer, which replaces all pixels in a block of size `window` with the mean -value. - -# Arguments - - - `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling - `length(window) == 2` - -## Keyword Arguments - - - `stride`: Should each be either single integer, or a tuple with `N` integers - - - `pad`: Specifies the number of elements added to the borders of the data array. It can - be - - + a single integer for equal padding all around, - + a tuple of `N` integers, to apply the same padding at begin/end of each spatial - dimension, - + a tuple of `2*N` integers, for asymmetric padding, or - + the singleton `SamePad()`, to calculate padding such that - `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial - dimension. - -# Extended Help - -## Inputs - - - `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` - -## Returns - - - Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where - -```math - O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s_i} + 1\right\rfloor -``` - - - Empty `NamedTuple()` - -See also [`Conv`](@ref), [`MaxPool`](@ref), [`GlobalMeanPool`](@ref), -[`AdaptiveMeanPool`](@ref) -""" -@concrete struct MeanPool <: AbstractExplicitLayer - k <: Tuple{Vararg{IntegerType}} - pad <: Tuple{Vararg{IntegerType}} - stride <: Tuple{Vararg{IntegerType}} +function initialparameters(rng::AbstractRNG, c::ConvTranspose) + weight = init_conv_weight( + rng, c.init_weight, c.kernel_size, c.out_chs, c.in_chs, c.groups, c.activation) + has_bias(c) || return (; weight) + # NOTE: The c.out_chs, c.out_chs is intentional, since it only affects the size of the + # bias vector + return (; weight, + bias=init_conv_bias( + rng, c.init_bias, c.kernel_size, c.out_chs, c.out_chs, c.groups)) end -function MeanPool(k::Tuple{Vararg{IntegerType}}; pad=0, stride=k) - stride = Utils.expand(Val(length(k)), stride) - pad = calc_padding(pad, k, 1, stride) - @argcheck allequal(length, (stride, k)) - - return MeanPool(k, pad, stride) +function parameterlength(c::ConvTranspose) + return prod(c.kernel_size) * c.in_chs * c.out_chs ÷ c.groups + has_bias(c) * c.out_chs end -function (m::MeanPool)(x, _, st::NamedTuple) - return meanpool(x, PoolDims(x, m.k; padding=m.pad, m.stride)), st +function (c::ConvTranspose)(x::AbstractArray, ps, st::NamedTuple) + y = match_eltype(c, ps, st, x) + cdims = construct_crosscor_convdims(c.cross_correlation, + conv_transpose_dims( + y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups, c.outpad)) + bias = safe_getproperty(ps, Val(:bias)) + σ = NNlib.fast_act(c.activation, y) + return bias_activation!!(σ, conv_transpose(y, ps.weight, cdims), bias), st end -function Base.show(io::IO, m::MeanPool) - print(io, "MeanPool(", m.k) - all(==(0), m.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(m.pad)) - m.stride == m.k || print(io, ", stride=", PrettyPrinting.tuple_string(m.stride)) - return print(io, ")") +function Base.show(io::IO, l::ConvTranspose) + print(io, "ConvTranspose(", l.kernel_size) + print(io, ", ", l.in_chs, " => ", l.out_chs) + l.activation == identity || print(io, ", ", l.activation) + all(==(0), l.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(l.pad)) + all(==(1), l.stride) || print(io, ", stride=", PrettyPrinting.tuple_string(l.stride)) + all(==(1), l.dilation) || + print(io, ", dilation=", PrettyPrinting.tuple_string(l.dilation)) + (l.groups == 1) || print(io, ", groups=", l.groups) + all(==(0), l.outpad) || print(io, ", outpad=", PrettyPrinting.tuple_string(l.outpad)) + has_bias(l) || print(io, ", use_bias=false") + known(l.cross_correlation) && print(io, ", cross_correlation=true") + print(io, ")") end """ - Upsample(mode = :nearest; [scale, size]) + Upsample(mode = :nearest; [scale, size, align_corners=false]) Upsample(scale, mode = :nearest) Upsampling Layer. @@ -372,6 +396,12 @@ Currently supported upsampling `mode`s and corresponding NNlib's methods are: # Extended Help +## Other Keyword Arguments + + - `align_corners`: If `true`, the corner pixels of the input and output tensors are + aligned, and thus preserving the values at those pixels. This only has effect when mode + is one of `:bilinear` or `:trilinear`. + ## Inputs - `x`: For the input dimensions look into the documentation for the corresponding `NNlib` @@ -386,42 +416,53 @@ Currently supported upsampling `mode`s and corresponding NNlib's methods are: - Upsampled Input of size `size` or of size `(I_1 x scale[1], ..., I_N x scale[N], C, N)` - Empty `NamedTuple()` """ -@concrete struct Upsample <: AbstractExplicitLayer +@concrete struct Upsample <: AbstractLuxLayer scale size upsample_mode <: StaticSymbol + align_corners <: Bool end -function Upsample(mode::SymbolType=static(:nearest); scale=nothing, size=nothing) +function Upsample(mode::SymbolType=static(:nearest); scale=nothing, + size=nothing, align_corners::Bool=false) @argcheck dynamic(mode) in (:nearest, :bilinear, :trilinear) + if !xor(isnothing(scale), isnothing(size)) throw(ArgumentError("Either scale or size should be specified (but not both).")) end - return Upsample(scale, size, static(mode)) + return Upsample(scale, size, static(mode), align_corners) end Upsample(scale, mode::SymbolType=static(:nearest)) = Upsample(mode; scale) function (m::Upsample)(x::AbstractArray, _, st::NamedTuple) - return lux_upsample_scale_dispatch(m.upsample_mode, x, m.scale), st + return lux_upsample_scale_dispatch(m.upsample_mode, x, m.scale, m.align_corners), st end function (m::Upsample{Nothing})(x::AbstractArray, _, st::NamedTuple) - return lux_upsample_size_dispatch(m.upsample_mode, x, m.size), st + return lux_upsample_size_dispatch(m.upsample_mode, x, m.size, m.align_corners), st end -for interp in (:nearest, :bilinear, :trilinear) +for interp in (:bilinear, :trilinear) nnlib_interp_func = Symbol(:upsample_, interp) @eval begin - function lux_upsample_scale_dispatch(::StaticSymbol{$(Meta.quot(interp))}, x, scale) + function lux_upsample_scale_dispatch( + ::StaticSymbol{$(Meta.quot(interp))}, x, scale, align_corners) return $(nnlib_interp_func)(x, scale) end - function lux_upsample_size_dispatch(::StaticSymbol{$(Meta.quot(interp))}, x, size) + function lux_upsample_size_dispatch( + ::StaticSymbol{$(Meta.quot(interp))}, x, size, align_corners) return $(nnlib_interp_func)(x; size) end end end -function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale::Integer) +function lux_upsample_size_dispatch(::StaticSymbol{:nearest}, x, size, _) + return NNlib.upsample_nearest(x; size) +end +function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale, _) + return NNlib.upsample_nearest(x, scale) +end +function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale::Integer, _) return NNlib.upsample_nearest(x, ntuple(i -> scale, ndims(x) - 2)) end @@ -429,121 +470,10 @@ function Base.show(io::IO, u::Upsample) print(io, "Upsample(", u.upsample_mode) u.scale !== nothing && print(io, ", scale = $(u.scale)") u.size !== nothing && print(io, ", size = $(u.size)") + u.align_corners && print(io, ", align_corners = $(u.align_corners)") print(io, ")") end -""" - GlobalMaxPool() - -Global Max Pooling layer. Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output, -by performing max pooling on the complete (w,h)-shaped feature maps. - -## Inputs - - - `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` - -## Returns - - - Output of the pooling `y` of size `(1, ..., 1, C, N)` - - Empty `NamedTuple()` - -See also [`MaxPool`](@ref), [`AdaptiveMaxPool`](@ref), [`GlobalMeanPool`](@ref) -""" -struct GlobalMaxPool <: AbstractExplicitLayer end - -function (g::GlobalMaxPool)(x, _, st::NamedTuple) - return maxpool(x, PoolDims(x, size(x)[1:(end - 2)])), st -end - -""" - GlobalMeanPool() - -Global Mean Pooling layer. Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output, -by performing mean pooling on the complete (w,h)-shaped feature maps. - -## Inputs - - - `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` - -## Returns - - - Output of the pooling `y` of size `(1, ..., 1, C, N)` - - Empty `NamedTuple()` - -See also [`MeanPool`](@ref), [`AdaptiveMeanPool`](@ref), [`GlobalMaxPool`](@ref) -""" -struct GlobalMeanPool <: AbstractExplicitLayer end - -function (g::GlobalMeanPool)(x, _, st::NamedTuple) - return meanpool(x, PoolDims(x, size(x)[1:(end - 2)])), st -end - -""" - AdaptiveMaxPool(out::NTuple) - -Adaptive Max Pooling layer. Calculates the necessary window size such that its output has -`size(y)[1:N] == out`. - -## Arguments - - - `out`: Size of the first `N` dimensions for the output - -## Inputs - - - `x`: Expects as input an array with `ndims(x) == N+2`, i.e. channel and batch - dimensions, after the `N` feature dimensions, where `N = length(out)`. - -## Returns - - - Output of size `(out..., C, N)` - - Empty `NamedTuple()` - -See also [`MaxPool`](@ref), [`AdaptiveMeanPool`](@ref). -""" -struct AdaptiveMaxPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractExplicitLayer - out::O - AdaptiveMaxPool(out) = new{length(out) + 2, typeof(out)}(out) -end - -function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) where {S, T} - return maxpool(x, compute_adaptive_pooling_dims(x, a.out)), st -end - -Base.show(io::IO, a::AdaptiveMaxPool) = print(io, "AdaptiveMaxPool(", a.out, ")") - -""" - AdaptiveMeanPool(out::NTuple) - -Adaptive Mean Pooling layer. Calculates the necessary window size such that its output has -`size(y)[1:N] == out`. - -## Arguments - - - `out`: Size of the first `N` dimensions for the output - -## Inputs - - - `x`: Expects as input an array with `ndims(x) == N+2`, i.e. channel and batch - dimensions, after the `N` feature dimensions, where `N = length(out)`. - -## Returns - - - Output of size `(out..., C, N)` - - Empty `NamedTuple()` - -See also [`MeanPool`](@ref), [`AdaptiveMaxPool`](@ref). -""" -struct AdaptiveMeanPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractExplicitLayer - out::O - AdaptiveMeanPool(out) = new{length(out) + 2, typeof(out)}(out) -end - -function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) where {S, T} - return meanpool(x, compute_adaptive_pooling_dims(x, a.out)), st -end - -Base.show(io::IO, a::AdaptiveMeanPool) = print(io, "AdaptiveMeanPool(", a.out, ")") - """ PixelShuffle(r::Int) @@ -571,260 +501,10 @@ function set to `Base.Fix2(pixel_shuffle, r)` - Output of size `(r x W, r x H, C, N)` for 4D-arrays, and `(r x W, r x H, ..., C, N)` for D-dimensional data, where `D = ndims(x) - 2` """ -PixelShuffle(r::IntegerType) = WrappedFunction{:direct_call}(Base.Fix2(pixel_shuffle, r)) - -@doc doc""" - CrossCor(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, - activation=identity; init_weight=glorot_uniform, init_bias=zeros32, stride=1, - pad=0, dilation=1, groups=1, use_bias=True(), allow_fast_activation=True()) - -Cross Correlation layer. - -Image data should be stored in WHCN order (width, height, channels, batch). In other words, -a `100 x 100` RGB image would be a `100 x 100 x 3 x 1` array, and a batch of 50 would be a -`100 x 100 x 3 x 50` array. This has `N = 2` spatial dimensions, and needs a kernel size -like `(5, 5)`, a 2-tuple of integers. To take convolutions along `N` feature dimensions, -this layer expects as input an array with `ndims(x) == N + 2`, where -`size(x, N + 1) == in_chs` is the number of input channels, and `size(x, ndims(x))` is the -number of observations in a batch. - -## Arguments - - - `k`: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D - convolutions `length(k) == 2` - - `in_chs`: Number of input channels - - `out_chs`: Number of input and output channels - - `activation`: Activation Function - -## Keyword Arguments - - - `init_weight`: Controls the initialization of the weight parameter - - `init_bias`: Controls the initialization of the bias parameter - - - `stride`: Should each be either single integer, or a tuple with `N` integers - - `dilation`: Should each be either single integer, or a tuple with `N` integers - - `groups`: Expected to be an `Int`. It specifies the number of groups to divide a - convolution into (set `groups = in_chs` for Depthwise Convolutions). `in_chs` - and `out_chs` must be divisible by `groups`. - - `pad`: Specifies the number of elements added to the borders of the data array. It can - be - - + a single integer for equal padding all around, - + a tuple of `N` integers, to apply the same padding at begin/end of each spatial - dimension, - + a tuple of `2*N` integers, for asymmetric padding, or - + the singleton `SamePad()`, to calculate padding such that - `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial - dimension. - - - `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` - -# Extended Help - -## Inputs - - - `x`: Data satisfying `ndims(x) == N + 2 && size(x, N - 1) == in_chs`, i.e. - `size(x) = (I_N, ..., I_1, C_in, N)` - -## Returns - - - Output of the convolution `y` of size `(O_N, ..., O_1, C_out, N)` where - -```math -O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s_i} + 1\right\rfloor -``` - - - Empty `NamedTuple()` - -## Parameters - - - `weight`: Convolution kernel - - `bias`: Bias (present if `use_bias=true`) -""" -@concrete struct CrossCor <: AbstractExplicitLayer - activation - in_chs <: IntegerType - out_chs <: IntegerType - kernel_size <: Tuple{Vararg{IntegerType}} - stride <: Tuple{Vararg{IntegerType}} - pad <: Tuple{Vararg{IntegerType}} - dilation <: Tuple{Vararg{IntegerType}} - groups <: IntegerType - init_weight - init_bias - use_bias <: StaticBool -end - -function CrossCor(k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, - activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, stride=1, pad=0, dilation=1, groups=1, - use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) - stride = Utils.expand(Val(length(k)), stride) - dilation = Utils.expand(Val(length(k)), dilation) - pad = calc_padding(pad, k, dilation, stride) - @argcheck allequal(length, (stride, dilation, k)) - - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation - return CrossCor(activation, first(ch), last(ch), k, stride, pad, dilation, - groups, init_weight, init_bias, static(use_bias)) -end - -function initialparameters(rng::AbstractRNG, c::CrossCor) - weight = init_conv_filter( - rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight, c.groups) - has_bias(c) || return (; weight) - return (; weight, - bias=c.init_bias(rng, ntuple(_ -> 1, length(c.kernel_size))..., c.out_chs, 1)) # TODO: flatten in v1 -end - -function parameterlength(c::CrossCor) - return prod(c.kernel_size) * c.in_chs * c.out_chs ÷ c.groups + has_bias(c) * c.out_chs -end - -function (c::CrossCor)(x::AbstractArray, ps, st::NamedTuple) - y = match_eltype(c, ps, st, x) - cdims = DenseConvDims( - DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups); F=true) - bias = safe_vec(safe_getproperty(ps, Val(:bias))) - return fused_conv_bias_activation(c.activation, ps.weight, y, bias, cdims), st -end - -function Base.show(io::IO, l::CrossCor) - print(io, "CrossCor(", l.kernel_size) - print(io, ", ", l.in_chs, " => ", l.out_chs) - l.activation == identity || print(io, ", ", l.activation) - all(==(0), l.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(l.pad)) - all(==(1), l.stride) || print(io, ", stride=", PrettyPrinting.tuple_string(l.stride)) - all(==(1), l.dilation) || - print(io, ", dilation=", PrettyPrinting.tuple_string(l.dilation)) - (l.groups == 1) || print(io, ", groups=", l.groups) - has_bias(l) || print(io, ", use_bias=false") - return print(io, ")") -end - -@doc doc""" - ConvTranspose(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, - activation=identity; init_weight=glorot_uniform, init_bias=zeros32, - stride=1, pad=0, dilation=1, groups=1, use_bias=True(), - allow_fast_activation=True()) - -Standard convolutional transpose layer. - -## Arguments - - - `k`: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D - convolutions `length(k) == 2` - - `in_chs`: Number of input channels - - `out_chs`: Number of input and output channels - - `activation`: Activation Function - -## Keyword Arguments - - - `init_weight`: Controls the initialization of the weight parameter - - `init_bias`: Controls the initialization of the bias parameter - - - `stride`: Should each be either single integer, or a tuple with `N` integers - - `dilation`: Should each be either single integer, or a tuple with `N` integers - - `pad`: Specifies the number of elements added to the borders of the data array. It can - be - - + a single integer for equal padding all around, - + a tuple of `N` integers, to apply the same padding at begin/end of each spatial - dimension, - + a tuple of `2*N` integers, for asymmetric padding, or - + the singleton `SamePad()`, to calculate padding such that - `size(output,d) == size(x,d) * stride` (possibly rounded) for each spatial - dimension. - - - `groups`: Expected to be an `Int`. It specifies the number of groups to divide a - convolution into (set `groups = in_chs` for Depthwise Convolutions). `in_chs` - and `out_chs` must be divisible by `groups`. - - `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` - -# Extended Help - -## Inputs - - - `x`: Data satisfying `ndims(x) == N + 2 && size(x, N - 1) == in_chs`, i.e. - `size(x) = (I_N, ..., I_1, C_in, N)` - -## Returns - - - Output of the convolution transpose `y` of size `(O_N, ..., O_1, C_out, N)` where - - Empty `NamedTuple()` - -## Parameters - - - `weight`: Convolution Transpose kernel - - `bias`: Bias (present if `use_bias=true`) -""" -@concrete struct ConvTranspose <: AbstractExplicitLayer - activation - in_chs <: IntegerType - out_chs <: IntegerType - kernel_size <: Tuple{Vararg{IntegerType}} - stride <: Tuple{Vararg{IntegerType}} - pad <: Tuple{Vararg{IntegerType}} - dilation <: Tuple{Vararg{IntegerType}} - groups <: IntegerType - init_weight - init_bias - use_bias <: StaticBool -end - -function ConvTranspose( - k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, - activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, stride=1, pad=0, dilation=1, groups=1, - use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) - stride = Utils.expand(Val(length(k)), stride) - dilation = Utils.expand(Val(length(k)), dilation) - pad = if pad isa SamePad - calc_padding(pad, k .- stride .+ 1, dilation, stride) - else - calc_padding(pad, k, dilation, stride) - end - @argcheck allequal(length, (stride, dilation, k)) - - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation - return ConvTranspose(activation, first(ch), last(ch), k, stride, pad, dilation, - groups, init_weight, init_bias, static(use_bias)) -end - -function initialparameters(rng::AbstractRNG, c::ConvTranspose) - weight = init_conv_filter( - rng, c.kernel_size, c.out_chs => c.in_chs; init=c.init_weight, c.groups) - has_bias(c) || return (; weight) - return (; weight, - bias=c.init_bias(rng, ntuple(_ -> 1, length(c.kernel_size))..., c.out_chs, 1)) # TODO: flatten in v1 -end - -function parameterlength(c::ConvTranspose) - return prod(c.kernel_size) * c.in_chs * c.out_chs ÷ c.groups + has_bias(c) * c.out_chs -end - -function (c::ConvTranspose)(x::AbstractArray, ps, st::NamedTuple) - y = match_eltype(c, ps, st, x) - cdims = conv_transpose_dims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups) - bias = safe_vec(safe_getproperty(ps, Val(:bias))) - return bias_activation!!(c.activation, conv_transpose(y, ps.weight, cdims), bias), st +@concrete struct PixelShuffle <: AbstractLuxWrapperLayer{:layer} + layer <: AbstractLuxLayer end -function Base.show(io::IO, l::ConvTranspose) - print(io, "ConvTranspose(", l.kernel_size) - print(io, ", ", l.in_chs, " => ", l.out_chs) - l.activation == identity || print(io, ", ", l.activation) - all(==(0), l.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(l.pad)) - all(==(1), l.stride) || print(io, ", stride=", PrettyPrinting.tuple_string(l.stride)) - all(==(1), l.dilation) || - print(io, ", dilation=", PrettyPrinting.tuple_string(l.dilation)) - (l.groups == 1) || print(io, ", groups=", l.groups) - has_bias(l) || print(io, ", use_bias=false") - return print(io, ")") +function PixelShuffle(r::IntegerType) + return PixelShuffle(WrappedFunction(Base.Fix2(pixel_shuffle, r))) end diff --git a/src/layers/display.jl b/src/layers/display.jl index 5c2efc1769..6f5f52c644 100644 --- a/src/layers/display.jl +++ b/src/layers/display.jl @@ -1,13 +1,12 @@ module PrettyPrinting using Functors: Functors -using LuxCore: LuxCore, AbstractExplicitContainerLayer, AbstractExplicitLayer, display_name + +using LuxCore: LuxCore, AbstractLuxWrapperLayer, AbstractLuxLayer, display_name printable_children(x) = Functors.children(x) -function printable_children(m::AbstractExplicitContainerLayer{layers}) where {layers} +function printable_children(m::AbstractLuxWrapperLayer{field}) where {field} children = Functors.children(m) - length(layers) ≥ 2 && return children - field = first(layers) hasfield(typeof(children), field) || return children nt = getfield(children, field) nt isa NamedTuple || (nt = NamedTuple{(field,)}((nt,))) @@ -15,7 +14,7 @@ function printable_children(m::AbstractExplicitContainerLayer{layers}) where {la end show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for: -show_leaflike(x::AbstractExplicitLayer) = false +show_leaflike(x::AbstractLuxLayer) = false function underscorise(n::Integer) return join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') @@ -75,7 +74,7 @@ function show_parameters_count(io::IO, layer, indent, str::String) return end -function print_wrapper_model(io::IO, desc::String, model::AbstractExplicitLayer) +function print_wrapper_model(io::IO, desc::String, model::AbstractLuxLayer) if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL print(io, desc, "(\n") big_show(io, model, 4) @@ -96,7 +95,8 @@ tuple_string(pad::Tuple) = all(==(pad[1]), pad) ? string(pad[1]) : string(pad) end -function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitContainerLayer) +function Base.show(io::IO, ::MIME"text/plain", + x::Union{AbstractLuxContainerLayer, AbstractLuxWrapperLayer}) if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL PrettyPrinting.big_show(io, x) elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix @@ -106,7 +106,7 @@ function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitContainerLayer end end -function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitLayer) +function Base.show(io::IO, ::MIME"text/plain", x::AbstractLuxLayer) !get(io, :compact, false) && return PrettyPrinting.layer_show(io, x) show(io, x) end diff --git a/src/layers/dropout.jl b/src/layers/dropout.jl index e1d1ffa441..a7ef56400f 100644 --- a/src/layers/dropout.jl +++ b/src/layers/dropout.jl @@ -28,7 +28,7 @@ Call [`Lux.testmode`](@ref) to switch to test mode. See also [`Dropout`](@ref), [`VariationalHiddenDropout`](@ref) """ -struct AlphaDropout{T <: Real} <: AbstractExplicitLayer +struct AlphaDropout{T <: Real} <: AbstractLuxLayer p::T alpha::T scale::T @@ -90,7 +90,7 @@ Call [`Lux.testmode`](@ref) to switch to test mode. See also [`AlphaDropout`](@ref), [`VariationalHiddenDropout`](@ref) """ -@concrete struct Dropout{T} <: AbstractExplicitLayer +@concrete struct Dropout{T} <: AbstractLuxLayer p::T q::T dims @@ -154,7 +154,7 @@ Call [`Lux.testmode`](@ref) to switch to test mode. See also [`AlphaDropout`](@ref), [`Dropout`](@ref) """ -@concrete struct VariationalHiddenDropout{T} <: AbstractExplicitLayer +@concrete struct VariationalHiddenDropout{T} <: AbstractLuxLayer p::T q::T dims diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 6997f4538c..8242790a86 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -1,124 +1,5 @@ # Layers here serve as a compatibility layer between different frameworks. The core # implementation is present in extensions - -## DynamicExpressions.jl -## We could constrain the type of `operator_enum` to be `OperatorEnum` but defining -## custom types in extensions tends to be a PITA -""" - DynamicExpressionsLayer(operator_enum::OperatorEnum, expressions::Node...; - name::NAME_TYPE=nothing, eval_options::EvalOptions=EvalOptions()) - DynamicExpressionsLayer(operator_enum::OperatorEnum, - expressions::AbstractVector{<:Node}; kwargs...) - -Wraps a `DynamicExpressions.jl` `Node` into a Lux layer and allows the constant nodes to -be updated using any of the AD Backends. - -For details about these expressions, refer to the -[`DynamicExpressions.jl` documentation](https://symbolicml.org/DynamicExpressions.jl/dev/types/). - -## Arguments - - - `operator_enum`: `OperatorEnum` from `DynamicExpressions.jl` - - `expressions`: `Node` from `DynamicExpressions.jl` or `AbstractVector{<:Node}` - -## Keyword Arguments - - - `name`: Name of the layer - - `turbo`: Use LoopVectorization.jl for faster evaluation **(Deprecated)** - - `bumper`: Use Bumper.jl for faster evaluation **(Deprecated)** - - `eval_options`: EvalOptions from `DynamicExpressions.jl` - -These options are simply forwarded to `DynamicExpressions.jl`'s `eval_tree_array` -and `eval_grad_tree_array` function. - -!!! danger "Deprecation Notice" - - These options are deprecated and will be removed in v1. Please use the version in - [`Boltz.jl`](https://github.com/LuxDL/Boltz.jl) instead. -""" -struct DynamicExpressionsLayer{OE, E, N, EO} <: AbstractExplicitLayer - operator_enum::OE - expression::E - name::N - eval_options::EO - - function DynamicExpressionsLayer(operator_enum::OE, expression::E, name::N, - eval_options::EO) where {OE, E, N, EO} - Base.depwarn( - "`DynamicExpressionsLayer` is deprecated and will be removed in v1. Please \ - use the corresponding version in `Boltz.jl` instead.", - :DynamicExpressionsLayer) - return new{OE, E, N, EO}(operator_enum, expression, name, eval_options) - end -end - -function Base.show(io::IO, l::DynamicExpressionsLayer) - print(io, - "DynamicExpressionsLayer($(l.operator_enum), $(l.expression); eval_options=$(l.eval_options))") -end - -function initialparameters(::AbstractRNG, layer::DynamicExpressionsLayer) - params = map(Base.Fix2(getproperty, :val), - filter(node -> node.degree == 0 && node.constant, layer.expression)) - return (; params) -end - -function update_de_expression_constants!(expression, ps) - # Don't use `set_constant_refs!` here, since it requires the types to match. In our - # case we just warn the user - params = filter(node -> node.degree == 0 && node.constant, expression) - foreach(enumerate(params)) do (i, node) - (node.val isa typeof(ps[i])) || - @warn lazy"node.val::$(typeof(node.val)) != ps[$i]::$(typeof(ps[i])). Type of node.val takes precedence. Fix the input expression if this is unintended." maxlog=1 - return node.val = ps[i] - end - return -end - -function (de::DynamicExpressionsLayer)(x::AbstractVector, ps, st) - y, stₙ = de(reshape(x, :, 1), ps, st) - return vec(y), stₙ -end - -# NOTE: Unfortunately we can't use `get_device_type` since it causes problems with -# ReverseDiff -function (de::DynamicExpressionsLayer)(x::AbstractMatrix, ps, st) - y = match_eltype(de, ps, st, x) - return ( - apply_dynamic_expression( - de, de.expression, de.operator_enum, y, ps.params, MLDataDevices.get_device(x)), - st) -end - -function apply_dynamic_expression_internal end - -function apply_dynamic_expression( - de::DynamicExpressionsLayer, expr, operator_enum, x, ps, ::CPUDevice) - if !is_extension_loaded(Val(:DynamicExpressions)) - error("`DynamicExpressions.jl` is not loaded. Please load it before using \ - `DynamicExpressionsLayer`.") - end - return apply_dynamic_expression_internal(de, expr, operator_enum, x, ps) -end - -function ∇apply_dynamic_expression end - -function CRC.rrule(::typeof(apply_dynamic_expression), de::DynamicExpressionsLayer, - expr, operator_enum, x, ps, ::CPUDevice) - if !is_extension_loaded(Val(:DynamicExpressions)) - error("`DynamicExpressions.jl` is not loaded. Please load it before using \ - `DynamicExpressionsLayer`.") - end - return ∇apply_dynamic_expression(de, expr, operator_enum, x, ps) -end - -function apply_dynamic_expression(de, expr, operator_enum, x, ps, dev) - throw(ArgumentError("`DynamicExpressions.jl` only supports CPU operations. Current \ - device detected as $(dev). CUDA.jl will be supported after \ - https://github.com/SymbolicML/DynamicExpressions.jl/pull/65 is \ - merged upstream.")) -end - ## Flux.jl """ FluxLayer(layer) @@ -144,7 +25,7 @@ API internally. - `p`: Flattened parameters of the `layer` """ -@concrete struct FluxLayer <: AbstractExplicitLayer +@concrete struct FluxLayer <: AbstractLuxLayer layer re <: Optimisers.Restructure init_parameters @@ -167,11 +48,11 @@ Base.show(io::IO, ::MIME"text/plain", l::FluxLayer) = print(io, "FluxLayer($(l.l ## SimpleChains.jl """ - SimpleChainsLayer{ToArray}(layer, lux_layer=nothing) - SimpleChainsLayer(layer, ToArray::Union{Bool, Val}=Val(false)) + SimpleChainsLayer(layer, to_array::Union{Bool, Val}=Val(false)) + SimpleChainsLayer(layer, lux_layer, to_array) Wraps a `SimpleChains` layer into a `Lux` layer. All operations are performed using -`SimpleChains` but the layer satisfies the `AbstractExplicitLayer` interface. +`SimpleChains` but the layer satisfies the `AbstractLuxLayer` interface. `ToArray` is a boolean flag that determines whether the output should be converted to a regular `Array` or not. Default is `false`. @@ -179,50 +60,32 @@ regular `Array` or not. Default is `false`. ## Arguments - `layer`: SimpleChains layer - -!!! note - - If using `Tracker.jl`, the output will always be a regular `Array`. - -!!! danger - - `Tracker.jl` sometimes produces incorrect gradients for `SimpleChains.jl` models. As - such please test your model with `FiniteDiff.jl` or `Zygote.jl` before using - `Tracker.jl` for your model. + - `lux_layer`: Potentially equivalent Lux layer that is used for printing """ -struct SimpleChainsLayer{ToArray, SL, LL <: Union{Nothing, AbstractExplicitLayer}} <: - AbstractExplicitLayer - to_array::ToArray - layer::SL - lux_layer::LL - - function SimpleChainsLayer{ToArray}(layer, lux_layer=nothing) where {ToArray} - to_array = static(ToArray) - return new{typeof(to_array), typeof(layer), typeof(lux_layer)}( - to_array, layer, lux_layer) - end - function SimpleChainsLayer(layer, ToArray::BoolType=False()) - to_array = static(ToArray) - return new{typeof(to_array), typeof(layer), Nothing}(to_array, layer, nothing) - end +@concrete struct SimpleChainsLayer <: AbstractLuxLayer + layer + lux_layer <: Union{Nothing, AbstractLuxLayer} + to_array <: StaticBool +end + +function SimpleChainsLayer(layer, to_array::BoolType=False()) + return SimpleChainsLayer(layer, nothing, static(to_array)) end -function Base.show( - io::IO, ::MIME"text/plain", s::SimpleChainsLayer{ToArray}) where {ToArray} - PrettyPrinting.print_wrapper_model( - io, "SimpleChainsLayer{to_array=$ToArray}", s.lux_layer) +function Base.show(io::IO, ::MIME"text/plain", s::SimpleChainsLayer) + PrettyPrinting.print_wrapper_model(io, "SimpleChainsLayer", s.lux_layer) end function (sc::SimpleChainsLayer)(x, ps, st) y = match_eltype(sc, ps, st, x) return ( - simple_chain_output( - sc, apply_simple_chain(sc.layer, y, ps.params, MLDataDevices.get_device(x))), + to_array(sc.to_array, + apply_simple_chain(sc.layer, y, ps.params, MLDataDevices.get_device(x))), st) end -simple_chain_output(::SimpleChainsLayer{False}, y) = y -simple_chain_output(::SimpleChainsLayer{True}, y) = convert(Array, y) +to_array(::False, y) = y +to_array(::True, y) = convert(Array, y) apply_simple_chain(layer, x, ps, ::CPUDevice) = layer(x, ps) diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index dc7de1252c..a355b03a1c 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -1,7 +1,6 @@ @doc doc""" BatchNorm(chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32, - affine=True(), track_stats=True(), epsilon=1f-5, momentum=0.1f0, - allow_fast_activation::Bool=true) + affine=True(), track_stats=True(), epsilon=1f-5, momentum=0.1f0) [Batch Normalization](https://arxiv.org/abs/1502.03167) layer. @@ -22,9 +21,6 @@ slice and normalises the input accordingly. - `epsilon`: a value added to the denominator for numerical stability - `momentum`: the value used for the `running_mean` and `running_var` computation - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. @@ -35,7 +31,7 @@ slice and normalises the input accordingly. ## Inputs - - `x`: Array where `size(x, N - 1) = chs` and `ndims(x) > 2` + - `x`: Array where `size(x, N - 1) = chs` ## Returns @@ -81,15 +77,15 @@ Chain( !!! warning - Passing a batch size of 1, during training will result in NaNs. + Passing a batch size of 1, during training will result in an error. See also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`LayerNorm`](@ref), [`WeightNorm`](@ref) """ -@concrete struct BatchNorm{N} <: AbstractExplicitLayer +@concrete struct BatchNorm <: AbstractLuxLayer activation - epsilon::N - momentum::N + epsilon <: Real + momentum <: Real chs <: IntegerType init_bias init_scale @@ -98,9 +94,8 @@ See also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`LayerNorm`](@ref), end function BatchNorm(chs::IntegerType, activation=identity; init_bias=zeros32, - init_scale=ones32, affine::BoolType=True(), track_stats::BoolType=True(), - epsilon=1.0f-5, momentum=0.1f0, allow_fast_activation::BoolType=True()) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + init_scale=ones32, affine::BoolType=True(), + track_stats::BoolType=True(), epsilon=1.0f-5, momentum=0.1f0) return BatchNorm(activation, epsilon, momentum, chs, init_bias, init_scale, static(affine), static(track_stats)) end @@ -129,10 +124,12 @@ function (BN::BatchNorm)(x::AbstractArray, ps, st::NamedTuple) end x′ = match_eltype(BN, ps, st, x) + σ = NNlib.fast_act(BN.activation, x′) y, stats = batchnorm( x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)), safe_getproperty(st, Val(:running_mean)), safe_getproperty(st, Val(:running_var)), - st.training, BN.activation, BN.momentum, BN.epsilon) + st.training, σ, convert(unwrapped_eltype(x′), BN.momentum), + convert(unwrapped_eltype(x′), BN.epsilon)) return y, update_batchnorm_state(BN, st, stats) end @@ -153,8 +150,7 @@ end """ GroupNorm(chs::Integer, groups::Integer, activation=identity; init_bias=zeros32, - init_scale=ones32, affine=true, epsilon=1f-5, - allow_fast_activation::Bool=true) + init_scale=ones32, affine=true, epsilon=1f-5) [Group Normalization](https://arxiv.org/abs/1803.08494) layer. @@ -171,9 +167,6 @@ end - `epsilon`: a value added to the denominator for numerical stability - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. @@ -222,7 +215,7 @@ Chain( See also [`GroupNorm`](@ref), [`InstanceNorm`](@ref), [`LayerNorm`](@ref), [`WeightNorm`](@ref) """ -@concrete struct GroupNorm <: AbstractExplicitLayer +@concrete struct GroupNorm <: AbstractLuxLayer activation epsilon chs <: IntegerType @@ -232,11 +225,10 @@ See also [`GroupNorm`](@ref), [`InstanceNorm`](@ref), [`LayerNorm`](@ref), affine <: StaticBool end -function GroupNorm(chs::IntegerType, groups::IntegerType, activation=identity; - init_bias=zeros32, init_scale=ones32, affine::BoolType=True(), - epsilon=1.0f-5, allow_fast_activation::BoolType=True()) +function GroupNorm( + chs::IntegerType, groups::IntegerType, activation=identity; init_bias=zeros32, + init_scale=ones32, affine::BoolType=True(), epsilon=1.0f-5) @argcheck chs % groups==0 "The number of groups ($(groups)) must divide the number of channels ($chs)" - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation return GroupNorm( activation, epsilon, chs, init_bias, init_scale, groups, static(affine)) end @@ -250,8 +242,9 @@ parameterlength(l::GroupNorm) = has_affine(l) ? (l.chs * 2) : 0 function (GN::GroupNorm)(x::AbstractArray, ps, st::NamedTuple) x′ = match_eltype(GN, ps, st, x) - y = groupnorm(x′, safe_getproperty(ps, Val(:scale)), - safe_getproperty(ps, Val(:bias)), GN.groups, GN.activation, GN.epsilon) + σ = NNlib.fast_act(GN.activation, x′) + y = groupnorm(x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)), + GN.groups, σ, convert(unwrapped_eltype(x′), GN.epsilon)) return y, st end @@ -264,7 +257,7 @@ end @doc doc""" InstanceNorm(chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32, - affine=true, epsilon=1f-5, allow_fast_activation::Bool=true) + affine=False(), track_stats=False(), epsilon=1f-5, momentum=0.1f0) Instance Normalization. For details see [1]. @@ -281,16 +274,19 @@ accordingly. ## Keyword Arguments + - If `track_stats=true`, accumulates mean and variance statistics in training phase that + will be used to renormalize the input in test phase. + - `epsilon`: a value added to the denominator for numerical stability - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` + - `momentum`: the value used for the `running_mean` and `running_var` computation - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. + `init_bias`: Controls how the `bias` is initialized + `init_scale`: Controls how the `scale` is initialized +# Extended Help + ## Inputs - `x`: Array where `size(x, N - 1) = chs` and `ndims(x) > 2` @@ -311,6 +307,15 @@ accordingly. ## States + - Statistics if `track_stats=true` + + + `running_mean`: Running mean of shape `(chs,)` + + `running_var`: Running variance of shape `(chs,)` + + - Statistics if `track_stats=false` + + + `running_mean`: nothing + + `running_var`: nothing - `training`: Used to check if training/inference mode Use `Lux.testmode` during inference. @@ -318,13 +323,13 @@ Use `Lux.testmode` during inference. ## Example ```jldoctest -julia> Chain(Dense(784 => 64), InstanceNorm(64, relu), Dense(64 => 10), - InstanceNorm(10, relu)) +julia> Chain(Dense(784 => 64), InstanceNorm(64, relu; affine=true), Dense(64 => 10), + InstanceNorm(10, relu; affine=true)) Chain( layer_1 = Dense(784 => 64), # 50_240 parameters - layer_2 = InstanceNorm(64, relu, affine=true), # 128 parameters, plus 1 + layer_2 = InstanceNorm(64, relu, affine=true, track_stats=false), # 128 parameters, plus 1 layer_3 = Dense(64 => 10), # 650 parameters - layer_4 = InstanceNorm(10, relu, affine=true), # 20 parameters, plus 1 + layer_4 = InstanceNorm(10, relu, affine=true, track_stats=false), # 20 parameters, plus 1 ) # Total: 51_038 parameters, # plus 2 states. ``` @@ -336,20 +341,22 @@ Chain( See also [`BatchNorm`](@ref), [`GroupNorm`](@ref), [`LayerNorm`](@ref), [`WeightNorm`](@ref) """ -@concrete struct InstanceNorm <: AbstractExplicitLayer +@concrete struct InstanceNorm <: AbstractLuxLayer activation - epsilon + epsilon <: Real + momentum <: Real chs <: IntegerType init_bias init_scale affine <: StaticBool + track_stats <: StaticBool end -function InstanceNorm( - chs::IntegerType, activation=identity; init_bias=zeros32, init_scale=ones32, - affine::BoolType=True(), epsilon=1.0f-5, allow_fast_activation::BoolType=True()) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation - return InstanceNorm(activation, epsilon, chs, init_bias, init_scale, static(affine)) +function InstanceNorm(chs::IntegerType, activation=identity; init_bias=zeros32, + init_scale=ones32, affine::BoolType=False(), + track_stats::BoolType=False(), epsilon=1.0f-5, momentum=0.1f0) + return InstanceNorm(activation, epsilon, momentum, chs, init_bias, + init_scale, static(affine), static(track_stats)) end function initialparameters(rng::AbstractRNG, l::InstanceNorm) @@ -357,14 +364,25 @@ function initialparameters(rng::AbstractRNG, l::InstanceNorm) return (;) end -initialstates(::AbstractRNG, ::InstanceNorm) = (; training=Val(true)) +function initialstates(rng::AbstractRNG, l::InstanceNorm) + if has_track_stats(l) + return (running_mean=zeros32(rng, l.chs), + running_var=ones32(rng, l.chs), training=Val(true)) + end + return (; training=Val(true)) +end + parameterlength(l::InstanceNorm) = ifelse(has_affine(l), l.chs * 2, 0) +statelength(l::InstanceNorm) = ifelse(has_track_stats(l), l.chs * 2, 0) + 1 function (IN::InstanceNorm)(x::AbstractArray, ps, st::NamedTuple) x′ = match_eltype(IN, ps, st, x) + σ = NNlib.fast_act(IN.activation, x′) y, _ = instancenorm( x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)), - st.training, IN.activation, IN.epsilon) + safe_getproperty(st, Val(:running_mean)), safe_getproperty(st, Val(:running_var)), + st.training, σ, convert(unwrapped_eltype(x′), IN.momentum), + convert(unwrapped_eltype(x′), IN.epsilon)) return y, st end @@ -372,11 +390,107 @@ function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(l.chs)") (l.activation == identity) || print(io, ", $(l.activation)") print(io, ", affine=$(has_affine(l))") + print(io, ", track_stats=$(has_track_stats(l))") + return print(io, ")") +end + +@doc doc""" + LayerNorm(shape::NTuple{N, Int}, activation=identity; epsilon=1f-5, dims=Colon(), + affine=true, init_bias=zeros32, init_scale=ones32) + +Computes mean and standard deviation over the whole input array, and uses these to +normalize the whole array. Optionally applies an elementwise affine transformation +afterwards. + +Given an input array ``x``, this layer computes + +```math +y = \frac{x - \mathbb{E}[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta +``` + +where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. + +!!! warning "Inconsistent Defaults till v0.5.0" + + As of v0.5.0, the doc used to say `affine::Bool=false`, but the code actually had + `affine::Bool=true` as the default. Now the doc reflects the code, so please check + whether your assumptions about the default (if made) were invalid. + +## Arguments + + - `shape`: Broadcastable shape of input array excluding the batch dimension. + - `activation`: After normalization, elementwise activation `activation` is applied. + +## Keyword Arguments + + - `epsilon`: a value added to the denominator for numerical stability. + - `dims`: Dimensions to normalize the array over. + - If `affine=true`, it also applies a shift and a rescale to the input through to + learnable per-element bias and scale parameters. + + + `init_bias`: Controls how the `bias` is initialized + + `init_scale`: Controls how the `scale` is initialized + +# Extended Help + +## Inputs + + - `x`: AbstractArray + +## Returns + + - `y`: Normalized Array + - Empty NamedTuple() + +## Parameters + + - `affine=false`: Empty `NamedTuple()` + - `affine=true` + + + `bias`: Bias of shape `(shape..., 1)` + + `scale`: Scale of shape `(shape..., 1)` +""" +@concrete struct LayerNorm <: AbstractLuxLayer + shape + activation + epsilon + init_bias + init_scale + dims + affine <: StaticBool +end + +function LayerNorm(shape, activation=identity; epsilon=1.0f-5, dims=Colon(), + affine::BoolType=True(), init_bias=zeros32, init_scale=ones32) + return LayerNorm( + shape, activation, epsilon, init_bias, init_scale, dims, static(affine)) +end + +function initialparameters(rng::AbstractRNG, ln::LayerNorm) + if has_affine(ln) + dims = (ln.shape..., 1) + return (; bias=ln.init_bias(rng, dims...), scale=ln.init_scale(rng, dims...)) + end + return (;) +end + +function (l::LayerNorm)(x::AbstractArray, ps, st::NamedTuple) + x′ = match_eltype(l, ps, st, x) + σ = NNlib.fast_act(l.activation, x′) + y = layernorm(x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)), + σ, l.dims, convert(unwrapped_eltype(x′), l.epsilon)) + return y, st +end + +function Base.show(io::IO, l::LayerNorm) + print(io, "LayerNorm($(l.shape)") + (l.activation == identity) || print(io, ", $(l.activation)") + print(io, ", affine=$(has_affine(l)), dims=$(l.dims)") return print(io, ")") end @doc doc""" - WeightNorm(layer::AbstractExplicitLayer, which_params::NTuple{N, Symbol}, + WeightNorm(layer::AbstractLuxLayer, which_params::NTuple{N, Symbol}, dims::Union{Tuple, Nothing}=nothing) Applies [weight normalization](https://arxiv.org/abs/1602.07868) to a parameter in the given @@ -416,13 +530,13 @@ parameters: one specifying the magnitude (e.g. `weight_g`) and one specifying th - Same as that of `layer` """ -@concrete struct WeightNorm <: AbstractExplicitLayer - layer <: AbstractExplicitLayer +@concrete struct WeightNorm <: AbstractLuxLayer + layer <: AbstractLuxLayer which_params dims function WeightNorm( - layer::AbstractExplicitLayer, which_params, dims::Union{Tuple, Nothing}=nothing) + layer::AbstractLuxLayer, which_params, dims::Union{Tuple, Nothing}=nothing) which_params = static(which_params) dims = static(dims) return new{typeof(layer), typeof(which_params), typeof(dims)}( @@ -497,102 +611,3 @@ function Base.show(io::IO, ::MIME"text/plain", w::WeightNorm) return print(io, "WeightNorm(", w.layer, ", dims = ", known(w.dims), ", normalized_parameters = ", known(w.which_params), ")") end - -@doc doc""" - LayerNorm(shape::NTuple{N, Int}, activation=identity; epsilon=1f-5, dims=Colon(), - affine=true, init_bias=zeros32, init_scale=ones32) - -Computes mean and standard deviation over the whole input array, and uses these to -normalize the whole array. Optionally applies an elementwise affine transformation -afterwards. - -Given an input array ``x``, this layer computes - -```math -y = \frac{x - \mathbb{E}[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta -``` - -where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. - -!!! warning "Inconsistent Defaults till v0.5.0" - - As of v0.5.0, the doc used to say `affine::Bool=false`, but the code actually had - `affine::Bool=true` as the default. Now the doc reflects the code, so please check - whether your assumptions about the default (if made) were invalid. - -## Arguments - - - `shape`: Broadcastable shape of input array excluding the batch dimension. - - `activation`: After normalization, elementwise activation `activation` is applied. - -## Keyword Arguments - - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` - - `epsilon`: a value added to the denominator for numerical stability. - - `dims`: Dimensions to normalize the array over. - - If `affine=true`, it also applies a shift and a rescale to the input through to - learnable per-element bias and scale parameters. - - + `init_bias`: Controls how the `bias` is initialized - + `init_scale`: Controls how the `scale` is initialized - -# Extended Help - -## Inputs - - - `x`: AbstractArray - -## Returns - - - `y`: Normalized Array - - Empty NamedTuple() - -## Parameters - - - `affine=false`: Empty `NamedTuple()` - - `affine=true` - - + `bias`: Bias of shape `(shape..., 1)` - + `scale`: Scale of shape `(shape..., 1)` -""" -@concrete struct LayerNorm <: AbstractExplicitLayer - shape - activation - epsilon - init_bias - init_scale - dims - affine <: StaticBool -end - -function LayerNorm( - shape, activation=identity; epsilon=1.0f-5, dims=Colon(), affine::BoolType=True(), - init_bias=zeros32, init_scale=ones32, allow_fast_activation::BoolType=True()) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation - return LayerNorm( - shape, activation, epsilon, init_bias, init_scale, dims, static(affine)) -end - -function initialparameters(rng::AbstractRNG, ln::LayerNorm) - if has_affine(ln) - dims = (ln.shape..., 1) - return (; bias=ln.init_bias(rng, dims...), scale=ln.init_scale(rng, dims...)) - end - return (;) -end - -function (l::LayerNorm)(x::AbstractArray, ps, st::NamedTuple) - x′ = match_eltype(l, ps, st, x) - y = layernorm(x′, safe_getproperty(ps, Val(:scale)), - safe_getproperty(ps, Val(:bias)), l.activation, l.dims, l.epsilon) - return y, st -end - -function Base.show(io::IO, l::LayerNorm) - print(io, "LayerNorm($(l.shape)") - (l.activation == identity) || print(io, ", $(l.activation)") - print(io, ", affine=$(has_affine(l)), dims=$(l.dims)") - return print(io, ")") -end diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl new file mode 100644 index 0000000000..bc4da7b089 --- /dev/null +++ b/src/layers/pooling.jl @@ -0,0 +1,253 @@ +abstract type AbstractPoolMode end + +CRC.@non_differentiable (::AbstractPoolMode)(::Any...) + +@concrete struct GenericPoolMode <: AbstractPoolMode + kernel_size <: Tuple{Vararg{IntegerType}} + stride <: Tuple{Vararg{IntegerType}} + pad <: Tuple{Vararg{IntegerType}} + dilation <: Tuple{Vararg{IntegerType}} +end + +(m::GenericPoolMode)(x) = PoolDims(x, m.kernel_size; padding=m.pad, m.stride, m.dilation) + +struct GlobalPoolMode <: AbstractPoolMode end + +(::GlobalPoolMode)(x) = PoolDims(x, size(x)[1:(end - 2)]) + +@concrete struct AdaptivePoolMode <: AbstractPoolMode + out_size <: Tuple{Vararg{IntegerType}} +end + +function (m::AdaptivePoolMode)(x) + in_size = size(x)[1:(end - 2)] + stride = in_size .÷ m.out_size + kernel_size = in_size .- (m.out_size .- 1) .* stride + return PoolDims(x, kernel_size; padding=0, stride, dilation=1) +end + +symbol_to_pool_mode(::StaticSymbol{:generic}) = GenericPoolMode +symbol_to_pool_mode(::StaticSymbol{:global}) = GlobalPoolMode +symbol_to_pool_mode(::StaticSymbol{:adaptive}) = AdaptivePoolMode + +abstract type AbstractPoolOp end + +struct MaxPoolOp <: AbstractPoolOp end +(m::MaxPoolOp)(x, pdims) = maxpool(x, pdims) + +struct MeanPoolOp <: AbstractPoolOp end +(m::MeanPoolOp)(x, pdims) = meanpool(x, pdims) + +@concrete struct LpPoolOp <: AbstractPoolOp + p +end +(m::LpPoolOp)(x, pdims) = lpnormpool(x, pdims; m.p) + +symbol_to_pool_op(::StaticSymbol{:max}, _) = MaxPoolOp() +symbol_to_pool_op(::StaticSymbol{:mean}, _) = MeanPoolOp() +symbol_to_pool_op(::StaticSymbol{:lp}, p) = LpPoolOp(p) + +@concrete struct PoolingLayer <: AbstractLuxLayer + mode <: AbstractPoolMode + op <: AbstractPoolOp +end + +function PoolingLayer(mode::SymbolType, op::SymbolType, + arg::Union{Nothing, Tuple{Vararg{IntegerType}}}=nothing; + stride=arg, pad=0, dilation=1, p=2) + return PoolingLayer(symbol_to_pool_mode(static(mode)), + symbol_to_pool_op(static(op), p), arg; stride, pad, dilation) +end + +function PoolingLayer(::Type{GenericPoolMode}, op::AbstractPoolOp, + kernel_size::Tuple{Vararg{IntegerType}}; stride=kernel_size, pad=0, dilation=1) + stride = Utils.expand(Val(length(kernel_size)), stride) + pad = calc_padding(pad, kernel_size, dilation, stride) + dilation = Utils.expand(Val(length(kernel_size)), dilation) + @argcheck allequal(length, (stride, kernel_size, dilation)) + + return PoolingLayer(GenericPoolMode(kernel_size, stride, pad, dilation), op) +end + +function PoolingLayer(::Type{AdaptivePoolMode}, op::AbstractPoolOp, + out_size::Tuple{Vararg{IntegerType}}; kwargs...) + return PoolingLayer(AdaptivePoolMode(out_size), op) +end + +function PoolingLayer(::Type{GlobalPoolMode}, op::AbstractPoolOp, ::Nothing; kwargs...) + return PoolingLayer(GlobalPoolMode(), op) +end + +(m::PoolingLayer)(x, _, st::NamedTuple) = m.op(x, m.mode(x)), st + +for layer_op in (:Max, :Mean, :LP) + op = Symbol(lowercase(string(layer_op))) + + no_gpu_danger = layer_op == :LP ? """ + + !!! danger "GPU Support" + + This layer is currently only supported on CPU. + """ : "" + + layer_name = Symbol(layer_op, :Pool) + extra_kwargs = layer_op == :LP ? ", p=2" : "" + layer_docstring = """ + $(layer_name)(window; stride=window, pad=0, dilation=1$(extra_kwargs)) + + $(layer_op) Pooling layer, which replaces all pixels in a block of size `window` with + the reduction operation: $(op). + + ## Arguments + + - `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling + `length(window) == 2` + + ## Keyword Arguments + + - `stride`: Should each be either single integer, or a tuple with `N` integers + - `dilation`: Should each be either single integer, or a tuple with `N` integers + + - `pad`: Specifies the number of elements added to the borders of the data array. It can + be + + + a single integer for equal padding all around, + + a tuple of `N` integers, to apply the same padding at begin/end of each spatial + dimension, + + a tuple of `2*N` integers, for asymmetric padding, or + + the singleton `SamePad()`, to calculate padding such that + `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial + dimension. + + $(no_gpu_danger) + + # Extended Help + + ## Inputs + + - `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` + + ## Returns + + - Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where + + ```math + O_i = \\left\\lfloor\\frac{I_i + p_i + p_{(i + N) \\% |p|} - d_i \\times (k_i - 1)}{s_i} + 1\\right\\rfloor + ``` + + - Empty `NamedTuple()` + """ + + global_layer_name = Symbol(:Global, layer_name) + extra_kwargs = layer_op == :LP ? "; p=2" : "" + global_pooling_docstring = """ + $(global_layer_name)($(extra_kwargs)) + + Global $(layer_op) Pooling layer. Transforms `(w, h, c, b)`-shaped input into + `(1, 1, c, b)`-shaped output, by performing mean pooling on the complete `(w, h)`-shaped + feature maps. + + $(no_gpu_danger) + + ## Inputs + + - `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` + + ## Returns + + - Output of the pooling `y` of size `(1, ..., 1, C, N)` + - Empty `NamedTuple()` + """ + + adaptive_layer_name = Symbol(:Adaptive, layer_name) + adaptive_pooling_docstring = """ + $(adaptive_layer_name)(output_size$(extra_kwargs)) + + Adaptive $(layer_op) Pooling layer. Calculates the necessary window size such that + its output has `size(y)[1:N] == output_size`. + + ## Arguments + + - `output_size`: Size of the first `N` dimensions for the output + + $(no_gpu_danger) + + ## Inputs + + - `x`: Expects as input an array with `ndims(x) == N + 2`, i.e. channel and batch + dimensions, after the `N` feature dimensions, where `N = length(output_size)`. + + ## Returns + + - Output of size `(out..., C, N)` + - Empty `NamedTuple()` + """ + + @eval begin + # Generic Pooling Layer + @doc $(layer_docstring) @concrete struct $(layer_name) <: + AbstractLuxWrapperLayer{:layer} + layer <: PoolingLayer + end + + function $(layer_name)( + window::Tuple{Vararg{IntegerType}}; stride=window, pad=0, dilation=1, p=2) + return $(layer_name)(PoolingLayer(static(:generic), static($(Meta.quot(op))), + window; stride, pad, dilation, p)) + end + + function Base.show(io::IO, ::MIME"text/plain", m::$(layer_name)) + kernel_size = m.layer.mode.kernel_size + print(io, string($(Meta.quot(layer_name))), "($(kernel_size)") + pad = m.layer.mode.pad + all(==(0), pad) || print(io, ", pad=", PrettyPrinting.tuple_string(pad)) + stride = m.layer.mode.stride + stride == kernel_size || + print(io, ", stride=", PrettyPrinting.tuple_string(stride)) + dilation = m.layer.mode.dilation + all(==(1), dilation) || + print(io, ", dilation=", PrettyPrinting.tuple_string(dilation)) + if $(Meta.quot(op)) == :lp + m.layer.op.p == 2 || print(io, ", p=", m.layer.op.p) + end + print(io, ")") + end + + # Global Pooling Layer + @doc $(global_pooling_docstring) @concrete struct $(global_layer_name) <: + AbstractLuxWrapperLayer{:layer} + layer <: PoolingLayer + end + + function $(global_layer_name)(; p=2) + return $(global_layer_name)(PoolingLayer(static(:global), $(Meta.quot(op)); p)) + end + + function Base.show(io::IO, ::MIME"text/plain", g::$(global_layer_name)) + print(io, string($(Meta.quot(global_layer_name))), "(") + if $(Meta.quot(op)) == :lp + g.layer.op.p == 2 || print(io, ", p=", g.layer.op.p) + end + print(io, ")") + end + + # Adaptive Pooling Layer + @doc $(adaptive_pooling_docstring) @concrete struct $(adaptive_layer_name) <: + AbstractLuxWrapperLayer{:layer} + layer <: PoolingLayer + end + + function $(adaptive_layer_name)(out_size::Tuple{Vararg{IntegerType}}; p=2) + return $(adaptive_layer_name)(PoolingLayer( + static(:adaptive), $(Meta.quot(op)), out_size; p)) + end + + function Base.show(io::IO, ::MIME"text/plain", a::$(adaptive_layer_name)) + print(io, string($(Meta.quot(adaptive_layer_name))), "(", a.layer.mode.out_size) + if $(Meta.quot(op)) == :lp + a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p) + end + print(io, ")") + end + end +end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index fd207bcad0..878c713beb 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,4 +1,4 @@ -abstract type AbstractRecurrentCell <: AbstractExplicitLayer end +abstract type AbstractRecurrentCell <: AbstractLuxLayer end const AbstractDebugRecurrentCell = Experimental.DebugLayer{ <:Any, <:Any, <:AbstractRecurrentCell} @@ -27,6 +27,20 @@ function LuxOps.eachslice(x::AbstractArray, ::BatchLastIndex) end LuxOps.eachslice(x::AbstractMatrix, ::BatchLastIndex) = LuxOps.eachslice(x, Val(ndims(x))) +function init_rnn_weight(rng::AbstractRNG, init_weight, hidden_dims, dims) + if init_weight === nothing + bound = inv(sqrt(hidden_dims)) + y = randn32(rng, dims...) + @. y = (y - 0.5f0) * 2 * bound + return y + end + return init_weight(rng, dims...) +end + +function init_rnn_bias(rng::AbstractRNG, init_bias, hidden_dims, bias_len) + return init_rnn_weight(rng, init_bias, hidden_dims, (bias_len,)) +end + """ Recurrence(cell; ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex(), @@ -85,7 +99,7 @@ automatically operate over a sequence of inputs. For some discussion on this topic, see https://github.com/LuxDL/Lux.jl/issues/472. """ -@concrete struct Recurrence{R <: StaticBool} <: AbstractExplicitContainerLayer{(:cell,)} +@concrete struct Recurrence{R <: StaticBool} <: AbstractLuxWrapperLayer{:cell} cell <: Union{<:AbstractRecurrentCell, <:AbstractDebugRecurrentCell} ordering <: AbstractTimeSeriesDataBatchOrdering return_sequence::R @@ -151,7 +165,7 @@ update the state with `Lux.update_state(st, :carry, nothing)`. + `cell`: Same as `cell`. + `carry`: The carry state of the `cell`. """ -@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)} +@concrete struct StatefulRecurrentCell <: AbstractLuxWrapperLayer{:cell} cell <: Union{<:AbstractRecurrentCell, <:AbstractDebugRecurrentCell} end @@ -171,11 +185,11 @@ applyrecurrentcell(l::AbstractRecurrentCell, x, ps, st, ::Nothing) = apply(l, x, @doc doc""" RNNCell(in_dims => out_dims, activation=tanh; use_bias=True(), train_state=False(), - init_bias=zeros32, init_weight=glorot_uniform, init_state=ones32) + init_bias=nothing, init_weight=nothing, init_state=zeros32) An Elman RNNCell cell with `activation` (typically set to `tanh` or `relu`). -``h_{new} = activation(weight_{ih} \times x + weight_{hh} \times h_{prev} + bias)`` +``h_{new} = activation(weight_{ih} \times x + bias_{ih} + weight_{hh} \times h_{prev} + bias_{hh})`` ## Arguments @@ -184,8 +198,10 @@ An Elman RNNCell cell with `activation` (typically set to `tanh` or `relu`). - `activation`: Activation function - `use_bias`: Set to false to deactivate bias - `train_state`: Trainable initial hidden state can be activated by setting this to `true` - - `init_bias`: Initializer for bias - - `init_weight`: Initializer for weight + - `init_bias`: Initializer for bias. If `nothing`, then we use uniform distribution with + bounds `-bound` and `bound` where `bound = inv(sqrt(out_dims))`. + - `init_weight`: Initializer for weight. If `nothing`, then we use uniform distribution + with bounds `-bound` and `bound` where `bound = inv(sqrt(out_dims))`. - `init_state`: Initializer for hidden state ## Inputs @@ -199,6 +215,7 @@ An Elman RNNCell cell with `activation` (typically set to `tanh` or `relu`). updated hidden state is returned. ## Returns + - Tuple containing + Output ``h_{new}`` of shape `(out_dims, batch_size)` @@ -210,7 +227,8 @@ An Elman RNNCell cell with `activation` (typically set to `tanh` or `relu`). - `weight_ih`: Maps the input to the hidden state. - `weight_hh`: Maps the hidden state to the hidden state. - - `bias`: Bias vector (not present if `use_bias=false`) + - `bias_ih`: Bias vector for the input-hidden connection (not present if `use_bias=false`) + - `bias_hh`: Bias vector for the hidden-hidden connection (not present if `use_bias=false`) - `hidden_state`: Initial hidden state vector (not present if `train_state=false`) ## States @@ -230,15 +248,22 @@ end function RNNCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}, activation=tanh; use_bias::BoolType=True(), train_state::BoolType=False(), - init_bias=zeros32, init_weight=glorot_uniform, init_state=ones32) + init_bias=nothing, init_weight=nothing, init_state=zeros32) return RNNCell(static(train_state), activation, in_dims, out_dims, init_bias, init_weight, init_state, static(use_bias)) end function initialparameters(rng::AbstractRNG, rnn::RNNCell) - ps = (weight_ih=rnn.init_weight(rng, rnn.out_dims, rnn.in_dims), - weight_hh=rnn.init_weight(rng, rnn.out_dims, rnn.out_dims)) - has_bias(rnn) && (ps = merge(ps, (bias=rnn.init_bias(rng, rnn.out_dims),))) + weight_ih = init_rnn_weight( + rng, rnn.init_weight, rnn.out_dims, (rnn.out_dims, rnn.in_dims)) + weight_hh = init_rnn_weight( + rng, rnn.init_weight, rnn.out_dims, (rnn.out_dims, rnn.out_dims)) + ps = (; weight_ih, weight_hh) + if has_bias(rnn) + bias_ih = init_rnn_bias(rng, rnn.init_bias, rnn.out_dims, rnn.out_dims) + bias_hh = init_rnn_bias(rng, rnn.init_bias, rnn.out_dims, rnn.out_dims) + ps = merge(ps, (; bias_ih, bias_hh)) + end has_train_state(rnn) && (ps = merge(ps, (hidden_state=rnn.init_state(rng, rnn.out_dims),))) return ps @@ -248,12 +273,12 @@ initialstates(rng::AbstractRNG, ::RNNCell) = (rng=Utils.sample_replicate(rng),) function (rnn::RNNCell{False})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) - hidden_state = Utils.init_hidden_state(rng, rnn, x) + hidden_state = init_rnn_hidden_state(rng, rnn, x) return rnn((x, (hidden_state,)), ps, merge(st, (; rng))) end function (rnn::RNNCell{True})(x::AbstractMatrix, ps, st::NamedTuple) - hidden_state = Utils.init_trainable_hidden_state(ps.hidden_state, x) + hidden_state = init_trainable_rnn_hidden_state(ps.hidden_state, x) return rnn((x, (hidden_state,)), ps, st) end @@ -261,9 +286,15 @@ function (rnn::RNNCell)( (x, (hidden_state,))::Tuple{<:AbstractMatrix, Tuple{<:AbstractMatrix}}, ps, st::NamedTuple) y, hidden_stateₙ = match_eltype(rnn, ps, st, x, hidden_state) - bias = safe_getproperty(ps, Val(:bias)) - z = fused_dense_bias_activation(identity, ps.weight_hh, hidden_stateₙ, bias) - hₙ = fast_activation!!(rnn.activation, LuxLib.Impl.matmul(ps.weight_ih, y) .+ z) + + bias_hh = safe_getproperty(ps, Val(:bias_hh)) + z₁ = fused_dense_bias_activation(identity, ps.weight_hh, hidden_stateₙ, bias_hh) + + bias_ih = safe_getproperty(ps, Val(:bias_ih)) + z₂ = fused_dense_bias_activation(identity, ps.weight_ih, y, bias_ih) + + # TODO: This operation can be fused instead of doing add then activation + hₙ = fast_activation!!(rnn.activation, z₁ .+ z₂) return (hₙ, (hₙ,)), st end @@ -277,10 +308,8 @@ end @doc doc""" LSTMCell(in_dims => out_dims; use_bias::Bool=true, train_state::Bool=false, - train_memory::Bool=false, - init_weight=(glorot_uniform, glorot_uniform, glorot_uniform, glorot_uniform), - init_bias=(zeros32, zeros32, ones32, zeros32), init_state=zeros32, - init_memory=zeros32) + train_memory::Bool=false, init_weight=nothing, init_bias=nothing, + init_state=zeros32, init_memory=zeros32) Long Short-Term (LSTM) Cell @@ -302,8 +331,14 @@ Long Short-Term (LSTM) Cell - `use_bias`: Set to false to deactivate bias - `train_state`: Trainable initial hidden state can be activated by setting this to `true` - `train_memory`: Trainable initial memory can be activated by setting this to `true` - - `init_bias`: Initializer for bias. Must be a tuple containing 4 functions - - `init_weight`: Initializer for weight. Must be a tuple containing 4 functions + - `init_bias`: Initializer for bias. Must be a tuple containing 4 functions. If a single + value is passed, it is copied into a 4 element tuple. If `nothing`, then we use + uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(out_dims))`. + - `init_weight`: Initializer for weight. Must be a tuple containing 4 functions. If a + single value is passed, it is copied into a 4 element tuple. If `nothing`, then we use + uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(out_dims))`. - `init_state`: Initializer for hidden state - `init_memory`: Initializer for memory @@ -338,11 +373,13 @@ Long Short-Term (LSTM) Cell ## Parameters - - `weight_i`: Concatenated Weights to map from input space - ``\{ W_{ii}, W_{if}, W_{ig}, W_{io} \}``. - - `weight_h`: Concatenated Weights to map from hidden space - ``\{ W_{hi}, W_{hf}, W_{hg}, W_{ho} \}`` - - `bias`: Bias vector (not present if `use_bias=false`) + - `weight_ih`: Concatenated Weights to map from input space + ``\{ W_{ii}, W_{if}, W_{ig}, W_{io} \}``. + - `weight_hh`: Concatenated Weights to map from hidden space + ``\{ W_{hi}, W_{hf}, W_{hg}, W_{ho} \}`` + - `bias_ih`: Bias vector for the input-hidden connection (not present if `use_bias=false`) + - `bias_hh`: Concatenated Bias vector for the hidden-hidden connection (not present if + `use_bias=false`) - `hidden_state`: Initial hidden state vector (not present if `train_state=false`) - `memory`: Initial memory vector (not present if `train_memory=false`) @@ -362,10 +399,10 @@ Long Short-Term (LSTM) Cell use_bias <: StaticBool end -function LSTMCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}; - use_bias::BoolType=True(), train_state::BoolType=False(), - train_memory::BoolType=False(), init_weight=glorot_uniform, - init_bias=zeros32, init_state=zeros32, init_memory=zeros32) +function LSTMCell( + (in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}; use_bias::BoolType=True(), + train_state::BoolType=False(), train_memory::BoolType=False(), + init_weight=nothing, init_bias=nothing, init_state=zeros32, init_memory=zeros32) init_weight isa NTuple{4} || (init_weight = ntuple(Returns(init_weight), 4)) init_bias isa NTuple{4} || (init_bias = ntuple(Returns(init_bias), 4)) return LSTMCell(static(train_state), static(train_memory), in_dims, out_dims, @@ -373,15 +410,19 @@ function LSTMCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}; end function initialparameters(rng::AbstractRNG, lstm::LSTMCell) - weight_i = vcat([init_weight(rng, lstm.out_dims, lstm.in_dims) - for init_weight in lstm.init_weight]...) - weight_h = vcat([init_weight(rng, lstm.out_dims, lstm.out_dims) - for init_weight in lstm.init_weight]...) - ps = (; weight_i, weight_h) + weight_ih = vcat([init_rnn_weight( + rng, init_weight, lstm.out_dims, (lstm.out_dims, lstm.in_dims)) + for init_weight in lstm.init_weight]...) + weight_hh = vcat([init_rnn_weight( + rng, init_weight, lstm.out_dims, (lstm.out_dims, lstm.out_dims)) + for init_weight in lstm.init_weight]...) + ps = (; weight_ih, weight_hh) if has_bias(lstm) - # TODO: in v1 we make this a flat vector - bias = vcat([init_bias(rng, lstm.out_dims, 1) for init_bias in lstm.init_bias]...) - ps = merge(ps, (bias=bias,)) + bias_ih = vcat([init_rnn_bias(rng, init_bias, lstm.out_dims, lstm.out_dims) + for init_bias in lstm.init_bias]...) + bias_hh = vcat([init_rnn_bias(rng, init_bias, lstm.out_dims, lstm.out_dims) + for init_bias in lstm.init_bias]...) + ps = merge(ps, (; bias_ih, bias_hh)) end has_train_state(lstm) && (ps = merge(ps, (hidden_state=lstm.init_state(rng, lstm.out_dims),))) @@ -394,28 +435,28 @@ initialstates(rng::AbstractRNG, ::LSTMCell) = (rng=Utils.sample_replicate(rng),) function (lstm::LSTMCell{False, False})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) - hidden_state = Utils.init_hidden_state(rng, lstm, x) - memory = Utils.init_hidden_state(rng, lstm, x) + hidden_state = init_rnn_hidden_state(rng, lstm, x) + memory = init_rnn_hidden_state(rng, lstm, x) return lstm((x, (hidden_state, memory)), ps, merge(st, (; rng))) end function (lstm::LSTMCell{True, False})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) - hidden_state = Utils.init_trainable_hidden_state(ps.hidden_state, x) - memory = Utils.init_hidden_state(rng, lstm, x) + hidden_state = init_trainable_rnn_hidden_state(ps.hidden_state, x) + memory = init_rnn_hidden_state(rng, lstm, x) return lstm((x, (hidden_state, memory)), ps, merge(st, (; rng))) end function (lstm::LSTMCell{False, True})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) - hidden_state = Utils.init_hidden_state(rng, lstm, x) - memory = Utils.init_trainable_hidden_state(ps.memory, x) + hidden_state = init_rnn_hidden_state(rng, lstm, x) + memory = init_trainable_rnn_hidden_state(ps.memory, x) return lstm((x, (hidden_state, memory)), ps, merge(st, (; rng))) end function (lstm::LSTMCell{True, True})(x::AbstractMatrix, ps, st::NamedTuple) - hidden_state = Utils.init_trainable_hidden_state(ps.hidden_state, x) - memory = Utils.init_trainable_hidden_state(ps.memory, x) + hidden_state = init_trainable_rnn_hidden_state(ps.hidden_state, x) + memory = init_trainable_rnn_hidden_state(ps.memory, x) return lstm((x, (hidden_state, memory)), ps, st) end @@ -425,10 +466,11 @@ const _LSTMCellInputType = Tuple{ function (lstm::LSTMCell)( (x, (hidden_state, memory))::_LSTMCellInputType, ps, st::NamedTuple) y, hidden_stateₙ, memoryₙ = match_eltype(lstm, ps, st, x, hidden_state, memory) - bias = safe_vec(safe_getproperty(ps, Val(:bias))) - z = fused_dense_bias_activation(identity, ps.weight_h, hidden_stateₙ, bias) - g = LuxLib.Impl.matmul(ps.weight_i, y) .+ z - + bias_hh = safe_getproperty(ps, Val(:bias_hh)) + z₁ = fused_dense_bias_activation(identity, ps.weight_hh, hidden_stateₙ, bias_hh) + bias_ih = safe_getproperty(ps, Val(:bias_ih)) + z₂ = fused_dense_bias_activation(identity, ps.weight_ih, y, bias_ih) + g = z₁ .+ z₂ input, forget, cell, output = multigate(g, Val(4)) memory₂ = @. sigmoid_fast(forget) * memoryₙ + sigmoid_fast(input) * tanh_fast(cell) hidden_state₂ = @. sigmoid_fast(output) * tanh_fast(memory₂) @@ -445,17 +487,14 @@ end @doc doc""" GRUCell((in_dims, out_dims)::Pair{<:Int,<:Int}; use_bias=true, train_state::Bool=false, - init_weight::Tuple{Function,Function,Function}=(glorot_uniform, glorot_uniform, - glorot_uniform), - init_bias::Tuple{Function,Function,Function}=(zeros32, zeros32, zeros32), - init_state::Function=zeros32) + init_weight=nothing, init_bias=nothing, init_state=zeros32) Gated Recurrent Unit (GRU) Cell ```math \begin{align} - r &= \sigma(W_{ir} \times x + W_{hr} \times h_{prev} + b_{hr})\\ - z &= \sigma(W_{iz} \times x + W_{hz} \times h_{prev} + b_{hz})\\ + r &= \sigma(W_{ir} \times x + b_{ir} + W_{hr} \times h_{prev} + b_{hr})\\ + z &= \sigma(W_{iz} \times x + b_{iz} + W_{hz} \times h_{prev} + b_{hz})\\ n &= \tanh(W_{in} \times x + b_{in} + r \cdot (W_{hn} \times h_{prev} + b_{hn}))\\ h_{new} &= (1 - z) \cdot n + z \cdot h_{prev} \end{align} @@ -467,8 +506,14 @@ Gated Recurrent Unit (GRU) Cell - `out_dims`: Output (Hidden State) Dimension - `use_bias`: Set to false to deactivate bias - `train_state`: Trainable initial hidden state can be activated by setting this to `true` - - `init_bias`: Initializer for bias. Must be a tuple containing 3 functions - - `init_weight`: Initializer for weight. Must be a tuple containing 3 functions + - `init_bias`: Initializer for bias. Must be a tuple containing 3 functions. If a single + value is passed, it is copied into a 3 element tuple. If `nothing`, then we use + uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(out_dims))`. + - `init_weight`: Initializer for weight. Must be a tuple containing 3 functions. If a + single value is passed, it is copied into a 3 element tuple. If `nothing`, then we use + uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(out_dims))`. - `init_state`: Initializer for hidden state ## Inputs @@ -492,13 +537,14 @@ Gated Recurrent Unit (GRU) Cell ## Parameters - - `weight_i`: Concatenated Weights to map from input space - ``\{ W_{ir}, W_{iz}, W_{in} \}``. - - `weight_h`: Concatenated Weights to map from hidden space - ``\{ W_{hr}, W_{hz}, W_{hn} \}``. - - `bias_i`: Bias vector (``b_{in}``; not present if `use_bias=false`). - - `bias_h`: Concatenated Bias vector for the hidden space - ``\{ b_{hr}, b_{hz}, b_{hn} \}`` (not present if `use_bias=false`). + - `weight_ih`: Concatenated Weights to map from input space + ``\{ W_{ir}, W_{iz}, W_{in} \}``. + - `weight_hh`: Concatenated Weights to map from hidden space + ``\{ W_{hr}, W_{hz}, W_{hn} \}``. + - `bias_ih`: Concatenated Bias vector for the input space + ``\{ b_{ir}, b_{iz}, b_{in} \}`` (not present if `use_bias=false`). + - `bias_hh`: Concatenated Bias vector for the hidden space + ``\{ b_{hr}, b_{hz}, b_{hn} \}`` (not present if `use_bias=false`). - `hidden_state`: Initial hidden state vector (not present if `train_state=false`) ``\{ b_{hr}, b_{hz}, b_{hn} \}``. @@ -526,16 +572,19 @@ function GRUCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}; end function initialparameters(rng::AbstractRNG, gru::GRUCell) - weight_i = vcat([init_weight(rng, gru.out_dims, gru.in_dims) - for init_weight in gru.init_weight]...) - weight_h = vcat([init_weight(rng, gru.out_dims, gru.out_dims) - for init_weight in gru.init_weight]...) - ps = (; weight_i, weight_h) + weight_ih = vcat([init_rnn_weight( + rng, init_weight, gru.out_dims, (gru.out_dims, gru.in_dims)) + for init_weight in gru.init_weight]...) + weight_hh = vcat([init_rnn_weight( + rng, init_weight, gru.out_dims, (gru.out_dims, gru.out_dims)) + for init_weight in gru.init_weight]...) + ps = (; weight_ih, weight_hh) if has_bias(gru) - bias_i = gru.init_bias[1](rng, gru.out_dims, 1) - # TODO: in v1 we make this a flat vector - bias_h = vcat([init_bias(rng, gru.out_dims, 1) for init_bias in gru.init_bias]...) - ps = merge(ps, (bias_i=bias_i, bias_h=bias_h)) + bias_ih = vcat([init_rnn_bias(rng, init_bias, gru.out_dims, gru.out_dims) + for init_bias in gru.init_bias]...) + bias_hh = vcat([init_rnn_bias(rng, init_bias, gru.out_dims, gru.out_dims) + for init_bias in gru.init_bias]...) + ps = merge(ps, (; bias_ih, bias_hh)) end has_train_state(gru) && (ps = merge(ps, (hidden_state=gru.init_state(rng, gru.out_dims),))) @@ -545,14 +594,14 @@ end initialstates(rng::AbstractRNG, ::GRUCell) = (rng=Utils.sample_replicate(rng),) function (gru::GRUCell{True})(x::AbstractMatrix, ps, st::NamedTuple) - hidden_state = Utils.init_trainable_hidden_state(ps.hidden_state, x) + hidden_state = init_trainable_rnn_hidden_state(ps.hidden_state, x) return gru((x, (hidden_state,)), ps, st) end function (gru::GRUCell{False})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) st = merge(st, (; rng)) - hidden_state = Utils.init_hidden_state(rng, gru, x) + hidden_state = init_rnn_hidden_state(rng, gru, x) return gru((x, (hidden_state,)), ps, st) end @@ -560,21 +609,22 @@ const _GRUCellInputType = Tuple{<:AbstractMatrix, Tuple{<:AbstractMatrix}} function (gru::GRUCell)((x, (hidden_state,))::_GRUCellInputType, ps, st::NamedTuple) y, hidden_stateₙ = match_eltype(gru, ps, st, x, hidden_state) - gxs = multigate(ps.weight_i * y, Val(3)) - bias_h = safe_vec(safe_getproperty(ps, Val(:bias_h))) - ghbs = multigate( - fused_dense_bias_activation(identity, ps.weight_h, hidden_stateₙ, bias_h), Val(3)) - r = @. sigmoid_fast(gxs[1] + ghbs[1]) - z = @. sigmoid_fast(gxs[2] + ghbs[2]) - n = gru_cell_compute(gxs[3], r, ghbs[3], safe_getproperty(ps, Val(:bias_i))) - hidden_state₂ = @. (1 - z) * n + z * hidden_stateₙ + z₁ = fused_dense_bias_activation( + identity, ps.weight_ih, y, safe_getproperty(ps, Val(:bias_ih))) + z₂ = fused_dense_bias_activation( + identity, ps.weight_hh, hidden_stateₙ, safe_getproperty(ps, Val(:bias_hh))) - return (hidden_state₂, (hidden_state₂,)), st -end + gxs₁, gxs₂, gxs₃ = multigate(z₁, Val(3)) + ghbs₁, ghbs₂, ghbs₃ = multigate(z₂, Val(3)) + + r = @. sigmoid_fast(gxs₁ + ghbs₁) + z = @. sigmoid_fast(gxs₂ + ghbs₂) + n = @. tanh_fast(gxs₃ + r * ghbs₃) + h′ = @. (1 - z) * n + z * hidden_stateₙ -gru_cell_compute(x, r, y, ::Nothing) = @. tanh_fast(x + r * y) -gru_cell_compute(x, r, y, bias) = @. tanh_fast(x + r * y + bias) + return (h′, (h′,)), st +end function Base.show(io::IO, g::GRUCell) print(io, "GRUCell($(g.in_dims) => $(g.out_dims)") @@ -631,7 +681,7 @@ Bidirectional RNN wrapper. - Same as `cell` and `backward_cell`. """ -@concrete struct BidirectionalRNN <: AbstractExplicitContainerLayer{(:model,)} +@concrete struct BidirectionalRNN <: AbstractLuxWrapperLayer{:model} model <: Parallel end diff --git a/src/preferences.jl b/src/preferences.jl index 920f749ad3..a3eaff5445 100644 --- a/src/preferences.jl +++ b/src/preferences.jl @@ -1,26 +1,10 @@ module LuxPreferences using ArgCheck: @argcheck -using Preferences: load_preference, has_preference, set_preferences! +using Preferences: load_preference, has_preference, set_preferences!, @load_preference using ..Lux: Lux -macro deprecate_preference(old_pref, new_pref, default) - msg1 = "Preference `$(old_pref)` is deprecated and will be removed in a future \ - release. Use `$(new_pref)` instead." - msg2 = "Both `$(old_pref)` and `$(new_pref)` preferences are set. Please remove \ - `$(old_pref)`." - return esc(quote - if has_preference($(Lux), $(old_pref)) - Base.depwarn($msg1, $(Meta.quot(Symbol(Lux)))) - has_preference($(Lux), $(new_pref)) && error($msg2) - load_preference($(Lux), $(old_pref), $(default)) - else - load_preference($(Lux), $(new_pref), $(default)) - end - end) -end - macro load_preference_with_choices(pref, default, choices) msg1 = "Invalid value for `$(pref)` preference: " msg2 = ". Valid choices are: $(choices)" @@ -32,14 +16,12 @@ macro load_preference_with_choices(pref, default, choices) end # Nested AD -const AUTOMATIC_NESTED_AD_SWITCHING = @deprecate_preference("DisableAutomaticNestedADSwitching", - "automatic_nested_ad_switching", true) +const AUTOMATIC_NESTED_AD_SWITCHING = @load_preference("automatic_nested_ad_switching", + true) # GPU-Aware MPI -const MPI_CUDA_AWARE = @deprecate_preference("LuxDistributedMPICUDAAware", "cuda_aware_mpi", - false) -const MPI_ROCM_AWARE = @deprecate_preference("LuxDistributedMPIROCMAware", "rocm_aware_mpi", - false) +const MPI_CUDA_AWARE = @load_preference("cuda_aware_mpi", false) +const MPI_ROCM_AWARE = @load_preference("rocm_aware_mpi", false) # Eltype Auto Conversion const ELTYPE_MISMATCH_HANDLING = @load_preference_with_choices("eltype_mismatch_handling", diff --git a/src/transform/simplechains.jl b/src/transform/simplechains.jl index d224fcd64b..18c840d590 100644 --- a/src/transform/simplechains.jl +++ b/src/transform/simplechains.jl @@ -2,7 +2,7 @@ ToSimpleChainsAdaptor(input_dims, convert_to_array::Bool=false) Adaptor for converting a Lux Model to SimpleChains. The returned model is still a Lux model, -and satisfies the `AbstractExplicitLayer` interfacem but all internal calculations are +and satisfies the `AbstractLuxLayer` interfacem but all internal calculations are performed using SimpleChains. !!! warning @@ -59,17 +59,17 @@ struct ToSimpleChainsAdaptor{ID, AT} <: AbstractFromLuxAdaptor end """ - Adapt.adapt(from::ToSimpleChainsAdaptor, L::AbstractExplicitLayer) + Adapt.adapt(from::ToSimpleChainsAdaptor, L::AbstractLuxLayer) Adapt a Simple Chains model to Lux model. See [`ToSimpleChainsAdaptor`](@ref) for more details. """ -function Adapt.adapt(to::ToSimpleChainsAdaptor, L::AbstractExplicitLayer) +function Adapt.adapt(to::ToSimpleChainsAdaptor, L::AbstractLuxLayer) if Base.get_extension(@__MODULE__, :LuxSimpleChainsExt) === nothing error("`ToSimpleChainsAdaptor` requires `SimpleChains.jl` to be loaded.") end sc_layer = fix_simplechain_input_dims(make_simplechain_network(L), to.input_dims) - return SimpleChainsLayer{to.convert_to_array}(sc_layer, L) + return SimpleChainsLayer(sc_layer, L, static(to.convert_to_array)) end function make_simplechain_network end diff --git a/src/utils.jl b/src/utils.jl index d47cea6139..13e442d087 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,14 +10,15 @@ using Functors: fmapstructure using Random: AbstractRNG using Static: Static, StaticBool, StaticInteger, StaticSymbol -using LuxCore: LuxCore, AbstractExplicitLayer +using LuxCore: LuxCore, AbstractLuxLayer using MLDataDevices: get_device +using NNlib: NNlib const CRC = ChainRulesCore const BoolType = Union{StaticBool, Bool, Val{true}, Val{false}} const IntegerType = Union{Integer, StaticInteger} -const SymbolType = Union{Symbol, StaticSymbol} +const SymbolType = Union{Symbol, StaticSymbol, Val} # Aliased `size` from Base size(x::AbstractArray) = Base.size(x) @@ -126,6 +127,8 @@ eltype(x) = eltype(Base.eltype(x)) eltype(::Type{T}) where {T} = T eltype(::Type{<:Dual{T, V}}) where {T, V} = V +@non_differentiable eltype(::Any) + ofeltype_array(::Type{T}, x::AbstractArray) where {T} = broadcast(T, x) function ofeltype_array(::Type{T}, x::AbstractArray{<:Dual{Tag, V, N}}) where {Tag, T, V, N} return Dual{Tag, T, N}.(x) @@ -157,11 +160,13 @@ end add!!(x::Number, y::Number) = x + y add!!(::Nothing, ::Nothing) = nothing -function init_hidden_state(rng::AbstractRNG, rnn, x::AbstractMatrix) +function init_rnn_hidden_state(rng::AbstractRNG, rnn, x::AbstractMatrix) + # TODO: Once we support moving `rng` to the device, we can directly initialize on the + # device return rnn.init_state(rng, rnn.out_dims, Base.size(x, 2)) |> get_device(x) end -function init_trainable_hidden_state(hidden_state::AbstractVector, x::AbstractMatrix) +function init_trainable_rnn_hidden_state(hidden_state::AbstractVector, x::AbstractMatrix) return repeat(hidden_state, 1, Base.size(x, 2)) end @@ -189,7 +194,7 @@ set_refval!(x, y) = (x[] = y) @non_differentiable set_refval!(::Any...) EnzymeRules.inactive(::typeof(set_refval!), ::Any...) = nothing -function named_tuple_layers(layers::Vararg{AbstractExplicitLayer, N}) where {N} +function named_tuple_layers(layers::Vararg{AbstractLuxLayer, N}) where {N} return NamedTuple{ntuple(i -> Symbol(:layer_, i), N)}(layers) end @@ -201,10 +206,25 @@ matrix_to_array(x::AbstractMatrix, ::AbstractVector) = vec(x) matrix_to_array(x::AbstractMatrix, ::AbstractMatrix) = x matrix_to_array(x::AbstractMatrix, y::AbstractArray) = reshape(x, :, size(y)[2:end]...) +# This should probably be in WeightInitializers.jl +calculate_gain(_, __) = 1.0f0 +calculate_gain(::typeof(identity), _) = 1.0f0 +calculate_gain(::typeof(NNlib.sigmoid), _) = 1.0f0 +calculate_gain(::typeof(NNlib.sigmoid_fast), _) = 1.0f0 +calculate_gain(::typeof(NNlib.relu), _) = 2.0f0 +calculate_gain(::typeof(tanh), _) = 5.0f0 / 3.0f0 +calculate_gain(::typeof(NNlib.tanh_fast), _) = 5.0f0 / 3.0f0 +function calculate_gain(::typeof(NNlib.leakyrelu), ::Nothing) + return calculate_gain(NNlib.leakyrelu, 0.1f0) +end +calculate_gain(::typeof(NNlib.leakyrelu), x::Real) = typeof(x)(√(2 / (1 + x^2))) +calculate_gain(::typeof(NNlib.selu), _) = 3.0f0 / 4 + end using .Utils: Utils, BoolType, IntegerType, SymbolType, make_abstract_matrix, - matrix_to_array + matrix_to_array, init_trainable_rnn_hidden_state, init_rnn_hidden_state const safe_reverse = Utils.reverse const safe_vec = Utils.vec +const unwrapped_eltype = Utils.eltype diff --git a/test/Project.toml b/test/Project.toml index b82a1f6640..5b9504c671 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,12 +2,10 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -17,11 +15,11 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -43,12 +41,10 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ADTypes = "1.5" Adapt = "4" Aqua = "0.8.4" -Bumper = "0.6, 0.7" ChainRulesCore = "1.24" ComponentArrays = "0.15.16" DispatchDoctor = "0.4.12" Documenter = "1.4" -DynamicExpressions = "0.16, 0.17, 0.18, 0.19" Enzyme = "0.12.26" ExplicitImports = "1.9.0" ForwardDiff = "0.10.36" @@ -57,12 +53,12 @@ Hwloc = "3.2.0" InteractiveUtils = "<0.0.1, 1" LinearAlgebra = "1.10" Logging = "1.10" -LuxCore = "0.1.16" -LuxDeviceUtils = "0.1.26" -LuxLib = "0.3.42" +LuxCore = "1.0" +LuxLib = "1.0" LuxTestUtils = "1.1.4" MLDataDevices = "1.1" MLUtils = "0.4.3" +NNlib = "0.9.21" OneHotArrays = "0.2.5" Optimisers = "0.3.3" Pkg = "1.10" diff --git a/test/autodiff/nested_autodiff_tests.jl b/test/autodiff/nested_autodiff_tests.jl index 1131446430..850c313878 100644 --- a/test/autodiff/nested_autodiff_tests.jl +++ b/test/autodiff/nested_autodiff_tests.jl @@ -29,13 +29,13 @@ function test_nested_ad_input_gradient_jacobian(aType, dev, ongpu, loss_fn, X, m allow_unstable() do test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; - atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoForwardDiff()], + atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoFiniteDiff()], skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) end end -const Xs = (randn(rng, Float32, 3, 3, 2, 4), randn(rng, Float32, 2, 4), - randn(rng, Float32, 2, 4), randn(rng, Float32, 3, 3, 2, 4)) +const Xs = (randn(rng, Float32, 3, 3, 2, 2), randn(rng, Float32, 2, 2), + randn(rng, Float32, 2, 2), randn(rng, Float32, 3, 3, 2, 2)) const models = ( Chain(Conv((3, 3), 2 => 4, gelu; pad=SamePad()), BatchNorm(4), @@ -50,25 +50,25 @@ const models = ( # smodel | ForwardDiff.jacobian function loss_function1(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, ForwardDiff.jacobian(smodel, x)) + return sum(abs2, ForwardDiff.jacobian(smodel, x) .* 0.01f0) end # smodel | Zygote.jacobian function loss_function2(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, only(Zygote.jacobian(smodel, x))) + return sum(abs2, only(Zygote.jacobian(smodel, x)) .* 0.01f0) end # sum(abs2) ∘ smodel | ForwardDiff.gradient function loss_function3(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, ForwardDiff.gradient(Base.Fix1(sum, abs2) ∘ smodel, x)) + return sum(abs2, ForwardDiff.gradient(Base.Fix1(sum, abs2) ∘ smodel, x) .* 0.01f0) end # sum(abs2) ∘ smodel | Zygote.gradient function loss_function4(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ smodel, x))) + return sum(abs2, only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ smodel, x)) .* 0.01f0) end const ALL_TEST_CONFIGS = Iterators.product( @@ -154,13 +154,13 @@ function test_nested_ad_parameter_gradient_jacobian(aType, dev, ongpu, loss_fn, allow_unstable() do test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; - atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoForwardDiff()], + atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoFiniteDiff()], skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) end end -const Xs = (randn(rng, Float32, 3, 3, 2, 4), randn(rng, Float32, 2, 4), - randn(rng, Float32, 2, 4), randn(rng, Float32, 3, 3, 2, 4)) +const Xs = (randn(rng, Float32, 3, 3, 2, 2), randn(rng, Float32, 2, 2), + randn(rng, Float32, 2, 2), randn(rng, Float32, 3, 3, 2, 2)) const models = ( Chain(Conv((3, 3), 2 => 4, gelu; pad=SamePad()), BatchNorm(4), @@ -175,25 +175,27 @@ const models = ( # smodel | ForwardDiff.jacobian function loss_function1(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, ForwardDiff.jacobian(Base.Fix1(smodel, x), ps)) + return sum(abs2, ForwardDiff.jacobian(Base.Fix1(smodel, x), ps) .* 0.01f0) end # smodel | Zygote.jacobian function loss_function2(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, only(Zygote.jacobian(Base.Fix1(smodel, x), ps))) + return sum(abs2, only(Zygote.jacobian(Base.Fix1(smodel, x), ps)) .* 0.01f0) end # sum(abs2) ∘ smodel | ForwardDiff.gradient function loss_function3(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, ForwardDiff.gradient(Base.Fix1(sum, abs2) ∘ Base.Fix1(smodel, x), ps)) + return sum(abs2, + ForwardDiff.gradient(Base.Fix1(sum, abs2) ∘ Base.Fix1(smodel, x), ps) .* 0.01f0) end # sum(abs2) ∘ smodel | Zygote.gradient function loss_function4(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ Base.Fix1(smodel, x), ps))) + return sum(abs2, + only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ Base.Fix1(smodel, x), ps)) .* 0.01f0) end const ALL_TEST_CONFIGS = Iterators.product( diff --git a/test/contrib/debug_tests.jl b/test/contrib/debug_tests.jl index 469f7787a6..2aff0dd4e2 100644 --- a/test/contrib/debug_tests.jl +++ b/test/contrib/debug_tests.jl @@ -4,8 +4,8 @@ rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) + model = Chain( + Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), BatchNorm(1)) ps, st = Lux.setup(rng, model) |> dev x = randn(rng, Float32, 1, 5) |> aType @@ -29,8 +29,8 @@ catch end - model_fixed = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) + model_fixed = Chain( + Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), BatchNorm(1)) ps, st = Lux.setup(rng, model_fixed) |> dev @@ -61,8 +61,8 @@ end end @testset "$mode: NaN Debugging" for (mode, aType, dev, ongpu) in MODES - model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) + model = Chain( + Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), BatchNorm(1)) x = randn(rng, Float32, 1, 5) |> aType ps, st = Lux.setup(rng, model) |> dev @@ -85,8 +85,8 @@ end model_debug4 = Lux.Experimental.@debug_mode model nan_check=:none @test any(isnan, first(model_debug4(x, ps, st)) |> Array) - model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), offending_layer), - BatchNorm(1); disable_optimizations=true) + model = Chain( + Dense(1 => 16, relu), Chain(Dense(16 => 1), offending_layer), BatchNorm(1)) ps, st = Lux.setup(rng, model) |> dev diff --git a/test/contrib/freeze_tests.jl b/test/contrib/freeze_tests.jl index ad9a2416a5..22fef5bbb1 100644 --- a/test/contrib/freeze_tests.jl +++ b/test/contrib/freeze_tests.jl @@ -1,5 +1,4 @@ @testitem "All Parameter Freezing" setup=[SharedTestSetup] tags=[:contrib] begin - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES @@ -7,8 +6,6 @@ d = Dense(5 => 5) psd, std = Lux.setup(rng, d) .|> dev - @test_deprecated Lux.freeze(d, psd, std, nothing) - fd, ps, st = Lux.Experimental.freeze(d, psd, std, nothing) @test length(keys(ps)) == 0 @test length(keys(st)) == 2 @@ -68,7 +65,6 @@ end @testitem "Partial Freezing" setup=[SharedTestSetup] tags=[:contrib] begin using Lux.Experimental: FrozenLayer - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES diff --git a/test/contrib/map_tests.jl b/test/contrib/map_tests.jl index cd3cdfb323..8badcf358c 100644 --- a/test/contrib/map_tests.jl +++ b/test/contrib/map_tests.jl @@ -1,26 +1,23 @@ @testitem "Layer Map" setup=[SharedTestSetup] tags=[:contrib] begin - using Setfield + using Setfield, Functors - function zero_dense_params_1(l, ps, st, name) - if l isa Dense && occursin("model.layers.chain", name) - @set! ps.weight = zero.(ps.weight) - @set! ps.bias = zero.(ps.bias) - end - return l, ps, st + function occurs_in(kp::KeyPath, x::KeyPath) + length(kp) ≤ length(x) && return all(==(x[i], kp[i]) for i in 1:length(kp)) + return false end - function zero_dense_params_2(l, ps, st, name) - if l isa Dense && occursin("c.layers.chain", name) - @set! ps.weight = zero.(ps.weight) - @set! ps.bias = zero.(ps.bias) + function zero_dense_params_1(l, ps, st, name) + if l isa Dense && occurs_in(KeyPath(:chain), name) + @set! ps.weight = zero(ps.weight) + @set! ps.bias = zero(ps.bias) end return l, ps, st end - function zero_dense_params_3(l, ps, st, name) + function zero_dense_params_2(l, ps, st, name) if l isa Dense - @set! ps.weight = zero.(ps.weight) - @set! ps.bias = zero.(ps.bias) + @set! ps.weight = zero(ps.weight) + @set! ps.bias = zero(ps.bias) end return l, ps, st end @@ -31,7 +28,7 @@ dense_3=Dense(5 => 1)) rng = StableRNG(12345) - ps, st = Lux.setup(rng, c) .|> dev + ps, st = Lux.setup(rng, c) |> dev c_, ps_, st_ = Lux.Experimental.layer_map(zero_dense_params_1, c, ps, st) @@ -40,32 +37,39 @@ @test all(iszero, ps_.chain.dense_2.weight) @test all(iszero, ps_.chain.dense_2.bias) @test !all(iszero, ps_.dense_3.weight) - @test all(iszero, ps_.dense_3.bias) - - c_, ps_, st_ = Lux.Experimental.@layer_map zero_dense_params_2 c ps st - - @test all(iszero, ps_.chain.dense_1.weight) - @test all(iszero, ps_.chain.dense_1.bias) - @test all(iszero, ps_.chain.dense_2.weight) - @test all(iszero, ps_.chain.dense_2.bias) - @test !all(iszero, ps_.dense_3.weight) - @test all(iszero, ps_.dense_3.bias) + @test !all(iszero, ps_.dense_3.bias) # Custom Layers -- See https://github.com/LuxDL/Lux.jl/issues/187 - struct SimpleCustom{L1, L2} <: Lux.AbstractExplicitContainerLayer{(:dense, :conv)} + struct SimpleCustom{L1, L2} <: Lux.AbstractLuxContainerLayer{(:dense, :conv)} dense::L1 conv::L2 end l = SimpleCustom(Dense(3 => 2), Conv((3,), 3 => 2)) - ps, st = Lux.setup(rng, l) .|> dev + ps, st = Lux.setup(rng, l) |> dev + + l_, ps_, st_ = Lux.Experimental.layer_map(zero_dense_params_2, l, ps, st) + + @test all(iszero, ps_.dense.weight) + @test all(iszero, ps_.dense.bias) + @test !all(iszero, ps_.conv.weight) + @test !all(iszero, ps_.conv.bias) + + # Custom Wrapper + struct SimpleWrapper{L} <: Lux.AbstractLuxWrapperLayer{:model} + model::L + end + + l = SimpleWrapper(SimpleCustom(Dense(3 => 2), Conv((3,), 3 => 2))) + + ps, st = Lux.setup(rng, l) |> dev - l_, ps_, st_ = Lux.Experimental.@layer_map zero_dense_params_3 l ps st + l_, ps_, st_ = Lux.Experimental.layer_map(zero_dense_params_2, l, ps, st) @test all(iszero, ps_.dense.weight) @test all(iszero, ps_.dense.bias) @test !all(iszero, ps_.conv.weight) - @test all(iszero, ps_.conv.bias) + @test !all(iszero, ps_.conv.bias) end end diff --git a/test/contrib/share_parameters_tests.jl b/test/contrib/share_parameters_tests.jl index f39f6bca79..874ddd2ffc 100644 --- a/test/contrib/share_parameters_tests.jl +++ b/test/contrib/share_parameters_tests.jl @@ -9,8 +9,6 @@ sharing = (("d2.l2", "d1"), ("d3", "d2.l1")) - @test_deprecated Lux.share_parameters(ps, sharing) - ps_1 = Lux.Experimental.share_parameters(ps, sharing) @test ps_1.d2.l2.weight == ps_1.d1.weight @@ -18,10 +16,8 @@ @test ps_1.d3.weight == ps_1.d2.l1.weight @test ps_1.d3.bias == ps_1.d2.l1.bias - ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4, 1)) |> - dev - ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) |> - dev + ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4)) |> dev + ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2)) |> dev ps_2 = Lux.Experimental.share_parameters(ps, sharing, (ps_new_1, ps_new_2)) @@ -31,7 +27,7 @@ @test ps_2.d3.bias == ps_new_2.bias == ps_2.d2.l1.bias # Mix in ComponentArray - ps_new_ca_1 = ComponentArray(ps_new_1 |> LuxCPUDevice()) |> dev + ps_new_ca_1 = ComponentArray(ps_new_1 |> CPUDevice()) |> dev ps_3 = Lux.Experimental.share_parameters(ps, sharing, (ps_new_ca_1, ps_new_2)) @@ -48,15 +44,13 @@ ps, sharing, (ps_new_1,)) # Parameter Structure Mismatch - ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4, 1)) |> - dev - ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) |> - dev + ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4)) |> dev + ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2)) |> dev @test_throws ArgumentError Lux.Experimental.share_parameters( ps, sharing, (ps_new_1, ps_new_2)) - ps_new_ca_1 = ComponentArray(ps_new_1 |> LuxCPUDevice()) |> dev + ps_new_ca_1 = ComponentArray(ps_new_1 |> CPUDevice()) |> dev @test_throws ArgumentError Lux.Experimental.share_parameters( ps, sharing, (ps_new_ca_1, ps_new_2)) diff --git a/test/distributed/common_distributedtest.jl b/test/distributed/common_distributedtest.jl index 4c64927d09..231078b6b6 100644 --- a/test/distributed/common_distributedtest.jl +++ b/test/distributed/common_distributedtest.jl @@ -9,8 +9,8 @@ if input_args[1] == "amdgpu" end const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend -const dev = input_args[1] == "cpu" ? LuxCPUDevice() : - (input_args[1] == "cuda" ? LuxCUDADevice() : LuxAMDGPUDevice()) +const dev = input_args[1] == "cpu" ? CPUDevice() : + (input_args[1] == "cuda" ? CUDADevice() : AMDGPUDevice()) const aType = input_args[1] == "cpu" ? Array : (input_args[1] == "cuda" ? CuArray : ROCArray) diff --git a/test/distributed/data_distributedtest.jl b/test/distributed/data_distributedtest.jl index d2eb08de78..c2f78adf57 100644 --- a/test/distributed/data_distributedtest.jl +++ b/test/distributed/data_distributedtest.jl @@ -14,8 +14,8 @@ if input_args[1] == "amdgpu" end const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend -const dev = input_args[1] == "cpu" ? LuxCPUDevice() : - (input_args[1] == "cuda" ? LuxCUDADevice() : LuxAMDGPUDevice()) +const dev = input_args[1] == "cpu" ? CPUDevice() : + (input_args[1] == "cuda" ? CUDADevice() : AMDGPUDevice()) rng = Xoshiro(1234) diff --git a/test/distributed/optimizer_distributedtest.jl b/test/distributed/optimizer_distributedtest.jl index 122761f194..6a3992a43a 100644 --- a/test/distributed/optimizer_distributedtest.jl +++ b/test/distributed/optimizer_distributedtest.jl @@ -9,8 +9,8 @@ if input_args[1] == "amdgpu" end const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend -const dev = input_args[1] == "cpu" ? LuxCPUDevice() : - (input_args[1] == "cuda" ? LuxCUDADevice() : LuxAMDGPUDevice()) +const dev = input_args[1] == "cpu" ? CPUDevice() : + (input_args[1] == "cuda" ? CUDADevice() : AMDGPUDevice()) DistributedUtils.initialize(backend_type) backend = DistributedUtils.get_distributed_backend(backend_type) diff --git a/test/distributed/synchronize_distributedtest.jl b/test/distributed/synchronize_distributedtest.jl index 403cab1d53..388755881e 100644 --- a/test/distributed/synchronize_distributedtest.jl +++ b/test/distributed/synchronize_distributedtest.jl @@ -9,8 +9,8 @@ if input_args[1] == "amdgpu" end const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend -const dev = input_args[1] == "cpu" ? LuxCPUDevice() : - (input_args[1] == "cuda" ? LuxCUDADevice() : LuxAMDGPUDevice()) +const dev = input_args[1] == "cpu" ? CPUDevice() : + (input_args[1] == "cuda" ? CUDADevice() : AMDGPUDevice()) function __get_array_based_on_rank(backend, dims; root) DistributedUtils.local_rank(backend) == root && return ones(dims...) diff --git a/test/enzyme_tests.jl b/test/enzyme_tests.jl index b3e5c7b7b8..5d7ac76cf0 100644 --- a/test/enzyme_tests.jl +++ b/test/enzyme_tests.jl @@ -2,7 +2,7 @@ # able to remove this, but this file is still helpful to catch errors in a localized way. @testsetup module EnzymeTestSetup using LuxTestUtils, Enzyme, Zygote, Test -using Lux +using Lux, NNlib using LuxTestUtils: check_approx generic_loss_function(model, x, ps, st) = sum(abs2, first(model(x, ps, st))) diff --git a/test/helpers/compact_tests.jl b/test/helpers/compact_tests.jl index 48c28fffb5..1844e937b2 100644 --- a/test/helpers/compact_tests.jl +++ b/test/helpers/compact_tests.jl @@ -104,14 +104,14 @@ ps, st = Lux.setup(rng, model) |> dev @test size(ps.w1.weight) == (128, 1) - @test size(ps.w1.bias) == (128, 1) + @test size(ps.w1.bias) == (128,) @test length(ps.w2) == nlayers for i in 1:nlayers @test size(ps.w2[i].weight) == (128, 128) - @test size(ps.w2[i].bias) == (128, 1) + @test size(ps.w2[i].bias) == (128,) end @test size(ps.w3.weight) == (1, 128) - @test size(ps.w3.bias) == (1, 1) + @test size(ps.w3.bias) == (1,) x = randn(n_in, 32) |> aType diff --git a/test/helpers/size_propagator_test.jl b/test/helpers/size_propagator_test.jl index 9dde070ced..7825cb75d3 100644 --- a/test/helpers/size_propagator_test.jl +++ b/test/helpers/size_propagator_test.jl @@ -5,10 +5,10 @@ lenet = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)), Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3), Dense(256 => 120, relu), Dense(120 => 84, relu), Dense(84 => 10)) - ps, st = Lux.setup(rng, lenet) - @test Lux.compute_output_size(lenet, (28, 28, 1, 3), ps, st) == (10,) - @test Lux.compute_output_size(lenet, (28, 28, 1, 12), ps, st) == (10,) + for x in (randn(rng, Float32, 28, 28, 1, 3), randn(rng, Float32, 28, 28, 1, 12)) + @test Lux.outputsize(lenet, x, rng) == (10,) + end end @testset "Chain with BatchNorm" begin @@ -17,35 +17,24 @@ MaxPool((2, 2)), FlattenLayer(3), Dense(256 => 120, relu), BatchNorm(120, relu), Dense(120 => 84, relu), Dropout(0.5f0), BatchNorm(84, relu), Dense(84 => 10), BatchNorm(10, relu)) - ps, st = Lux.setup(rng, lenet) - @test Lux.compute_output_size(lenet, (28, 28, 1, 3), ps, st) == (10,) - @test Lux.compute_output_size(lenet, (28, 28, 1, 12), ps, st) == (10,) + for x in (randn(rng, Float32, 28, 28, 1, 3), randn(rng, Float32, 28, 28, 1, 12)) + @test Lux.outputsize(lenet, x, rng) == (10,) + end end - @testset "Normalization Layers" begin - layer = BatchNorm(3, relu) - ps, st = Lux.setup(rng, layer) - - @test Lux.compute_output_size(layer, (4, 4, 3, 2), ps, st) == (4, 4, 3) - @test Lux.compute_output_size(layer, (3, 3), ps, st) == (3,) - - layer = GroupNorm(6, 3, relu) - ps, st = Lux.setup(rng, layer) - - @test Lux.compute_output_size(layer, (4, 4, 6, 2), ps, st) == (4, 4, 6) - @test Lux.compute_output_size(layer, (6, 3), ps, st) == (6,) - - layer = InstanceNorm(3, relu) - ps, st = Lux.setup(rng, layer) - - @test Lux.compute_output_size(layer, (4, 4, 3, 2), ps, st) == (4, 4, 3) - @test Lux.compute_output_size(layer, (4, 3, 2), ps, st) == (4, 3) - - layer = LayerNorm((2, 1, 3), relu) - ps, st = Lux.setup(rng, layer) - - @test Lux.compute_output_size(layer, (2, 4, 3, 2), ps, st) == (2, 4, 3) - @test Lux.compute_output_size(layer, (2, 1, 3, 3), ps, st) == (2, 1, 3) + norm_layer = [ + (BatchNorm(3, relu), [randn(rng, Float32, 4, 4, 3, 2), randn(rng, Float32, 3, 3)]), + (GroupNorm(6, 3, relu), + [randn(rng, Float32, 4, 4, 6, 2), randn(rng, Float32, 6, 3)]), + (InstanceNorm(3, relu), + [randn(rng, Float32, 4, 4, 3, 2), randn(rng, Float32, 4, 3, 2)]), + (LayerNorm((2, 1, 3), relu), + [randn(rng, Float32, 2, 4, 3, 2), randn(rng, Float32, 2, 1, 3, 3)])] + + @testset "Normalization: $(nameof(typeof(layer)))" for (layer, xs) in norm_layer + for x in xs + @test Lux.outputsize(layer, x, rng) == size(x)[1:(end - 1)] + end end end diff --git a/test/helpers/size_propagator_tests.jl b/test/helpers/size_propagator_tests.jl new file mode 100644 index 0000000000..7ce8e2f572 --- /dev/null +++ b/test/helpers/size_propagator_tests.jl @@ -0,0 +1,40 @@ +@testitem "Size Propagator" setup=[SharedTestSetup] tags=[:helpers] begin + rng = StableRNG(12345) + + @testset "Simple Chain (LeNet)" begin + lenet = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)), + Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(), + Dense(256 => 120, relu), Dense(120 => 84, relu), Dense(84 => 10)) + + for x in (randn(rng, Float32, 28, 28, 1, 3), randn(rng, Float32, 28, 28, 1, 12)) + @test Lux.outputsize(lenet, x, rng) == (10,) + end + end + + @testset "Chain with BatchNorm" begin + lenet = Chain(Conv((5, 5), 1 => 6, relu), BatchNorm(6, relu), MaxPool((2, 2)), + Conv((5, 5), 6 => 16, relu), BatchNorm(16, relu), + MaxPool((2, 2)), FlattenLayer(), Dense(256 => 120, relu), + BatchNorm(120, relu), Dense(120 => 84, relu), Dropout(0.5f0), + BatchNorm(84, relu), Dense(84 => 10), BatchNorm(10, relu)) + + for x in (randn(rng, Float32, 28, 28, 1, 3), randn(rng, Float32, 28, 28, 1, 12)) + @test Lux.outputsize(lenet, x, rng) == (10,) + end + end + + norm_layer = [ + (BatchNorm(3, relu), [randn(rng, Float32, 4, 4, 3, 2), randn(rng, Float32, 3, 3)]), + (GroupNorm(6, 3, relu), + [randn(rng, Float32, 4, 4, 6, 2), randn(rng, Float32, 6, 3)]), + (InstanceNorm(3, relu), + [randn(rng, Float32, 4, 4, 3, 2), randn(rng, Float32, 4, 3, 2)]), + (LayerNorm((2, 1, 3), relu), + [randn(rng, Float32, 2, 4, 3, 2), randn(rng, Float32, 2, 1, 3, 3)])] + + @testset "Normalization: $(nameof(typeof(layer)))" for (layer, xs) in norm_layer + for x in xs + @test Lux.outputsize(layer, x, rng) == size(x)[1:(end - 1)] + end + end +end diff --git a/test/helpers/stateful_tests.jl b/test/helpers/stateful_tests.jl index ba3c24691b..cc2b4e4afb 100644 --- a/test/helpers/stateful_tests.jl +++ b/test/helpers/stateful_tests.jl @@ -3,7 +3,7 @@ rng = StableRNG(12345) - struct NotFixedStateModel <: Lux.AbstractExplicitLayer end + struct NotFixedStateModel <: Lux.AbstractLuxLayer end (m::NotFixedStateModel)(x, ps, st) = (x, (; s=1)) @@ -12,8 +12,6 @@ @test st isa NamedTuple{()} - @test_deprecated StatefulLuxLayer(model, ps, st) - smodel = StatefulLuxLayer{false}(model, ps, st) display(smodel) @test smodel(1) isa Any diff --git a/test/helpers/training_tests.jl b/test/helpers/training_tests.jl index 7982ed6df3..67897b17fc 100644 --- a/test/helpers/training_tests.jl +++ b/test/helpers/training_tests.jl @@ -6,24 +6,21 @@ @testset "$mode" for (mode, aType, dev, ongpu) in MODES model = Dense(3, 2) opt = Adam(0.01f0) + ps, st = Lux.setup(Lux.replicate(rng), model) |> dev - tstate = Lux.Experimental.TrainState(Lux.replicate(rng), model, opt) + tstate = Training.TrainState(model, ps, st, opt) x = randn(Lux.replicate(rng), Float32, (3, 1)) - - ps, st = Lux.setup(Lux.replicate(rng), model) opt_st = Optimisers.setup(opt, tstate.parameters) @test check_approx(tstate.model, model) - @test check_approx(tstate.parameters, ps) - @test check_approx(tstate.states, st) @test check_approx(tstate.optimizer_state, opt_st) @test tstate.step == 0 end end @testitem "AbstractADTypes" setup=[SharedTestSetup] tags=[:helpers] begin - using ADTypes, Optimisers, Enzyme + using ADTypes, Optimisers function _loss_function(model, ps, st, data) y, st = model(data, ps, st) @@ -35,9 +32,9 @@ end @testset "$mode" for (mode, aType, dev, ongpu) in MODES model = Dense(3, 2) opt = Adam(0.01f0) + ps, st = Lux.setup(rng, model) |> dev - tstate = Lux.Experimental.TrainState( - Lux.replicate(rng), model, opt; transform_variables=dev) + tstate = Training.TrainState(model, ps, st, opt) x = randn(Lux.replicate(rng), Float32, (3, 1)) |> aType @@ -45,9 +42,8 @@ end ongpu && (ad isa AutoReverseDiff || ad isa AutoEnzyme) && continue !LuxTestUtils.ENZYME_TESTING_ENABLED && ad isa AutoEnzyme && continue - grads, _, _, _ = Lux.Experimental.compute_gradients( - ad, _loss_function, x, tstate) - tstate_ = Lux.Experimental.apply_gradients(tstate, grads) + grads, _, _, _ = Training.compute_gradients(ad, _loss_function, x, tstate) + tstate_ = Training.apply_gradients(tstate, grads) @test tstate_.step == 1 @test tstate != tstate_ end @@ -56,7 +52,6 @@ end @testitem "Training API" setup=[SharedTestSetup] tags=[:helpers] begin using ADTypes, Optimisers - import Enzyme, Tracker, ReverseDiff, Zygote mse = MSELoss() @@ -78,42 +73,34 @@ end ongpu && (ad isa AutoReverseDiff || ad isa AutoEnzyme) && continue !LuxTestUtils.ENZYME_TESTING_ENABLED && ad isa AutoEnzyme && continue - @test_deprecated Lux.Experimental.TrainState( - Lux.replicate(rng), model, opt; transform_variables=dev) - - tstate = Lux.Training.TrainState( - Lux.replicate(rng), model, opt; transform_variables=dev) + ps, st = Lux.setup(rng, model) |> dev + tstate = Training.TrainState(model, ps, st, opt) initial_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1])) - for epoch in 1:100, (x, y) in dataset_ + for epoch in 1:1000, (x, y) in dataset_ grads, loss, _, tstate = allow_unstable() do - Lux.Experimental.compute_gradients(ad, mse, (x, y), tstate) + Training.compute_gradients(ad, mse, (x, y), tstate) end - tstate = Lux.Experimental.apply_gradients!(tstate, grads) - end - - (x, y) = first(dataset_) - allow_unstable() do - @test_deprecated Lux.Experimental.compute_gradients(ad, mse, (x, y), tstate) + tstate = Training.apply_gradients!(tstate, grads) end - grads, loss, _, tstate = allow_unstable() do - Lux.Experimental.compute_gradients(ad, mse, (x, y), tstate) - end - @test_deprecated Lux.Experimental.apply_gradients(tstate, grads) - for epoch in 1:100, (x, y) in dataset_ + for epoch in 1:1000, (x, y) in dataset_ grads, loss, _, tstate = allow_unstable() do - Lux.Experimental.single_train_step!(ad, mse, (x, y), tstate) + Training.single_train_step!(ad, mse, (x, y), tstate) end end - for epoch in 1:100, (x, y) in dataset_ + for epoch in 1:1000, (x, y) in dataset_ grads, loss, _, tstate = allow_unstable() do - Lux.Experimental.single_train_step(ad, mse, (x, y), tstate) + Training.single_train_step(ad, mse, (x, y), tstate) end end + final_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1])) + + @test final_loss * 100 < initial_loss + # Test the adjust API tstate = Optimisers.adjust(tstate, 0.1f0) @test tstate.optimizer_state.layer_1.weight.rule.eta ≈ 0.1f0 @@ -126,67 +113,87 @@ end Optimisers.adjust!(tstate; eta=0.11f0) @test tstate.optimizer_state.layer_1.weight.rule.eta ≈ 0.11f0 - - final_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1])) - - @test final_loss * 100 < initial_loss end struct AutoCustomAD <: ADTypes.AbstractADType end - tstate = Lux.Experimental.TrainState( - Lux.replicate(rng), model, opt; transform_variables=dev) + ps, st = Lux.setup(rng, model) |> dev + tstate = Training.TrainState(model, ps, st, opt) - @test_throws ArgumentError Lux.Experimental.compute_gradients( + @test_throws ArgumentError Training.compute_gradients( AutoCustomAD(), mse, dataset_[1], tstate) end end -@testitem "Enzyme: Invalidate Cache on State Update" setup=[SharedTestSetup] tags=[:helpers] begin +@testitem "Enzyme: Invalidate Cache on State Update" setup=[SharedTestSetup] tags=[:helpers] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin using ADTypes, Optimisers - using Enzyme - if LuxTestUtils.ENZYME_TESTING_ENABLED - Enzyme.API.runtimeActivity!(true) + mse = MSELoss() - mse = MSELoss() - function mse2(model, ps, st, (x, y)) - z, st = model(x, ps, st) - return sum(abs2, z .- y), st, () - end + function mse2(model, ps, st, (x, y)) + z, st = model(x, ps, st) + return sum(abs2, z .- y), st, () + end - rng = StableRNG(12345) + rng = StableRNG(12345) - model = Chain(Dense(4 => 3), VariationalHiddenDropout(0.5f0), Dense(3 => 4)) - ps, st = Lux.setup(rng, model) - x = randn(rng, Float32, 4, 32) - opt = Adam(0.001f0) + model = Chain(Dense(4 => 3), VariationalHiddenDropout(0.5f0), Dense(3 => 4)) + ps, st = Lux.setup(rng, model) + x = randn(rng, Float32, 4, 32) + opt = Adam(0.001f0) - tstate = Lux.Experimental.TrainState(model, ps, st, opt) + tstate = Training.TrainState(model, ps, st, opt) - _, _, _, tstate_new = @inferred Lux.Experimental.compute_gradients( - AutoEnzyme(), mse, (x, x), tstate) + _, _, _, tstate_new = @inferred Training.compute_gradients( + AutoEnzyme(), mse, (x, x), tstate) - @test tstate_new.states !== tstate.states + @test tstate_new.states !== tstate.states - model = Chain(Dense(4 => 3), Dense(3 => 4)) - ps, st = Lux.setup(rng, model) + model = Chain(Dense(4 => 3), Dense(3 => 4)) + ps, st = Lux.setup(rng, model) - tstate = Lux.Experimental.TrainState(model, ps, st, opt) + tstate = Training.TrainState(model, ps, st, opt) - _, _, _, tstate_new = @inferred Lux.Experimental.compute_gradients( - AutoEnzyme(), mse, (x, x), tstate) + _, _, _, tstate_new = @inferred Training.compute_gradients( + AutoEnzyme(), mse, (x, x), tstate) - @test @inferred(Lux.Experimental.compute_gradients( - AutoEnzyme(), mse, (x, x), tstate_new)) isa Any + @test @inferred(Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)) isa + Any - _, _, _, tstate_new2 = @inferred Lux.Experimental.compute_gradients( - AutoEnzyme(), mse2, (x, x), tstate_new) - @test hasfield(typeof(tstate_new2.cache.extras), :forward) - @test hasfield(typeof(tstate_new2.cache.extras), :reverse) - else - @test_broken false - end + _, _, _, tstate_new2 = @inferred Training.compute_gradients( + AutoEnzyme(), mse2, (x, x), tstate_new) + @test hasfield(typeof(tstate_new2.cache.extras), :forward) + @test hasfield(typeof(tstate_new2.cache.extras), :reverse) + + rng = StableRNG(12345) + + model = Chain(Dense(4 => 3), VariationalHiddenDropout(0.5f0), Dense(3 => 4)) + ps, st = Lux.setup(rng, model) + x = randn(rng, Float32, 4, 32) + opt = Adam(0.001f0) + + tstate = Training.TrainState(model, ps, st, opt) + + _, _, _, tstate_new = @inferred Training.compute_gradients( + AutoEnzyme(), mse, (x, x), tstate) + + @test tstate_new.states !== tstate.states + + model = Chain(Dense(4 => 3), Dense(3 => 4)) + ps, st = Lux.setup(rng, model) + + tstate = Training.TrainState(model, ps, st, opt) + + _, _, _, tstate_new = @inferred Training.compute_gradients( + AutoEnzyme(), mse, (x, x), tstate) + + @test @inferred(Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)) isa + Any + + _, _, _, tstate_new2 = @inferred Training.compute_gradients( + AutoEnzyme(), mse2, (x, x), tstate_new) + @test hasfield(typeof(tstate_new2.cache.extras), :forward) + @test hasfield(typeof(tstate_new2.cache.extras), :reverse) end @testitem "Compiled ReverseDiff" setup=[SharedTestSetup] tags=[:helpers] begin @@ -207,22 +214,22 @@ end Dense(32, 32, tanh), BatchNorm(32), Dense(32, 4)) ps, st = Lux.setup(rng, model) - tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) # Stateful models are not supported - @test_throws ArgumentError Lux.Experimental.compute_gradients( + @test_throws ArgumentError Training.compute_gradients( AutoReverseDiff(; compile=true), mse1, dataset[1], tstate) model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4)) ps, st = Lux.setup(rng, model) - tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) # Loss functions that return non-empty `stats` are not supported - @test_throws ArgumentError Lux.Experimental.compute_gradients( + @test_throws ArgumentError Training.compute_gradients( AutoReverseDiff(; compile=true), mse2, dataset[1], tstate) - struct StrangeModel <: Lux.AbstractExplicitLayer end + struct StrangeModel <: Lux.AbstractLuxLayer end function (m::StrangeModel)(x, ps, st) return x, (; new_state=0.0) @@ -231,23 +238,23 @@ end model = StrangeModel() ps, st = Lux.setup(rng, model) - tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) # Stateful models are not supported - @test_throws ArgumentError Lux.Experimental.compute_gradients( + @test_throws ArgumentError Training.compute_gradients( AutoReverseDiff(; compile=true), mse1, dataset[1], tstate) end model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4)) ps, st = Lux.setup(rng, model) - tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) loss_initial = first(mse1(model, ps, st, dataset[1])) for i in 1:100 for (x, y) in dataset _, _, _, tstate = allow_unstable() do - Lux.Experimental.single_train_step!( + Training.single_train_step!( AutoReverseDiff(; compile=true), mse1, (x, y), tstate) end end diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 7f1acd4880..7c06662971 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -9,7 +9,7 @@ x = randn(rng, 6, 3) |> aType @test size(layer(x, ps, st)[1]) == (2, 3, 3) - @test Lux.outputsize(layer) == (2, 3) + @test Lux.outputsize(layer, x, rng) == (2, 3) @jet layer(x, ps, st) @@ -103,45 +103,6 @@ __f = x -> sum(first(layer(x, ps, st))) test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - f11(x) = x .* x - - layer = WrappedFunction{:runtime_check}(f11) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - x = randn(rng, 2, 3) |> aType - - @test layer(x, ps, st)[1] ≈ x .* x - @test @inferred(layer(x, ps, st)) isa Any - - f12(x, ps, st) = x .+ 1, st - - layer = WrappedFunction{:runtime_check}(f12) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - x = randn(rng, 2, 3) |> aType - - @test layer(x, ps, st)[1] ≈ x .+ 1 - @test @inferred(layer(x, ps, st)) isa Any - end - - @testset "PeriodicEmbedding" begin - layer = PeriodicEmbedding([2, 3], [4.0, π / 5]) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - x = randn(rng, 6, 4, 3, 2) |> aType - Δx = [0.0, 12.0, -2π / 5, 0.0, 0.0, 0.0] |> aType - - val = layer(x, ps, st)[1] |> Array - shifted_val = layer(x .+ Δx, ps, st)[1] |> Array - - @test all(val[1:4, :, :, :] .== shifted_val[1:4, :, :, :]) && all(isapprox.( - val[5:8, :, :, :], shifted_val[5:8, :, :, :]; atol=5 * eps(Float32))) - - @jet layer(x, ps, st) - - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoEnzyme()]) end end end @@ -155,7 +116,7 @@ end ps, st = Lux.setup(rng, layer) |> dev @test size(ps.weight) == (100, 10) - @test size(ps.bias) == (100, 1) + @test size(ps.bias) == (100,) @test layer.activation == identity layer = Dense(10, 100, relu; use_bias=false) @@ -165,13 +126,6 @@ end @test layer.activation == relu end - @testset "allow fast activation" begin - layer = Dense(10, 10, tanh) - @test layer.activation == tanh_fast - layer = Dense(10, 10, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end - @testset "dimensions" begin layer = Dense(10, 5) ps, st = Lux.setup(rng, layer) @@ -179,42 +133,36 @@ end @test size(first(Lux.apply(layer, randn(10), ps, st))) == (5,) @test size(first(Lux.apply(layer, randn(10, 2), ps, st))) == (5, 2) - @test LuxCore.outputsize(layer) == (5,) + @test LuxCore.outputsize(layer, randn(10), rng) == (5,) end @testset "zeros" begin @test begin - layer = Dense(10, 1, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + layer = Dense(10, 1, identity; init_weight=ones32, init_bias=zeros32) first(Lux.apply( layer, ones(10, 1) |> aType, dev.(Lux.setup(rng, layer))...)) end == 10 * aType(ones(1, 1)) @test begin - layer = Dense(10, 1, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + layer = Dense(10, 1, identity; init_weight=ones32, init_bias=zeros32) first(Lux.apply( layer, ones(10, 2) |> aType, dev.(Lux.setup(rng, layer))...)) end == 10 * aType(ones(1, 2)) @test begin - layer = Dense(10, 2, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + layer = Dense(10, 2, identity; init_weight=ones32, init_bias=zeros32) first(Lux.apply( layer, ones(10, 1) |> aType, dev.(Lux.setup(rng, layer))...)) end == 10 * aType(ones(2, 1)) @test begin - layer = Dense(10, 2, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + layer = Dense(10, 2, identity; init_weight=ones32, init_bias=zeros32) first(Lux.apply(layer, aType([ones(10, 1) 2 * ones(10, 1)]), dev.(Lux.setup(rng, layer))...)) end == aType([10 20; 10 20]) @test begin - layer = Dense(10, 2, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...), - use_bias=false) + layer = Dense(10, 2, identity; init_weight=ones32, use_bias=false) first(Lux.apply(layer, aType([ones(10, 1) 2 * ones(10, 1)]), dev.(Lux.setup(rng, layer))...)) end == aType([10 20; 10 20]) @@ -247,13 +195,6 @@ end @test layer.activation == relu end - @testset "allow fast activation" begin - layer = Scale(10, 5, tanh) - @test layer.activation == tanh_fast - layer = Scale(10, 5, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end - @testset "dimensions" begin layer = Scale(10, 5) ps, st = Lux.setup(rng, layer) |> dev @@ -262,28 +203,24 @@ end @test size(first(Lux.apply(layer, randn(10, 5, 2) |> aType, ps, st))) == (10, 5, 2) - @test LuxCore.outputsize(layer) == (10, 5) + @test LuxCore.outputsize(layer, randn(10), rng) == (10, 5) end @testset "zeros" begin @test begin - layer = Scale(10, 1, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + layer = Scale(10, 1, identity; init_weight=ones32) first(Lux.apply( layer, ones(10, 1) |> aType, dev.(Lux.setup(rng, layer))...)) end == aType(ones(10, 1)) @test begin - layer = Scale(10, 1, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + layer = Scale(10, 1, identity; init_weight=ones32) first(Lux.apply( layer, ones(10, 2) |> aType, dev.(Lux.setup(rng, layer))...)) end == aType(ones(10, 2)) @test begin - layer = Scale(2, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...), - init_bias=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + layer = Scale(2, identity; init_weight=ones32, init_bias=ones32) first(Lux.apply( layer, [1 2; 3 4] |> aType, dev.(Lux.setup(rng, layer))...)) end == aType([2.0 3.0; 4.0 5.0]) @@ -355,14 +292,14 @@ end @testset "Two-streams zero sum" begin x = zeros(Float32, 2, 1) |> aType y = zeros(Float32, 1, 1) |> aType - layer = Bilinear((2, 1) => 3) + layer = Bilinear((2, 1) => 3; init_bias=zeros32) display(layer) ps, st = Lux.setup(rng, layer) |> dev @test size(layer((x, y), ps, st)[1]) == (3, 1) @test sum(abs2, layer((x, y), ps, st)[1]) == 0.0f0 - @test LuxCore.outputsize(layer) == (3,) + @test LuxCore.outputsize(layer, (x, y), rng) == (3,) @jet layer((x, y), ps, st) @@ -410,7 +347,7 @@ end @test size(ps.weight) == (embed_size, vocab_size) - @test LuxCore.outputsize(layer) == (4,) + @test LuxCore.outputsize(layer, nothing, rng) == (4,) x = rand(1:vocab_size, 1)[1] y, st_ = layer(x, ps, st) @@ -442,7 +379,7 @@ end @test size(ps.weight) == (embed_size, vocab_size...) - @test LuxCore.outputsize(layer) == (4,) + @test LuxCore.outputsize(layer, nothing, rng) == (4,) x = (rand(1:vocab_size[1], 1)[1], rand(1:vocab_size[2], 1)[1]) y, st_ = layer(x, ps, st) diff --git a/test/layers/containers_tests.jl b/test/layers/containers_tests.jl index 6d1eaaed7f..e58296acc1 100644 --- a/test/layers/containers_tests.jl +++ b/test/layers/containers_tests.jl @@ -4,8 +4,7 @@ @testset "$mode" for (mode, aType, dev, ongpu) in MODES @testset "zero sum" begin - layer = SkipConnection( - WrappedFunction{:direct_call}(Broadcast.BroadcastFunction(zero)), .+) + layer = SkipConnection(WrappedFunction(Broadcast.BroadcastFunction(zero)), .+) display(layer) ps, st = Lux.setup(rng, layer) |> dev x = randn(rng, Float32, 10, 10, 10, 10) |> aType @@ -42,9 +41,7 @@ end @testset "$mode" for (mode, aType, dev, ongpu) in MODES @testset "zero sum" begin layer = Parallel( - +, WrappedFunction{:direct_call}(Broadcast.BroadcastFunction(zero)), - NoOpLayer()) - @test :layer_1 in keys(layer) && :layer_2 in keys(layer) + +, WrappedFunction(Broadcast.BroadcastFunction(zero)), NoOpLayer()) display(layer) ps, st = Lux.setup(rng, layer) |> dev x = randn(rng, 10, 10, 10, 10) |> aType @@ -134,7 +131,7 @@ end x::X end - struct L1 <: Lux.AbstractExplicitLayer end + struct L1 <: Lux.AbstractLuxLayer end (::L1)(x, ps, st) = (ps.x * x, st) Lux.initialparameters(rng::AbstractRNG, ::L1) = (x=randn(rng, Float32, 3, 3),) Base.:*(a::AbstractArray, b::Input) = a * b.x @@ -265,7 +262,7 @@ end x = rand(Float32, 10, 1) |> aType y, _ = layer(x, ps, st) @test size(y) == (1, 1) - @test Lux.outputsize(layer) == (1,) + @test Lux.outputsize(layer, x, rng) == (1,) @jet layer(x, ps, st) @@ -293,7 +290,7 @@ end x = rand(Float32, 10, 1) |> aType y, _ = layer(x, ps, st) @test size(y) == (2, 1) - @test Lux.outputsize(layer) == (2,) + @test Lux.outputsize(layer, x, rng) == (2,) @jet layer(x, ps, st) @@ -308,7 +305,7 @@ end x = rand(Float32, 10, 1) |> aType y, _ = layer(x, ps, st) @test size(y) == (2, 1) - @test Lux.outputsize(layer) == (2,) + @test Lux.outputsize(layer, x, rng) == (2,) @jet layer(x, ps, st) @@ -323,17 +320,13 @@ end x = rand(Float32, 10, 1) |> aType y, _ = layer(x, ps, st) @test size(y) == (5, 1) - @test Lux.outputsize(layer) == (5,) + @test Lux.outputsize(layer, x, rng) == (5,) @jet layer(x, ps, st) __f = (x, ps) -> sum(first(layer(x, ps, st))) test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) - @test_throws ArgumentError Chain(; - l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), - d21=Dense(2 => 1), d2=Dense(2 => 1), disable_optimizations=false) - @testset "indexing and field access" begin encoder = Chain(Dense(10 => 5, sigmoid), Dense(5 => 2, tanh)) decoder = Chain(Dense(2 => 5, tanh), Dense(5 => 10, sigmoid)) @@ -342,8 +335,6 @@ end @test encoder[2] == encoder.layer_2 @test autoencoder[1] == autoencoder.encoder @test autoencoder[2] == autoencoder.decoder - @test keys(encoder) == (:layer_1, :layer_2) - @test keys(autoencoder) == (:encoder, :decoder) @test autoencoder.layers isa NamedTuple @test autoencoder.encoder isa Chain @test_throws ArgumentError autoencoder.layer_1 @@ -351,13 +342,11 @@ end end @testset "constructors" begin - @test Chain([Dense(10 => 5, sigmoid)]) == Dense(10 => 5, sigmoid) - - f1(x, ps, st::NamedTuple) = (x .+ 1, st) + f1(x) = x .+ 1 f2(x) = x .+ 2 model = Chain((Dense(2 => 3), Dense(3 => 2)), f1, f2, NoOpLayer()) - @test length(model) == 4 + @test length(model) == 5 x = rand(Float32, 2, 5) ps, st = Lux.setup(rng, model) diff --git a/test/layers/conv_tests.jl b/test/layers/conv_tests.jl index eb6b0974cf..18a4dd4b39 100644 --- a/test/layers/conv_tests.jl +++ b/test/layers/conv_tests.jl @@ -1,101 +1,3 @@ -@testitem "Pooling" setup=[SharedTestSetup] tags=[:core_layers] begin - rng = StableRNG(12345) - - @testset "$mode" for (mode, aType, dev, ongpu) in MODES - x = randn(rng, Float32, 10, 10, 3, 2) |> aType - y = randn(rng, Float32, 20, 20, 3, 2) |> aType - - layer = AdaptiveMaxPool((5, 5)) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test layer(x, ps, st)[1] == maxpool(x, PoolDims(x, 2)) - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - layer = AdaptiveMeanPool((5, 5)) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test layer(x, ps, st)[1] == meanpool(x, PoolDims(x, 2)) - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - layer = AdaptiveMaxPool((10, 5)) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test layer(y, ps, st)[1] == maxpool(y, PoolDims(y, (2, 4))) - @jet layer(y, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - layer = AdaptiveMeanPool((10, 5)) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test layer(y, ps, st)[1] == meanpool(y, PoolDims(y, (2, 4))) - @jet layer(y, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - layer = GlobalMaxPool() - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test size(layer(x, ps, st)[1]) == (1, 1, 3, 2) - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - layer = GlobalMeanPool() - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test size(layer(x, ps, st)[1]) == (1, 1, 3, 2) - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - layer = MaxPool((2, 2)) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test layer(x, ps, st)[1] == maxpool(x, PoolDims(x, 2)) - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - layer = MeanPool((2, 2)) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test layer(x, ps, st)[1] == meanpool(x, PoolDims(x, 2)) - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - @testset "$ltype SamePad windowsize $k" for ltype in (MeanPool, MaxPool), - k in ((1,), (2,), (3,), (4, 5), (6, 7, 8)) - - x = ones(Float32, (k .+ 3)..., 1, 1) |> aType - - layer = ltype(k; pad=Lux.SamePad()) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test size(layer(x, ps, st)[1])[1:(end - 2)] == cld.(size(x)[1:(end - 2)], k) - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - - soft_fail = ltype == MaxPool ? [AutoFiniteDiff()] : [] - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail) - end - end -end - @testitem "CNN" setup=[SharedTestSetup] tags=[:core_layers] begin rng = StableRNG(12345) @@ -141,12 +43,8 @@ end test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) # Test that we cannot ask for non-integer multiplication factors - layer = Conv((2, 2), 3 => 10; groups=2) - display(layer) - @test_throws DimensionMismatch Lux.setup(rng, layer) - layer = Conv((2, 2), 2 => 9; groups=2) - display(layer) - @test_throws DimensionMismatch Lux.setup(rng, layer) + @test_throws DimensionMismatch Conv((2, 2), 3 => 10; groups=2) + @test_throws DimensionMismatch Conv((2, 2), 2 => 9; groups=2) @testset "Segfault Test LuxDL/Lux.jl#386" begin layer = Conv((5,), 32 => 32, tanh; groups=32) @@ -189,7 +87,7 @@ end display(layer) ps, st = Lux.setup(rng, layer) @test ps.weight isa aType{Float64, 4} - @test ps.bias isa aType{Float16, 4} + @test ps.bias isa aType{Float16, 1} end @testset "Depthwise Conv" begin @@ -228,9 +126,7 @@ end test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) # Test that we cannot ask for non-integer multiplication factors - layer = Conv((2, 2), 3 => 10; groups=3) - display(layer) - @test_throws DimensionMismatch Lux.setup(rng, layer) + @test_throws DimensionMismatch Conv((2, 2), 3 => 10; groups=3) end @testset "Conv SamePad kernelsize $k" for k in ((1,), (2,), (3,), (2, 3), (1, 2, 3)) @@ -261,7 +157,7 @@ end x[4, 4, 1, 1] = 1 x = x |> aType - layer = Conv((3, 3), 1 => 1) + layer = Conv((3, 3), 1 => 1; use_bias=false) display(layer) ps, st = Lux.setup(rng, layer) |> dev @@ -271,7 +167,7 @@ end @jet layer(x, ps, st) - layer = Conv((3, 1), 1 => 1) + layer = Conv((3, 1), 1 => 1; use_bias=false) display(layer) ps, st = Lux.setup(rng, layer) |> dev @@ -281,7 +177,7 @@ end @jet layer(x, ps, st) - layer = Conv((1, 3), 1 => 1) + layer = Conv((1, 3), 1 => 1; use_bias=false) display(layer) ps, st = Lux.setup(rng, layer) |> dev @@ -291,7 +187,7 @@ end @jet layer(x, ps, st) - layer = Conv((1, 3), 1 => 1; init_weight=Lux.glorot_normal) + layer = Conv((1, 3), 1 => 1; init_weight=Lux.glorot_normal, use_bias=false) display(layer) ps, st = Lux.setup(rng, layer) |> dev @@ -301,13 +197,6 @@ end @jet layer(x, ps, st) end - - @testset "allow fast activation" begin - layer = Conv((3, 3), 1 => 1, tanh) - @test layer.activation == tanh_fast - layer = Conv((3, 3), 1 => 1, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end end end @@ -417,12 +306,12 @@ end end end -@testitem "CrossCor" setup=[SharedTestSetup] tags=[:core_layers] begin +@testitem "Conv(cross_correlation=true)" setup=[SharedTestSetup] tags=[:core_layers] begin rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES @testset "Asymmetric Padding" begin - layer = CrossCor((3, 3), 1 => 1, relu; pad=(0, 1, 1, 2)) + layer = Conv((3, 3), 1 => 1, relu; pad=(0, 1, 1, 2), cross_correlation=true) display(layer) x = ones(Float32, 28, 28, 1, 1) |> aType ps, st = Lux.setup(rng, layer) |> dev @@ -443,23 +332,23 @@ end end @testset "Variable BitWidth Parameters FluxML/Flux.jl#1421" begin - layer = CrossCor((5, 5), 10 => 20, identity; + layer = Conv((5, 5), 10 => 20, identity; init_weight=(rng, dims...) -> aType(randn(rng, Float64, dims...)), - init_bias=(rng, dims...) -> aType(randn(rng, Float16, dims...))) + init_bias=(rng, dims...) -> aType(randn(rng, Float16, dims...)), + cross_correlation=true) display(layer) ps, st = Lux.setup(rng, layer) @test ps.weight isa aType{Float64, 4} - @test ps.bias isa aType{Float16, 4} + @test ps.bias isa aType{Float16, 1} end - @testset "CrossCor SamePad kernelsize $k" for k in ( - (1,), (2,), (3,), (2, 3), (1, 2, 3)) + @testset "SamePad kernelsize $k" for k in ((1,), (2,), (3,), (2, 3), (1, 2, 3)) x = ones(Float32, (k .+ 3)..., 1, 1) |> aType @testset "Kwargs: $kwarg" for kwarg in ( (; stride=1), (; dilation=max.(k .÷ 2, 1), stride=1), (; stride=3), (; stride=1, use_bias=false)) - layer = CrossCor(k, 1 => 1; pad=Lux.SamePad(), kwarg...) + layer = Conv(k, 1 => 1; pad=Lux.SamePad(), kwarg..., cross_correlation=true) display(layer) ps, st = Lux.setup(rng, layer) |> dev @@ -476,13 +365,6 @@ end test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) end end - - @testset "allow fast activation" begin - layer = CrossCor((3, 3), 1 => 1, tanh) - @test layer.activation == tanh_fast - layer = CrossCor((3, 3), 1 => 1, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end end end @@ -490,152 +372,191 @@ end rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - x = randn(Float32, 5, 5, 1, 1) |> aType - layer = Conv((3, 3), 1 => 1) - ps, st = Lux.setup(rng, layer) |> dev - y = layer(x, ps, st)[1] - - layer = ConvTranspose((3, 3), 1 => 1) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @jet layer(y, ps, st) + @testset for cross_correlation in (true, false) + x = randn(Float32, 5, 5, 1, 1) |> aType + layer = Conv((3, 3), 1 => 1) + ps, st = Lux.setup(rng, layer) |> dev + y = layer(x, ps, st)[1] - x_hat1 = layer(y, ps, st)[1] + layer = ConvTranspose((3, 3), 1 => 1; cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev - layer = ConvTranspose((3, 3), 1 => 1; use_bias=false) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev + @jet layer(y, ps, st) - @jet layer(y, ps, st) + x_hat1 = layer(y, ps, st)[1] - x_hat2 = layer(y, ps, st)[1] + layer = ConvTranspose((3, 3), 1 => 1; use_bias=false, cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev - @test size(x_hat1) == size(x_hat2) == size(x) + @jet layer(y, ps, st) - layer = ConvTranspose((3, 3), 1 => 1) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - x = rand(Float32, 5, 5, 1, 1) |> aType + x_hat2 = layer(y, ps, st)[1] - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) + @test size(x_hat1) == size(x_hat2) == size(x) - x = rand(Float32, 5, 5, 2, 4) |> aType - layer = ConvTranspose((3, 3), 2 => 3) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev + layer = ConvTranspose((3, 3), 1 => 1; cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev + x = rand(Float32, 5, 5, 1, 1) |> aType - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) - - # test ConvTranspose supports groups argument - x = randn(Float32, 10, 10, 2, 3) |> aType - layer1 = ConvTranspose((3, 3), 2 => 4; pad=SamePad()) - display(layer1) - ps1, st1 = Lux.setup(rng, layer1) |> dev - @test size(ps1.weight) == (3, 3, 4, 2) - @test size(layer1(x, ps1, st1)[1]) == (10, 10, 4, 3) - - layer2 = ConvTranspose((3, 3), 2 => 4; groups=2, pad=SamePad()) - display(layer2) - ps2, st2 = Lux.setup(rng, layer2) |> dev - @test size(ps2.weight) == (3, 3, 2, 2) - @test size(layer1(x, ps1, st1)[1]) == size(layer2(x, ps2, st2)[1]) - - __f = (x, ps) -> sum(first(layer1(x, ps, st1))) - test_gradients(__f, x, ps1; atol=1.0f-3, rtol=1.0f-3) - - __f = (x, ps) -> sum(first(layer2(x, ps, st2))) - test_gradients(__f, x, ps2; atol=1.0f-3, rtol=1.0f-3) - - x = randn(Float32, 10, 2, 1) |> aType - layer = ConvTranspose((3,), 2 => 4; pad=SamePad(), groups=2) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) - @jet layer(x, ps, st) + x = rand(Float32, 5, 5, 2, 4) |> aType + layer = ConvTranspose((3, 3), 2 => 3; cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev - @test size(layer(x, ps, st)[1]) == (10, 4, 1) - @test length(ps.weight) == 3 * (2 * 4) / 2 + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) + # test ConvTranspose supports groups argument + x = randn(Float32, 10, 10, 2, 3) |> aType + layer1 = ConvTranspose((3, 3), 2 => 4; pad=SamePad(), cross_correlation) + display(layer1) + ps1, st1 = Lux.setup(rng, layer1) |> dev + @test size(ps1.weight) == (3, 3, 4, 2) + @test size(layer1(x, ps1, st1)[1]) == (10, 10, 4, 3) + + layer2 = ConvTranspose( + (3, 3), 2 => 4; groups=2, pad=SamePad(), cross_correlation) + display(layer2) + ps2, st2 = Lux.setup(rng, layer2) |> dev + @test size(ps2.weight) == (3, 3, 2, 2) + @test size(layer1(x, ps1, st1)[1]) == size(layer2(x, ps2, st2)[1]) + + __f = (x, ps) -> sum(first(layer1(x, ps, st1))) + test_gradients(__f, x, ps1; atol=1.0f-3, rtol=1.0f-3) + + __f = (x, ps) -> sum(first(layer2(x, ps, st2))) + test_gradients(__f, x, ps2; atol=1.0f-3, rtol=1.0f-3) + + x = randn(Float32, 10, 2, 1) |> aType + layer = ConvTranspose((3,), 2 => 4; pad=SamePad(), groups=2, cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev - x = randn(Float32, 10, 11, 4, 2) |> aType - layer = ConvTranspose((3, 5), 4 => 4; pad=SamePad(), groups=4) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev + @jet layer(x, ps, st) - @jet layer(x, ps, st) + @test size(layer(x, ps, st)[1]) == (10, 4, 1) + @test length(ps.weight) == 3 * (2 * 4) / 2 - @test size(layer(x, ps, st)[1]) == (10, 11, 4, 2) - @test length(ps.weight) == (3 * 5) * (4 * 4) / 4 + __f = (x, ps) -> sum(first(layer(x, ps, st))) + test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) + x = randn(Float32, 10, 11, 4, 2) |> aType + layer = ConvTranspose( + (3, 5), 4 => 4; pad=SamePad(), groups=4, cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev - x = randn(Float32, 10, 11, 4, 2) |> aType - layer = ConvTranspose((3, 5), 4 => 4, tanh; pad=SamePad(), groups=4) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev + @jet layer(x, ps, st) - @jet layer(x, ps, st) - @test size(layer(x, ps, st)[1]) == (10, 11, 4, 2) - @test length(ps.weight) == (3 * 5) * (4 * 4) / 4 + @test size(layer(x, ps, st)[1]) == (10, 11, 4, 2) + @test length(ps.weight) == (3 * 5) * (4 * 4) / 4 - __f = (x, ps) -> sum(first(layer(x, ps, st))) - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) - x = randn(Float32, 10, 11, 12, 3, 2) |> aType - layer = ConvTranspose((3, 5, 3), 3 => 6; pad=SamePad(), groups=3) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev + x = randn(Float32, 10, 11, 4, 2) |> aType + layer = ConvTranspose( + (3, 5), 4 => 4, tanh; pad=SamePad(), groups=4, cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev - @jet layer(x, ps, st) - @test size(layer(x, ps, st)[1]) == (10, 11, 12, 6, 2) - @test length(ps.weight) == (3 * 5 * 3) * (3 * 6) / 3 + @jet layer(x, ps, st) + @test size(layer(x, ps, st)[1]) == (10, 11, 4, 2) + @test length(ps.weight) == (3 * 5) * (4 * 4) / 4 - x = randn(Float32, 10, 11, 12, 3, 2) |> aType - layer = ConvTranspose((3, 5, 3), 3 => 6, tanh; pad=SamePad(), groups=3) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev + __f = (x, ps) -> sum(first(layer(x, ps, st))) + test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) - @jet layer(x, ps, st) - @test size(layer(x, ps, st)[1]) == (10, 11, 12, 6, 2) - @test length(ps.weight) == (3 * 5 * 3) * (3 * 6) / 3 + x = randn(Float32, 10, 11, 12, 3, 2) |> aType + layer = ConvTranspose( + (3, 5, 3), 3 => 6; pad=SamePad(), groups=3, cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev - @test occursin("groups=2", sprint(show, ConvTranspose((3, 3), 2 => 4; groups=2))) - @test occursin("2 => 4", sprint(show, ConvTranspose((3, 3), 2 => 4; groups=2))) + @jet layer(x, ps, st) + @test size(layer(x, ps, st)[1]) == (10, 11, 12, 6, 2) + @test length(ps.weight) == (3 * 5 * 3) * (3 * 6) / 3 - @testset "SamePad size mismatch LuxDL/Lux.jl#534" begin - layer = ConvTranspose((3,), 2 => 1; pad=SamePad(), stride=2) + x = randn(Float32, 10, 11, 12, 3, 2) |> aType + layer = ConvTranspose( + (3, 5, 3), 3 => 6, tanh; pad=SamePad(), groups=3, cross_correlation) display(layer) - x = ones(Float32, 2, 2, 1) |> aType ps, st = Lux.setup(rng, layer) |> dev - y = first(layer(x, ps, st)) - @test size(y) == (4, 1, 1) @jet layer(x, ps, st) - end + @test size(layer(x, ps, st)[1]) == (10, 11, 12, 6, 2) + @test length(ps.weight) == (3 * 5 * 3) * (3 * 6) / 3 - @testset "Catch Channel Mismatch Early: LuxDL/Lux.jl#455" begin - layer = ConvTranspose((4, 4), 42 => 16; stride=2, pad=1) + @test occursin("groups=2", + sprint(show, ConvTranspose((3, 3), 2 => 4; groups=2, cross_correlation))) + @test occursin("2 => 4", + sprint(show, ConvTranspose((3, 3), 2 => 4; groups=2, cross_correlation))) - x = randn(Float32, 28, 28, 42, 3) |> aType - ps, st = Lux.setup(rng, layer) |> dev + @testset "SamePad size mismatch LuxDL/Lux.jl#534" begin + layer = ConvTranspose( + (3,), 2 => 1; pad=SamePad(), stride=2, cross_correlation) + display(layer) + x = ones(Float32, 2, 2, 1) |> aType + ps, st = Lux.setup(rng, layer) |> dev + + y = first(layer(x, ps, st)) + @test size(y) == (4, 1, 1) + @jet layer(x, ps, st) + end + + @testset "Catch Channel Mismatch Early: LuxDL/Lux.jl#455" begin + layer = ConvTranspose((4, 4), 42 => 16; stride=2, pad=1, cross_correlation) + + x = randn(Float32, 28, 28, 42, 3) |> aType + ps, st = Lux.setup(rng, layer) |> dev - @test layer(x, ps, st) isa Any + @test layer(x, ps, st) isa Any - x = randn(Float32, 28, 28, 46, 3) |> aType + x = randn(Float32, 28, 28, 46, 3) |> aType - @test_throws DimensionMismatch layer(x, ps, st) + @test_throws DimensionMismatch layer(x, ps, st) - x = randn(Float32, 28, 28, 23, 3) |> aType + x = randn(Float32, 28, 28, 23, 3) |> aType - @test_throws DimensionMismatch layer(x, ps, st) + @test_throws DimensionMismatch layer(x, ps, st) + end + + @testset "with Output Padding" begin + m1 = ConvTranspose((3, 5), 3 => 6; stride=3, cross_correlation) + display(m1) + m2 = ConvTranspose( + (3, 5), 3 => 6; stride=3, outpad=(1, 0), cross_correlation) + display(m2) + + ps1, st1 = Lux.setup(rng, m1) |> dev + ps2, st2 = Lux.setup(rng, m2) |> dev + + x = randn(Float32, 10, 11, 3, 2) |> aType + @test size(m1(x, ps1, st1)[1])[1:2] .+ (1, 0) == + size(m2(x, ps2, st2)[1])[1:2] + + m1 = ConvTranspose((3, 5, 3), 3 => 6; stride=3, cross_correlation) + display(m1) + m2 = ConvTranspose( + (3, 5, 3), 3 => 6; stride=3, outpad=(1, 0, 1), cross_correlation) + display(m2) + + ps1, st1 = Lux.setup(rng, m1) |> dev + ps2, st2 = Lux.setup(rng, m2) |> dev + + x = randn(Float32, 10, 11, 12, 3, 2) |> aType + + @test size(m1(x, ps1, st1)[1])[1:3] .+ (1, 0, 1) == + size(m2(x, ps2, st2)[1])[1:3] + end end end end diff --git a/test/layers/dropout_tests.jl b/test/layers/dropout_tests.jl index 5d377d9b6e..62ecafcd0b 100644 --- a/test/layers/dropout_tests.jl +++ b/test/layers/dropout_tests.jl @@ -64,9 +64,6 @@ end end @testitem "VariationalHiddenDropout" setup=[SharedTestSetup] tags=[:normalize_layers] begin - using Enzyme - Enzyme.API.runtimeActivity!(true) - rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES diff --git a/test/layers/dynamic_expressions_tests.jl b/test/layers/dynamic_expressions_tests.jl deleted file mode 100644 index fae956c365..0000000000 --- a/test/layers/dynamic_expressions_tests.jl +++ /dev/null @@ -1,60 +0,0 @@ -@testitem "Dynamic Expressions" setup=[SharedTestSetup] tags=[:others] begin - using DynamicExpressions, ForwardDiff, ComponentArrays, Bumper - - operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos]) - - x1 = Node(; feature=1) - x2 = Node(; feature=2) - - expr_1 = x1 * cos(x2 - 3.2) - expr_2 = x2 - x1 * x2 + 2.5 - 1.0 * x1 - - for exprs in ((expr_1,), (expr_1, expr_2), ([expr_1, expr_2],)), - turbo in (Val(false), Val(true)), - bumper in (Val(false), Val(true)) - - layer = DynamicExpressionsLayer(operators, exprs...; turbo, bumper) - ps, st = Lux.setup(Random.default_rng(), layer) - - x = [1.0f0 2.0f0 3.0f0 - 4.0f0 5.0f0 6.0f0] - - y, st_ = layer(x, ps, st) - @test eltype(y) == Float32 - __f = (x, p) -> sum(abs2, first(layer(x, p, st))) - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()]) - - # Particular ForwardDiff dispatches - ps_ca = ComponentArray(ps) - dps_ca = ForwardDiff.gradient(ps_ca) do ps_ - sum(abs2, first(layer(x, ps_, st))) - end - dx = ForwardDiff.gradient(x) do x_ - sum(abs2, first(layer(x_, ps, st))) - end - dxps = ForwardDiff.gradient(ComponentArray(; x=x, ps=ps)) do ca - sum(abs2, first(layer(ca.x, ca.ps, st))) - end - - @test dx≈dxps.x atol=1.0f-3 rtol=1.0f-3 - @test dps_ca≈dxps.ps atol=1.0f-3 rtol=1.0f-3 - - x = Float64.(x) - y, st_ = layer(x, ps, st) - @test eltype(y) == Float64 - __f = (x, p) -> sum(abs2, first(layer(x, p, st))) - test_gradients(__f, x, ps; atol=1.0e-3, rtol=1.0e-3, skip_backends=[AutoEnzyme()]) - end - - @testset "$(mode)" for (mode, aType, dev, ongpu) in MODES - layer = DynamicExpressionsLayer(operators, expr_1) - ps, st = Lux.setup(Random.default_rng(), layer) |> dev - - x = [1.0f0 2.0f0 3.0f0 - 4.0f0 5.0f0 6.0f0] |> aType - - if ongpu - @test_throws ArgumentError layer(x, ps, st) - end - end -end diff --git a/test/layers/normalize_tests.jl b/test/layers/normalize_tests.jl index a598ae8972..46474dbf1a 100644 --- a/test/layers/normalize_tests.jl +++ b/test/layers/normalize_tests.jl @@ -15,7 +15,7 @@ @test ps.scale == [1, 1] |> aType # init_scale(2) y, st_ = pullback(m, x, ps, st)[1] - st_ = st_ |> LuxCPUDevice() + st_ = st_ |> CPUDevice() @test check_approx(Array(y), [-1.22474 0 1.22474; -1.22474 0 1.22474]; atol=1.0e-5) # julia> x # 2×3 Array{Float64,2}: @@ -39,7 +39,7 @@ 0.1 .* var(Array(x); dims=2, corrected=false) .* (3 / 2) .+ 0.9 .* [1.0, 1.0]) st_ = Lux.testmode(st_) |> device - x_ = m(x, ps, st_)[1] |> LuxCPUDevice() + x_ = m(x, ps, st_)[1] |> CPUDevice() @test check_approx(x_[1], (1 .- 0.3) / sqrt(1.3), atol=1.0e-5) @jet m(x, ps, st) @@ -101,13 +101,6 @@ @jet m(x, ps, st) end - - @testset "allow fast activation" begin - layer = BatchNorm(10, tanh) - @test layer.activation == tanh_fast - layer = BatchNorm(10, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end end end @@ -135,7 +128,7 @@ end __f = let m = m, x = x, st = st ps -> sum(first(m(x, ps, st))) end - test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoEnzyme()]) + test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3) @testset "affine: $affine" for affine in (true, false) m = GroupNorm(2, 2; affine) @@ -192,13 +185,6 @@ end end @test_throws ArgumentError GroupNorm(5, 2) - - @testset "allow fast activation" begin - layer = GroupNorm(10, 2, tanh) - @test layer.activation == tanh_fast - layer = GroupNorm(10, 2, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end end end @@ -302,7 +288,7 @@ end # See https://github.com/LuxDL/Lux.jl/issues/95 @testset "Normalizing Zero Parameters" begin - c = Conv((3, 3), 3 => 3) + c = Conv((3, 3), 3 => 3; init_bias=zeros32) wn = WeightNorm(c, (:weight, :bias)) @test_throws ArgumentError Lux.setup(rng, wn) @@ -371,13 +357,6 @@ end end end end - - @testset "allow fast activation" begin - layer = LayerNorm((3, 1), tanh) - @test layer.activation == tanh_fast - layer = LayerNorm((3, 1), tanh; allow_fast_activation=false) - @test layer.activation == tanh - end end end @@ -388,10 +367,10 @@ end for x in (randn(rng, Float32, 3, 3, 3, 2), randn(rng, Float32, 3, 3, 2), randn(rng, Float32, 3, 3, 3, 3, 2)) x = x |> aType - for affine in (true, false) - layer = InstanceNorm(3; affine) + for affine in (true, false), track_stats in (true, false) + layer = InstanceNorm(3; affine, track_stats) display(layer) - ps, st = Lux.setup(rng, layer) .|> device + ps, st = Lux.setup(rng, layer) |> device y, st_ = layer(x, ps, st) @@ -408,9 +387,9 @@ end end for act in (sigmoid, tanh) - layer = InstanceNorm(3, act; affine) + layer = InstanceNorm(3, act; affine, track_stats) display(layer) - ps, st = Lux.setup(rng, layer) .|> device + ps, st = Lux.setup(rng, layer) |> device y, st_ = layer(x, ps, st) @@ -428,12 +407,5 @@ end end end end - - @testset "allow fast activation" begin - layer = InstanceNorm(3, tanh) - @test layer.activation == tanh_fast - layer = InstanceNorm(3, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end end end diff --git a/test/layers/pooling_tests.jl b/test/layers/pooling_tests.jl new file mode 100644 index 0000000000..2824de30ee --- /dev/null +++ b/test/layers/pooling_tests.jl @@ -0,0 +1,78 @@ +@testitem "Pooling" setup=[SharedTestSetup] tags=[:core_layers] begin + rng = StableRNG(12345) + + nnlib_op = Dict(:LPPool => (args...) -> lpnormpool(args...; p=2), + :MeanPool => meanpool, :MaxPool => maxpool) + + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + @testset for ltype in (:LPPool, :MeanPool, :MaxPool) + if ongpu && ltype == :LPPool + @test_broken false + continue + end + + broken_backends = ltype == :LPPool ? [AutoTracker(), AutoEnzyme()] : [] + + adaptive_ltype = Symbol(:Adaptive, ltype) + global_ltype = Symbol(:Global, ltype) + + x = randn(rng, Float32, 10, 10, 3, 2) |> aType + y = randn(rng, Float32, 20, 20, 3, 2) |> aType + + layer = getfield(Lux, adaptive_ltype)((5, 5)) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev + + @test size(layer(x, ps, st)[1]) == (5, 5, 3, 2) + @test layer(x, ps, st)[1] == nnlib_op[ltype](x, PoolDims(x, 2)) + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, broken_backends) + + layer = getfield(Lux, adaptive_ltype)((10, 5)) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev + + @test size(layer(y, ps, st)[1]) == (10, 5, 3, 2) + @test layer(y, ps, st)[1] == nnlib_op[ltype](y, PoolDims(y, (2, 4))) + @jet layer(y, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, broken_backends) + + layer = getfield(Lux, global_ltype)() + display(layer) + ps, st = Lux.setup(rng, layer) |> dev + + @test size(layer(x, ps, st)[1]) == (1, 1, 3, 2) + @test layer(x, ps, st)[1] == nnlib_op[ltype](x, PoolDims(x, size(x)[1:2])) + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, broken_backends) + + layer = getfield(Lux, ltype)((2, 2)) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev + + @test layer(x, ps, st)[1] == nnlib_op[ltype](x, PoolDims(x, 2)) + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, broken_backends) + + @testset "SamePad windowsize $k" for k in ((1,), (2,), (3,), (4, 5), (6, 7, 8)) + x = ones(Float32, (k .+ 3)..., 1, 1) |> aType + + layer = getfield(Lux, ltype)(k; pad=Lux.SamePad()) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev + + @test size(layer(x, ps, st)[1])[1:(end - 2)] == + cld.(size(x)[1:(end - 2)], k) + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + + soft_fail = ltype == :MaxPool ? [AutoFiniteDiff()] : [] + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, broken_backends) + end + end + end +end diff --git a/test/layers/recurrent_tests.jl b/test/layers/recurrent_tests.jl index e2970027bb..91b5ace684 100644 --- a/test/layers/recurrent_tests.jl +++ b/test/layers/recurrent_tests.jl @@ -1,15 +1,14 @@ @testitem "RNNCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - for rnncell in (RNNCell(3 => 5, identity), RNNCell(3 => 5, tanh), + @testset for rnncell in (RNNCell(3 => 5, identity), RNNCell(3 => 5, tanh), RNNCell(3 => 5, tanh; use_bias=false), RNNCell(3 => 5, identity; use_bias=false), RNNCell(3 => 5, identity; use_bias=false, train_state=false)) display(rnncell) - ps, st = Lux.setup(rng, rnncell) .|> dev - for x_size in ((3, 2), (3,)) + ps, st = Lux.setup(rng, rnncell) |> dev + @testset for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType (y, carry), st_ = Lux.apply(rnncell, x, ps, st) @@ -18,7 +17,7 @@ function loss_loop_rnncell(p) (y, carry), st_ = rnncell(x, p, st) - for i in 1:10 + for _ in 1:10 (y, carry), st_ = rnncell((x, carry), p, st_) end return sum(abs2, y) @@ -26,20 +25,21 @@ @test_throws ErrorException ps.train_state - test_gradients(loss_loop_rnncell, ps; atol=1.0f-3, - rtol=1.0f-3, broken_backends=[AutoEnzyme()]) + test_gradients(loss_loop_rnncell, ps; atol=1.0f-3, rtol=1.0f-3, + soft_fail=[AutoFiniteDiff()], broken_backends=[AutoEnzyme()]) end end @testset "Trainable hidden states" begin - for rnncell in (RNNCell(3 => 5, identity; use_bias=false, train_state=true), + @testset for rnncell in ( + RNNCell(3 => 5, identity; use_bias=false, train_state=true), RNNCell(3 => 5, identity; use_bias=true, train_state=true)) rnn_no_trainable_state = RNNCell( 3 => 5, identity; use_bias=false, train_state=false) - _ps, _st = Lux.setup(rng, rnn_no_trainable_state) .|> dev + _ps, _st = Lux.setup(rng, rnn_no_trainable_state) |> dev rnncell = RNNCell(3 => 5, identity; use_bias=false, train_state=true) - ps, st = Lux.setup(rng, rnncell) .|> dev + ps, st = Lux.setup(rng, rnncell) |> dev ps = merge(_ps, (hidden_state=ps.hidden_state,)) for x_size in ((3, 2), (3,)) @@ -60,16 +60,15 @@ end @testitem "LSTMCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - for lstmcell in (LSTMCell(3 => 5), LSTMCell(3 => 5; use_bias=true), + @testset for lstmcell in (LSTMCell(3 => 5), LSTMCell(3 => 5; use_bias=true), LSTMCell(3 => 5; use_bias=false)) display(lstmcell) - ps, st = Lux.setup(rng, lstmcell) .|> dev + ps, st = Lux.setup(rng, lstmcell) |> dev - for x_size in ((3, 2), (3,)) + @testset for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType (y, carry), st_ = Lux.apply(lstmcell, x, ps, st) @@ -84,8 +83,8 @@ end return sum(abs2, y) end - test_gradients(loss_loop_lstmcell, ps; atol=1.0f-3, - rtol=1.0f-3, broken_backends=[AutoEnzyme()]) + test_gradients(loss_loop_lstmcell, ps; atol=1.0f-3, rtol=1.0f-3, + soft_fail=[AutoFiniteDiff()], broken_backends=[AutoEnzyme()]) @test_throws ErrorException ps.train_state @test_throws ErrorException ps.train_memory @@ -93,73 +92,77 @@ end end @testset "Trainable hidden states" begin - for x_size in ((3, 2), (3,)) + @testset for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType _lstm = LSTMCell( 3 => 5; use_bias=false, train_state=false, train_memory=false) - _ps, _st = Lux.setup(rng, _lstm) .|> dev + _ps, _st = Lux.setup(rng, _lstm) |> dev (_y, _carry), _ = Lux.apply(_lstm, x, _ps, _st) lstm = LSTMCell( 3 => 5; use_bias=false, train_state=false, train_memory=false) - ps, st = Lux.setup(rng, lstm) .|> dev + ps, st = Lux.setup(rng, lstm) |> dev ps = _ps (y, carry), _ = Lux.apply(lstm, x, ps, st) @test carry == _carry l, back = Zygote.pullback( p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias + @test_throws ErrorException gs.bias_ih + @test_throws ErrorException gs.bias_hh @test_throws ErrorException gs.hidden_state @test_throws ErrorException gs.memory lstm = LSTMCell( 3 => 5; use_bias=false, train_state=true, train_memory=false) - ps, st = Lux.setup(rng, lstm) .|> dev + ps, st = Lux.setup(rng, lstm) |> dev ps = merge(_ps, (hidden_state=ps.hidden_state,)) (y, carry), _ = Lux.apply(lstm, x, ps, st) @test carry == _carry l, back = Zygote.pullback( p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias + @test_throws ErrorException gs.bias_ih + @test_throws ErrorException gs.bias_hh @test !isnothing(gs.hidden_state) @test_throws ErrorException gs.memory lstm = LSTMCell( 3 => 5; use_bias=false, train_state=false, train_memory=true) - ps, st = Lux.setup(rng, lstm) .|> dev + ps, st = Lux.setup(rng, lstm) |> dev ps = merge(_ps, (memory=ps.memory,)) (y, carry), _ = Lux.apply(lstm, x, ps, st) @test carry == _carry l, back = Zygote.pullback( p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias + @test_throws ErrorException gs.bias_ih + @test_throws ErrorException gs.bias_hh @test_throws ErrorException gs.hidden_state @test !isnothing(gs.memory) lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=true) - ps, st = Lux.setup(rng, lstm) .|> dev + ps, st = Lux.setup(rng, lstm) |> dev ps = merge(_ps, (hidden_state=ps.hidden_state, memory=ps.memory)) (y, carry), _ = Lux.apply(lstm, x, ps, st) @test carry == _carry l, back = Zygote.pullback( p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias + @test_throws ErrorException gs.bias_ih + @test_throws ErrorException gs.bias_hh @test !isnothing(gs.hidden_state) @test !isnothing(gs.memory) lstm = LSTMCell(3 => 5; use_bias=true, train_state=true, train_memory=true) - ps, st = Lux.setup(rng, lstm) .|> dev - ps = merge( - _ps, (bias=ps.bias, hidden_state=ps.hidden_state, memory=ps.memory)) + ps, st = Lux.setup(rng, lstm) |> dev + ps = merge(_ps, (; ps.bias_ih, ps.bias_hh, ps.hidden_state, ps.memory)) (y, carry), _ = Lux.apply(lstm, x, ps, st) l, back = Zygote.pullback( p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test !isnothing(gs.bias) + @test !isnothing(gs.bias_ih) + @test !isnothing(gs.bias_hh) @test !isnothing(gs.hidden_state) @test !isnothing(gs.memory) end @@ -168,16 +171,15 @@ end end @testitem "GRUCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - for grucell in (GRUCell(3 => 5), GRUCell(3 => 5; use_bias=true), + @testset for grucell in (GRUCell(3 => 5), GRUCell(3 => 5; use_bias=true), GRUCell(3 => 5; use_bias=false)) display(grucell) - ps, st = Lux.setup(rng, grucell) .|> dev + ps, st = Lux.setup(rng, grucell) |> dev - for x_size in ((3, 2), (3,)) + @testset for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType (y, carry), st_ = Lux.apply(grucell, x, ps, st) @@ -192,32 +194,33 @@ end return sum(abs2, y) end - test_gradients(loss_loop_grucell, ps; atol=1e-3, - rtol=1e-3, broken_backends=[AutoEnzyme()]) + test_gradients(loss_loop_grucell, ps; atol=1e-3, rtol=1e-3, + soft_fail=[AutoFiniteDiff()], broken_backends=[AutoEnzyme()]) @test_throws ErrorException ps.train_state end end @testset "Trainable hidden states" begin - for x_size in ((3, 2), (3,)) + @testset for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType _gru = GRUCell(3 => 5; use_bias=false, train_state=false) - _ps, _st = Lux.setup(rng, _gru) .|> dev + _ps, _st = Lux.setup(rng, _gru) |> dev (_y, _carry), _ = Lux.apply(_gru, x, _ps, _st) gru = GRUCell(3 => 5; use_bias=false, train_state=false) - ps, st = Lux.setup(rng, gru) .|> dev + ps, st = Lux.setup(rng, gru) |> dev ps = _ps (y, carry), _ = Lux.apply(gru, x, ps, st) @test carry == _carry l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias + @test_throws ErrorException gs.bias_ih + @test_throws ErrorException gs.bias_hh @test_throws ErrorException gs.hidden_state gru = GRUCell(3 => 5; use_bias=false, train_state=true) - ps, st = Lux.setup(rng, gru) .|> dev + ps, st = Lux.setup(rng, gru) |> dev ps = merge(_ps, (hidden_state=ps.hidden_state,)) (y, carry), _ = Lux.apply(gru, x, ps, st) @test carry == _carry @@ -226,9 +229,8 @@ end @test !isnothing(gs.hidden_state) gru = GRUCell(3 => 5; use_bias=true, train_state=true) - ps, st = Lux.setup(rng, gru) .|> dev - ps = merge( - _ps, (bias_h=ps.bias_h, bias_i=ps.bias_i, hidden_state=ps.hidden_state)) + ps, st = Lux.setup(rng, gru) |> dev + ps = merge(_ps, (; ps.bias_ih, ps.bias_hh, ps.hidden_state)) (y, carry), _ = Lux.apply(gru, x, ps, st) @test carry == _carry l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) @@ -240,7 +242,6 @@ end end @testitem "StatefulRecurrentCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES @@ -254,7 +255,7 @@ end for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType - ps, st = Lux.setup(rng, rnn) .|> dev + ps, st = Lux.setup(rng, rnn) |> dev y, st_ = rnn(x, ps, st) @@ -280,15 +281,14 @@ end return sum(abs2, y) end - test_gradients( - loss_loop_rnn, ps; atol=1e-3, rtol=1e-3, broken_backends=[AutoEnzyme()]) + test_gradients(loss_loop_rnn, ps; atol=1e-3, rtol=1e-3, + broken_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()]) end end end end @testitem "Recurrence" setup=[SharedTestSetup] tags=[:recurrent_layers] begin - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES @@ -314,7 +314,7 @@ end (ntuple(identity, ndims(x) - 2)..., ndims(x), ndims(x) - 1)) end - ps, st = Lux.setup(rng, rnn) .|> dev + ps, st = Lux.setup(rng, rnn) |> dev y, st_ = rnn(x, ps, st) y_, st__ = rnn_seq(x, ps, st) @@ -326,12 +326,12 @@ end @test all(x -> size(x) == (5, 2), y_) __f = p -> sum(first(rnn(x, p, st))) - test_gradients( - __f, ps; atol=1e-3, rtol=1e-3, skip_backends=[AutoEnzyme()]) + test_gradients(__f, ps; atol=1e-3, rtol=1e-3, + skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()]) __f = p -> sum(Base.Fix1(sum, abs2), first(rnn_seq(x, p, st))) - test_gradients( - __f, ps; atol=1e-3, rtol=1e-3, skip_backends=[AutoEnzyme()]) + test_gradients(__f, ps; atol=1e-3, rtol=1e-3, + skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()]) end # Batched Time Series without data batches @@ -339,7 +339,7 @@ end randn(rng, Float32, 3, 4) |> aType, Tuple(randn(rng, Float32, 3) for _ in 1:4) .|> aType, [randn(rng, Float32, 3) for _ in 1:4] .|> aType) - ps, st = Lux.setup(rng, rnn) .|> dev + ps, st = Lux.setup(rng, rnn) |> dev y, st_ = rnn(x, ps, st) y_, st__ = rnn_seq(x, ps, st) @@ -361,12 +361,12 @@ end end __f = p -> sum(first(rnn(x, p, st))) - test_gradients( - __f, ps; atol=1e-3, rtol=1e-3, skip_backends=[AutoEnzyme()]) + test_gradients(__f, ps; atol=1e-3, rtol=1e-3, + skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()]) __f = p -> sum(Base.Fix1(sum, abs2), first(rnn_seq(x, p, st))) - test_gradients( - __f, ps; atol=1e-3, rtol=1e-3, skip_backends=[AutoEnzyme()]) + test_gradients(__f, ps; atol=1e-3, rtol=1e-3, + skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()]) end end end @@ -379,7 +379,7 @@ end init_state=(rng, args...; kwargs...) -> zeros(args...; kwargs...), init_bias=(rng, args...; kwargs...) -> zeros(args...; kwargs...)); return_sequence=true) - ps, st = Lux.setup(rng, encoder) .|> dev + ps, st = Lux.setup(rng, encoder) |> dev m2 = reshape([0.5, 0.0, 0.7, 0.8], 1, :, 1) |> aType res, _ = encoder(m2, ps, st) @@ -388,7 +388,6 @@ end end @testitem "Bidirectional" setup=[SharedTestSetup] tags=[:recurrent_layers] begin - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES @@ -400,7 +399,7 @@ end # Batched Time Series x = randn(rng, Float32, 3, 4, 2) |> aType - ps, st = Lux.setup(rng, bi_rnn) .|> dev + ps, st = Lux.setup(rng, bi_rnn) |> dev y, st_ = bi_rnn(x, ps, st) y_, st__ = bi_rnn_no_merge(x, ps, st) @@ -436,7 +435,7 @@ end # Batched Time Series x = randn(rng, Float32, 3, 4, 2) |> aType - ps, st = Lux.setup(rng, bi_rnn) .|> dev + ps, st = Lux.setup(rng, bi_rnn) |> dev y, st_ = bi_rnn(x, ps, st) y_, st__ = bi_rnn_no_merge(x, ps, st) diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 06e20f31e4..074f464b09 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -1,12 +1,13 @@ @testitem "Aqua: Quality Assurance" tags=[:others] begin using Aqua, ChainRulesCore, ForwardDiff - Aqua.test_all(Lux; ambiguities=false) + Aqua.test_all(Lux; ambiguities=false, piracies=false) Aqua.test_ambiguities(Lux; exclude=[ForwardDiff.jacobian, ForwardDiff.gradient, Lux.AutoDiffInternalImpl.batched_jacobian, Lux.AutoDiffInternalImpl.jacobian_vector_product, Lux.AutoDiffInternalImpl.jacobian_vector_product_impl]) + Aqua.test_piracies(Lux; treat_as_own=[Lux.outputsize]) end @testitem "Explicit Imports: Quality Assurance" setup=[SharedTestSetup] tags=[:others] begin @@ -16,10 +17,10 @@ end # Skip our own packages @test check_no_implicit_imports( - Lux; skip=(Base, Core, LuxCore, LuxDeviceUtils, LuxLib, WeightInitializers)) === + Lux; skip=(Base, Core, LuxCore, MLDataDevices, LuxLib, WeightInitializers)) === nothing @test check_no_stale_explicit_imports( - Lux; ignore=(:inputsize, :setup, :testmode, :trainmode, :update_state)) === nothing + Lux; ignore=(:setup, :testmode, :trainmode, :update_state)) === nothing @test check_no_self_qualified_accesses(Lux) === nothing @test check_all_explicit_imports_via_owners(Lux) === nothing @test check_all_qualified_accesses_via_owners( @@ -34,8 +35,7 @@ end doctestexpr = quote using SimpleChains: static - using DynamicExpressions - using Adapt, Lux, Random, Optimisers, Zygote + using Adapt, Lux, Random, Optimisers, Zygote, NNlib end DocMeta.setdocmeta!(Lux, :DocTestSetup, doctestexpr; recursive=true) diff --git a/test/runtests.jl b/test/runtests.jl index fb0950818d..a6160b119d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -71,10 +71,6 @@ Lux.set_dispatch_doctor_preferences!(; luxcore="error", luxlib="error") @test !Lux.is_extension_loaded(Val(:Zygote)) using Zygote @test Lux.is_extension_loaded(Val(:Zygote)) - - @test !Lux.is_extension_loaded(Val(:DynamicExpressions)) - using DynamicExpressions - @test Lux.is_extension_loaded(Val(:DynamicExpressions)) end # These need to be run before MPI or NCCL is ever loaded @@ -97,7 +93,7 @@ const RETESTITEMS_NWORKERS = parse( @info "Running tests for group: [$(i)/$(length(LUX_TEST_GROUP))] $tag" ReTestItems.runtests(Lux; tags=(tag == "all" ? nothing : [Symbol(tag)]), - nworkers=RETESTITEMS_NWORKERS, testitem_timeout=1800, retries=1) + nworkers=RETESTITEMS_NWORKERS, testitem_timeout=2400, retries=1) end end diff --git a/test/setup_modes.jl b/test/setup_modes.jl index 88b88a247b..1617179a5b 100644 --- a/test/setup_modes.jl +++ b/test/setup_modes.jl @@ -25,9 +25,9 @@ end const MODES = begin # Mode, Array Type, Device Function, GPU? modes = [] - cpu_testing() && push!(modes, ("cpu", Array, LuxCPUDevice(), false)) - cuda_testing() && push!(modes, ("cuda", CuArray, LuxCUDADevice(), true)) - amdgpu_testing() && push!(modes, ("amdgpu", ROCArray, LuxAMDGPUDevice(), true)) + cpu_testing() && push!(modes, ("cpu", Array, CPUDevice(), false)) + cuda_testing() && push!(modes, ("cuda", CuArray, CUDADevice(), true)) + amdgpu_testing() && push!(modes, ("amdgpu", ROCArray, AMDGPUDevice(), true)) modes end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index 4ba455da8f..85abd32042 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -1,15 +1,20 @@ @testsetup module SharedTestSetup +using Enzyme +Enzyme.API.runtimeActivity!(true) + include("setup_modes.jl") import Reexport: @reexport using Lux, Functors +using Setfield: @set using DispatchDoctor: allow_unstable @reexport using ComponentArrays, LuxCore, LuxLib, LuxTestUtils, Random, StableRNGs, Test, Zygote, Statistics, Enzyme, LinearAlgebra, ForwardDiff using MLDataDevices: default_device_rng, CPUDevice, CUDADevice, AMDGPUDevice using LuxTestUtils: check_approx +using Static: True LuxTestUtils.jet_target_modules!(["Lux", "LuxCore", "LuxLib"]) LinearAlgebra.BLAS.set_num_threads(Threads.nthreads()) @@ -24,9 +29,8 @@ end maybe_rewrite_to_crosscor(layer) = layer function maybe_rewrite_to_crosscor(layer::Conv) - return CrossCor(layer.activation, layer.in_chs, layer.out_chs, layer.kernel_size, - layer.stride, layer.pad, layer.dilation, layer.groups, - layer.init_weight, layer.init_bias, layer.use_bias) + @set layer.cross_correlation = True() + return layer end function maybe_rewrite_to_crosscor(mode, model) diff --git a/test/transform/flux_tests.jl b/test/transform/flux_tests.jl index 7c6cc0ed09..fe7e5fceef 100644 --- a/test/transform/flux_tests.jl +++ b/test/transform/flux_tests.jl @@ -1,9 +1,9 @@ @testitem "FromFluxAdaptor" setup=[SharedTestSetup] tags=[:fluxcompat] begin import Flux - from_flux = fdev(::Lux.LuxCPUDevice) = Flux.cpu - fdev(::Lux.LuxCUDADevice) = Base.Fix1(Flux.gpu, Flux.FluxCUDAAdaptor()) - fdev(::Lux.LuxAMDGPUDevice) = Base.Fix1(Flux.gpu, Flux.FluxAMDAdaptor()) + from_flux = fdev(::Lux.CPUDevice) = Flux.cpu + fdev(::Lux.CUDADevice) = Base.Fix1(Flux.gpu, Flux.FluxCUDAAdaptor()) + fdev(::Lux.AMDGPUDevice) = Base.Fix1(Flux.gpu, Flux.FluxAMDAdaptor()) toluxpsst = FromFluxAdaptor(; preserve_ps_st=true) tolux = FromFluxAdaptor() @@ -327,53 +327,20 @@ @testset "Recurrent" begin @testset "RNNCell" begin model = Flux.RNNCell(2 => 3) |> fdev(dev) - x = rand(Float32, 2, 4) |> aType - - model_lux = tolux(model) - ps, st = Lux.setup(StableRNG(12345), model_lux) .|> dev - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) - + @test_throws Lux.FluxModelConversionException tolux(model) @test_throws Lux.FluxModelConversionException toluxforce(model) - - model_lux = toluxpsst(model) - ps, st = Lux.setup(StableRNG(12345), model_lux) .|> dev - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) end @testset "LSTMCell" begin model = Flux.LSTMCell(2 => 3) |> fdev(dev) - x = rand(Float32, 2, 4) |> aType - - model_lux = tolux(model) - ps, st = Lux.setup(StableRNG(12345), model_lux) .|> dev - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) - + @test_throws Lux.FluxModelConversionException tolux(model) @test_throws Lux.FluxModelConversionException toluxforce(model) - - model_lux = toluxpsst(model) - ps, st = Lux.setup(StableRNG(12345), model_lux) .|> dev - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) end @testset "GRUCell" begin model = Flux.GRUCell(2 => 3) |> fdev(dev) - x = rand(Float32, 2, 4) |> aType - - model_lux = tolux(model) - ps, st = Lux.setup(StableRNG(12345), model_lux) .|> dev - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) - + @test_throws Lux.FluxModelConversionException tolux(model) @test_throws Lux.FluxModelConversionException toluxforce(model) - - model_lux = toluxpsst(model) - ps, st = Lux.setup(StableRNG(12345), model_lux) .|> dev - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) end end @@ -503,7 +470,7 @@ @testset "Functions" begin @test tolux(Flux.flatten) isa Lux.FlattenLayer @test tolux(identity) isa Lux.NoOpLayer - @test tolux(+) isa Lux.WrappedFunction{:direct_call} + @test tolux(+) isa Lux.WrappedFunction end @testset "Unsupported Layers" begin diff --git a/test/transform/simple_chains_tests.jl b/test/transform/simple_chains_tests.jl index 5f22bea523..97428139df 100644 --- a/test/transform/simple_chains_tests.jl +++ b/test/transform/simple_chains_tests.jl @@ -29,7 +29,8 @@ @test length(gs[2].params) == length(ps.params) # See https://github.com/LuxDL/Lux.jl/issues/644 - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoEnzyme()]) + test_gradients( + __f, x, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoEnzyme(), AutoTracker()]) x = randn(Float32, 28, 28, 1, 15) @test size(first(simple_chains_model(x, ps, st))) == (10, 15) @@ -41,7 +42,8 @@ @test length(gs[2].params) == length(ps.params) # See https://github.com/LuxDL/Lux.jl/issues/644 - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoEnzyme()]) + test_gradients( + __f, x, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoEnzyme(), AutoTracker()]) @testset "Array Output" begin adaptor = ToSimpleChainsAdaptor((static(28), static(28), static(1)), true) @@ -87,7 +89,7 @@ for dims in (static(10), (static(10),)) adaptor = ToSimpleChainsAdaptor(dims) - simple_chains_model = @test_warn "The model provided is not a `Chain`. Trying to wrap it into a `Chain` but this might fail. Please consider using `Chain` directly (potentially with `disable_optimizations = true`)." adaptor(lux_model) + simple_chains_model = @test_warn "The model provided is not a `Chain`. Trying to wrap it into a `Chain` but this might fail. Please consider using `Chain` directly." adaptor(lux_model) ps, st = Lux.setup(Random.default_rng(), simple_chains_model) @@ -101,9 +103,9 @@ @test length(gs[2].params) == length(ps.params) # See https://github.com/LuxDL/Lux.jl/issues/644 - test_gradients( - __f, x, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoEnzyme()], - soft_fail=[AutoForwardDiff(), AutoFiniteDiff(), AutoTracker()]) + test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, + broken_backends=[AutoEnzyme(), AutoTracker()], + soft_fail=[AutoForwardDiff(), AutoFiniteDiff()]) end end diff --git a/test/utils_tests.jl b/test/utils_tests.jl index d017b11dfa..fe7811cffa 100644 --- a/test/utils_tests.jl +++ b/test/utils_tests.jl @@ -31,21 +31,6 @@ end @test eltype(ComponentArray(Any[:a, 1], (FlatAxis(),))) == Any end -@testitem "Deprecations" tags=[:others] begin - using Functors - - @test_deprecated Lux.disable_stacktrace_truncation!() - @test_deprecated Lux.cpu(rand(2)) - @test_deprecated Lux.gpu(rand(2)) - - model = NoOpLayer() - @test_deprecated Lux.Experimental.StatefulLuxLayer(model, (;), (;)) - - @test_deprecated Lux.Experimental.DebugLayer(model; location="model") - dmodel = Lux.Experimental.DebugLayer(model; location="model") - @test dmodel.location == KeyPath(:model) -end - @testitem "multigate" setup=[SharedTestSetup] tags=[:others] begin rng = StableRNG(12345) @@ -128,13 +113,13 @@ end @test length(Zygote.gradient(l2reg, ps)) == 1 end -@testitem "Utils.init_hidden_state" setup=[SharedTestSetup] tags=[:recurrent_layers] begin +@testitem "Utils.init_rnn_hidden_state" setup=[SharedTestSetup] tags=[:recurrent_layers] begin rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES rnn = RNNCell(3 => 5; init_state=Lux.zeros32) x = randn(rng, Float32, 3, 2, 2) - @test Lux.Utils.init_hidden_state(rng, rnn, view(dev(x), :, 1, :)) == + @test Lux.Utils.init_rnn_hidden_state(rng, rnn, view(dev(x), :, 1, :)) == aType(zeros(Float32, 5, 2)) end end @@ -143,8 +128,8 @@ end rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) + model = Chain( + Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), BatchNorm(1)) for (f, ftype) in zip((f16, f32, f64), (Float16, Float32, Float64)) ps, st = Lux.setup(rng, model) |> dev |> f diff --git a/test/zygote_type_stability.jl b/test/zygote_type_stability.jl index 517ba590f1..b8d0d22c3d 100644 --- a/test/zygote_type_stability.jl +++ b/test/zygote_type_stability.jl @@ -77,7 +77,7 @@ include("setup_modes.jl") @test @inferred(model(x, ps, st)) isa Any @test @inferred(loss_function(model, x, ps, st)) isa Any - if mode == "amdgpu" && (model isa Conv || model isa CrossCor) + if mode == "amdgpu" && model isa Conv @test_broken @inferred(Zygote.gradient(loss_function, model, x, ps, st)) isa Any else