1
0
mirror of https://github.com/owncloud/ocis.git synced 2025-04-18 23:44:07 +03:00

Bump github.com/nats-io/nats-server/v2 from 2.10.22 to 2.11.1

Bumps [github.com/nats-io/nats-server/v2](https://github.com/nats-io/nats-server) from 2.10.22 to 2.11.1.
- [Release notes](https://github.com/nats-io/nats-server/releases)
- [Changelog](https://github.com/nats-io/nats-server/blob/main/.goreleaser.yml)
- [Commits](https://github.com/nats-io/nats-server/compare/v2.10.22...v2.11.1)

---
updated-dependencies:
- dependency-name: github.com/nats-io/nats-server/v2
  dependency-version: 2.11.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
This commit is contained in:
dependabot[bot] 2025-04-16 07:00:30 +00:00 committed by GitHub
parent 20f32d4f2a
commit b8b7a36554
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
179 changed files with 19979 additions and 4072 deletions

13
go.mod
View File

@ -57,8 +57,8 @@ require (
github.com/mitchellh/mapstructure v1.5.0
github.com/mna/pigeon v1.3.0
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/nats-io/nats-server/v2 v2.10.22
github.com/nats-io/nats.go v1.37.0
github.com/nats-io/nats-server/v2 v2.11.1
github.com/nats-io/nats.go v1.39.1
github.com/oklog/run v1.1.0
github.com/olekukonko/tablewriter v0.0.5
github.com/onsi/ginkgo v1.16.5
@ -224,6 +224,7 @@ require (
github.com/gomodule/redigo v1.9.2 // indirect
github.com/google/flatbuffers v2.0.8+incompatible // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/go-tpm v0.9.3 // indirect
github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db // indirect
github.com/google/renameio/v2 v2.0.0 // indirect
github.com/gookit/color v1.5.4 // indirect
@ -244,7 +245,7 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/juliangruber/go-intersect v1.1.0 // indirect
github.com/kevinburke/ssh_config v1.2.0 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.8 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/libregraph/oidc-go v1.1.0 // indirect
@ -268,8 +269,8 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/mschoch/smat v0.2.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/nats-io/jwt/v2 v2.5.8 // indirect
github.com/nats-io/nkeys v0.4.7 // indirect
github.com/nats-io/jwt/v2 v2.7.3 // indirect
github.com/nats-io/nkeys v0.4.10 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
github.com/nxadm/tail v1.4.8 // indirect
github.com/opencontainers/runtime-spec v1.1.0 // indirect
@ -328,7 +329,7 @@ require (
go.uber.org/zap v1.23.0 // indirect
golang.org/x/mod v0.21.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/time v0.7.0 // indirect
golang.org/x/time v0.11.0 // indirect
golang.org/x/tools v0.26.0 // indirect
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53 // indirect

28
go.sum
View File

@ -113,6 +113,8 @@ github.com/amoghe/go-crypt v0.0.0-20220222110647-20eada5f5964 h1:I9YN9WMo3SUh7p/
github.com/amoghe/go-crypt v0.0.0-20220222110647-20eada5f5964/go.mod h1:eFiR01PwTcpbzXtdMces7zxg6utvFM5puiWHpWB8D/k=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/antithesishq/antithesis-sdk-go v0.4.3-default-no-op h1:+OSa/t11TFhqfrX0EOSqQBDJ0YlpmK0rDSiB19dg9M0=
github.com/antithesishq/antithesis-sdk-go v0.4.3-default-no-op/go.mod h1:IUpT2DPAKh6i/YhSbt6Gl3v2yvUZjmKncl7U91fup7E=
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
@ -533,6 +535,8 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/go-tika v0.3.1 h1:l+jr10hDhZjcgxFRfcQChRLo1bPXQeLFluMyvDhXTTA=
github.com/google/go-tika v0.3.1/go.mod h1:DJh5N8qxXIl85QkqmXknd+PeeRkUOTbvwyYf7ieDz6c=
github.com/google/go-tpm v0.9.3 h1:+yx0/anQuGzi+ssRqeD6WpXjW2L/V0dItUayO0i9sRc=
github.com/google/go-tpm v0.9.3/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
@ -689,8 +693,8 @@ github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvW
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM=
github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
@ -829,14 +833,14 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/namedotcom/go v0.0.0-20180403034216-08470befbe04/go.mod h1:5sN+Lt1CaY4wsPvgQH/jsuJi4XO2ssZbdsIizr4CVC8=
github.com/nats-io/jwt/v2 v2.5.8 h1:uvdSzwWiEGWGXf+0Q+70qv6AQdvcvxrv9hPM0RiPamE=
github.com/nats-io/jwt/v2 v2.5.8/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A=
github.com/nats-io/nats-server/v2 v2.10.22 h1:Yt63BGu2c3DdMoBZNcR6pjGQwk/asrKU7VX846ibxDA=
github.com/nats-io/nats-server/v2 v2.10.22/go.mod h1:X/m1ye9NYansUXYFrbcDwUi/blHkrgHh2rgCJaakonk=
github.com/nats-io/nats.go v1.37.0 h1:07rauXbVnnJvv1gfIyghFEo6lUcYRY0WXc3x7x0vUxE=
github.com/nats-io/nats.go v1.37.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8=
github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI=
github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc=
github.com/nats-io/jwt/v2 v2.7.3 h1:6bNPK+FXgBeAqdj4cYQ0F8ViHRbi7woQLq4W29nUAzE=
github.com/nats-io/jwt/v2 v2.7.3/go.mod h1:GvkcbHhKquj3pkioy5put1wvPxs78UlZ7D/pY+BgZk4=
github.com/nats-io/nats-server/v2 v2.11.1 h1:LwdauqMqMNhTxTN3+WFTX6wGDOKntHljgZ+7gL5HCnk=
github.com/nats-io/nats-server/v2 v2.11.1/go.mod h1:leXySghbdtXSUmWem8K9McnJ6xbJOb0t9+NQ5HTRZjI=
github.com/nats-io/nats.go v1.39.1 h1:oTkfKBmz7W047vRxV762M67ZdXeOtUgvbBaNoQ+3PPk=
github.com/nats-io/nats.go v1.39.1/go.mod h1:MgRb8oOdigA6cYpEPhXJuRVH6UE/V4jblJ2jQ27IXYM=
github.com/nats-io/nkeys v0.4.10 h1:glmRrpCmYLHByYcePvnTBEAwawwapjCPMjy2huw20wc=
github.com/nats-io/nkeys v0.4.10/go.mod h1:OjRrnIKnWBFl+s4YK5ChQfvHP2fxqZexrKJoVVyWB3U=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms=
@ -1476,8 +1480,8 @@ golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxb
golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

202
vendor/github.com/google/go-tpm/LICENSE generated vendored Normal file
View File

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

35
vendor/github.com/google/go-tpm/legacy/tpm2/README.md generated vendored Normal file
View File

@ -0,0 +1,35 @@
# TPM 2.0 client library
## Tests
This library contains unit tests in `github.com/google/go-tpm/tpm2`, which just
tests that various encoding and error checking functions work correctly. It also
contains more comprehensive integration tests in
`github.com/google/go-tpm/tpm2/test`, which run actual commands on a TPM.
By default, these integration tests are run against the
[`go-tpm-tools`](https://github.com/google/go-tpm-tools)
simulator, which is baesed on the
[Microsoft Reference TPM2 code](https://github.com/microsoft/ms-tpm-20-ref). To
run both the unit and integration tests, run (in this directory)
```bash
go test . ./test
```
These integration tests can also be run against a real TPM device. This is
slightly more complex as the tests often need to be built as a normal user and
then executed as root. For example,
```bash
# Build the test binary without running it
go test -c github.com/google/go-tpm/tpm2/test
# Execute the test binary as root
sudo ./test.test --tpm-path=/dev/tpmrm0
```
On Linux, The `--tpm-path` causes the integration tests to be run against a
real TPM located at that path (usually `/dev/tpmrm0` or `/dev/tpm0`). On Windows, the story is similar, execept that
the `--use-tbs` flag is used instead.
Tip: if your TPM host is remote and you don't want to install Go on it, this
same two-step process can be used. The test binary can be copied to a remote
host and run without extra installation (as the test binary has very few
*runtime* dependancies).

View File

@ -0,0 +1,575 @@
// Copyright (c) 2018, Google LLC All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tpm2
import (
"crypto"
"crypto/elliptic"
"fmt"
"strings"
// Register the relevant hash implementations to prevent a runtime failure.
_ "crypto/sha1"
_ "crypto/sha256"
_ "crypto/sha512"
"github.com/google/go-tpm/tpmutil"
)
var hashInfo = []struct {
alg Algorithm
hash crypto.Hash
}{
{AlgSHA1, crypto.SHA1},
{AlgSHA256, crypto.SHA256},
{AlgSHA384, crypto.SHA384},
{AlgSHA512, crypto.SHA512},
{AlgSHA3_256, crypto.SHA3_256},
{AlgSHA3_384, crypto.SHA3_384},
{AlgSHA3_512, crypto.SHA3_512},
}
// MAX_DIGEST_BUFFER is the maximum size of []byte request or response fields.
// Typically used for chunking of big blobs of data (such as for hashing or
// encryption).
const maxDigestBuffer = 1024
// Algorithm represents a TPM_ALG_ID value.
type Algorithm uint16
// HashToAlgorithm looks up the TPM2 algorithm corresponding to the provided crypto.Hash
func HashToAlgorithm(hash crypto.Hash) (Algorithm, error) {
for _, info := range hashInfo {
if info.hash == hash {
return info.alg, nil
}
}
return AlgUnknown, fmt.Errorf("go hash algorithm #%d has no TPM2 algorithm", hash)
}
// IsNull returns true if a is AlgNull or zero (unset).
func (a Algorithm) IsNull() bool {
return a == AlgNull || a == AlgUnknown
}
// UsesCount returns true if a signature algorithm uses count value.
func (a Algorithm) UsesCount() bool {
return a == AlgECDAA
}
// UsesHash returns true if the algorithm requires the use of a hash.
func (a Algorithm) UsesHash() bool {
return a == AlgOAEP
}
// Hash returns a crypto.Hash based on the given TPM_ALG_ID.
// An error is returned if the given algorithm is not a hash algorithm or is not available.
func (a Algorithm) Hash() (crypto.Hash, error) {
for _, info := range hashInfo {
if info.alg == a {
if !info.hash.Available() {
return crypto.Hash(0), fmt.Errorf("go hash algorithm #%d not available", info.hash)
}
return info.hash, nil
}
}
return crypto.Hash(0), fmt.Errorf("hash algorithm not supported: 0x%x", a)
}
func (a Algorithm) String() string {
var s strings.Builder
var err error
switch a {
case AlgUnknown:
_, err = s.WriteString("AlgUnknown")
case AlgRSA:
_, err = s.WriteString("RSA")
case AlgSHA1:
_, err = s.WriteString("SHA1")
case AlgHMAC:
_, err = s.WriteString("HMAC")
case AlgAES:
_, err = s.WriteString("AES")
case AlgKeyedHash:
_, err = s.WriteString("KeyedHash")
case AlgXOR:
_, err = s.WriteString("XOR")
case AlgSHA256:
_, err = s.WriteString("SHA256")
case AlgSHA384:
_, err = s.WriteString("SHA384")
case AlgSHA512:
_, err = s.WriteString("SHA512")
case AlgNull:
_, err = s.WriteString("AlgNull")
case AlgRSASSA:
_, err = s.WriteString("RSASSA")
case AlgRSAES:
_, err = s.WriteString("RSAES")
case AlgRSAPSS:
_, err = s.WriteString("RSAPSS")
case AlgOAEP:
_, err = s.WriteString("OAEP")
case AlgECDSA:
_, err = s.WriteString("ECDSA")
case AlgECDH:
_, err = s.WriteString("ECDH")
case AlgECDAA:
_, err = s.WriteString("ECDAA")
case AlgKDF2:
_, err = s.WriteString("KDF2")
case AlgECC:
_, err = s.WriteString("ECC")
case AlgSymCipher:
_, err = s.WriteString("SymCipher")
case AlgSHA3_256:
_, err = s.WriteString("SHA3_256")
case AlgSHA3_384:
_, err = s.WriteString("SHA3_384")
case AlgSHA3_512:
_, err = s.WriteString("SHA3_512")
case AlgCTR:
_, err = s.WriteString("CTR")
case AlgOFB:
_, err = s.WriteString("OFB")
case AlgCBC:
_, err = s.WriteString("CBC")
case AlgCFB:
_, err = s.WriteString("CFB")
case AlgECB:
_, err = s.WriteString("ECB")
default:
return fmt.Sprintf("Alg?<%d>", int(a))
}
if err != nil {
return fmt.Sprintf("Writing to string builder failed: %v", err)
}
return s.String()
}
// Supported Algorithms.
const (
AlgUnknown Algorithm = 0x0000
AlgRSA Algorithm = 0x0001
AlgSHA1 Algorithm = 0x0004
AlgHMAC Algorithm = 0x0005
AlgAES Algorithm = 0x0006
AlgKeyedHash Algorithm = 0x0008
AlgXOR Algorithm = 0x000A
AlgSHA256 Algorithm = 0x000B
AlgSHA384 Algorithm = 0x000C
AlgSHA512 Algorithm = 0x000D
AlgNull Algorithm = 0x0010
AlgRSASSA Algorithm = 0x0014
AlgRSAES Algorithm = 0x0015
AlgRSAPSS Algorithm = 0x0016
AlgOAEP Algorithm = 0x0017
AlgECDSA Algorithm = 0x0018
AlgECDH Algorithm = 0x0019
AlgECDAA Algorithm = 0x001A
AlgKDF2 Algorithm = 0x0021
AlgECC Algorithm = 0x0023
AlgSymCipher Algorithm = 0x0025
AlgSHA3_256 Algorithm = 0x0027
AlgSHA3_384 Algorithm = 0x0028
AlgSHA3_512 Algorithm = 0x0029
AlgCTR Algorithm = 0x0040
AlgOFB Algorithm = 0x0041
AlgCBC Algorithm = 0x0042
AlgCFB Algorithm = 0x0043
AlgECB Algorithm = 0x0044
)
// HandleType defines a type of handle.
type HandleType uint8
// Supported handle types
const (
HandleTypePCR HandleType = 0x00
HandleTypeNVIndex HandleType = 0x01
HandleTypeHMACSession HandleType = 0x02
HandleTypeLoadedSession HandleType = 0x02
HandleTypePolicySession HandleType = 0x03
HandleTypeSavedSession HandleType = 0x03
HandleTypePermanent HandleType = 0x40
HandleTypeTransient HandleType = 0x80
HandleTypePersistent HandleType = 0x81
)
// SessionType defines the type of session created in StartAuthSession.
type SessionType uint8
// Supported session types.
const (
SessionHMAC SessionType = 0x00
SessionPolicy SessionType = 0x01
SessionTrial SessionType = 0x03
)
// SessionAttributes represents an attribute of a session.
type SessionAttributes byte
// Session Attributes (Structures 8.4 TPMA_SESSION)
const (
AttrContinueSession SessionAttributes = 1 << iota
AttrAuditExclusive
AttrAuditReset
_ // bit 3 reserved
_ // bit 4 reserved
AttrDecrypt
AttrEcrypt
AttrAudit
)
// EmptyAuth represents the empty authorization value.
var EmptyAuth []byte
// KeyProp is a bitmask used in Attributes field of key templates. Individual
// flags should be OR-ed to form a full mask.
type KeyProp uint32
// Key properties.
const (
FlagFixedTPM KeyProp = 0x00000002
FlagStClear KeyProp = 0x00000004
FlagFixedParent KeyProp = 0x00000010
FlagSensitiveDataOrigin KeyProp = 0x00000020
FlagUserWithAuth KeyProp = 0x00000040
FlagAdminWithPolicy KeyProp = 0x00000080
FlagNoDA KeyProp = 0x00000400
FlagRestricted KeyProp = 0x00010000
FlagDecrypt KeyProp = 0x00020000
FlagSign KeyProp = 0x00040000
FlagSealDefault = FlagFixedTPM | FlagFixedParent
FlagSignerDefault = FlagSign | FlagRestricted | FlagFixedTPM |
FlagFixedParent | FlagSensitiveDataOrigin | FlagUserWithAuth
FlagStorageDefault = FlagDecrypt | FlagRestricted | FlagFixedTPM |
FlagFixedParent | FlagSensitiveDataOrigin | FlagUserWithAuth
)
// TPMProp represents a Property Tag (TPM_PT) used with calls to GetCapability(CapabilityTPMProperties).
type TPMProp uint32
// TPM Capability Properties, see TPM 2.0 Spec, Rev 1.38, Table 23.
// Fixed TPM Properties (PT_FIXED)
const (
FamilyIndicator TPMProp = 0x100 + iota
SpecLevel
SpecRevision
SpecDayOfYear
SpecYear
Manufacturer
VendorString1
VendorString2
VendorString3
VendorString4
VendorTPMType
FirmwareVersion1
FirmwareVersion2
InputMaxBufferSize
TransientObjectsMin
PersistentObjectsMin
LoadedObjectsMin
ActiveSessionsMax
PCRCount
PCRSelectMin
ContextGapMax
_ // (PT_FIXED + 21) is skipped
NVCountersMax
NVIndexMax
MemoryMethod
ClockUpdate
ContextHash
ContextSym
ContextSymSize
OrderlyCount
CommandMaxSize
ResponseMaxSize
DigestMaxSize
ObjectContextMaxSize
SessionContextMaxSize
PSFamilyIndicator
PSSpecLevel
PSSpecRevision
PSSpecDayOfYear
PSSpecYear
SplitSigningMax
TotalCommands
LibraryCommands
VendorCommands
NVMaxBufferSize
TPMModes
CapabilityMaxBufferSize
)
// Variable TPM Properties (PT_VAR)
const (
TPMAPermanent TPMProp = 0x200 + iota
TPMAStartupClear
HRNVIndex
HRLoaded
HRLoadedAvail
HRActive
HRActiveAvail
HRTransientAvail
CurrentPersistent
AvailPersistent
NVCounters
NVCountersAvail
AlgorithmSet
LoadedCurves
LockoutCounter
MaxAuthFail
LockoutInterval
LockoutRecovery
NVWriteRecovery
AuditCounter0
AuditCounter1
)
// Allowed ranges of different kinds of Handles (TPM_HANDLE)
// These constants have type TPMProp for backwards compatibility.
const (
PCRFirst TPMProp = 0x00000000
HMACSessionFirst TPMProp = 0x02000000
LoadedSessionFirst TPMProp = 0x02000000
PolicySessionFirst TPMProp = 0x03000000
ActiveSessionFirst TPMProp = 0x03000000
TransientFirst TPMProp = 0x80000000
PersistentFirst TPMProp = 0x81000000
PersistentLast TPMProp = 0x81FFFFFF
PlatformPersistent TPMProp = 0x81800000
NVIndexFirst TPMProp = 0x01000000
NVIndexLast TPMProp = 0x01FFFFFF
PermanentFirst TPMProp = 0x40000000
PermanentLast TPMProp = 0x4000010F
)
// Reserved Handles.
const (
HandleOwner tpmutil.Handle = 0x40000001 + iota
HandleRevoke
HandleTransport
HandleOperator
HandleAdmin
HandleEK
HandleNull
HandleUnassigned
HandlePasswordSession
HandleLockout
HandleEndorsement
HandlePlatform
)
// Capability identifies some TPM property or state type.
type Capability uint32
// TPM Capabilities.
const (
CapabilityAlgs Capability = iota
CapabilityHandles
CapabilityCommands
CapabilityPPCommands
CapabilityAuditCommands
CapabilityPCRs
CapabilityTPMProperties
CapabilityPCRProperties
CapabilityECCCurves
CapabilityAuthPolicies
)
// TPM Structure Tags. Tags are used to disambiguate structures, similar to Alg
// values: tag value defines what kind of data lives in a nested field.
const (
TagNull tpmutil.Tag = 0x8000
TagNoSessions tpmutil.Tag = 0x8001
TagSessions tpmutil.Tag = 0x8002
TagAttestCertify tpmutil.Tag = 0x8017
TagAttestQuote tpmutil.Tag = 0x8018
TagAttestCreation tpmutil.Tag = 0x801a
TagAuthSecret tpmutil.Tag = 0x8023
TagHashCheck tpmutil.Tag = 0x8024
TagAuthSigned tpmutil.Tag = 0x8025
)
// StartupType instructs the TPM on how to handle its state during Shutdown or
// Startup.
type StartupType uint16
// Startup types
const (
StartupClear StartupType = iota
StartupState
)
// EllipticCurve identifies specific EC curves.
type EllipticCurve uint16
// ECC curves supported by TPM 2.0 spec.
const (
CurveNISTP192 = EllipticCurve(iota + 1)
CurveNISTP224
CurveNISTP256
CurveNISTP384
CurveNISTP521
CurveBNP256 = EllipticCurve(iota + 10)
CurveBNP638
CurveSM2P256 = EllipticCurve(0x0020)
)
var toGoCurve = map[EllipticCurve]elliptic.Curve{
CurveNISTP224: elliptic.P224(),
CurveNISTP256: elliptic.P256(),
CurveNISTP384: elliptic.P384(),
CurveNISTP521: elliptic.P521(),
}
// Supported TPM operations.
const (
CmdNVUndefineSpaceSpecial tpmutil.Command = 0x0000011F
CmdEvictControl tpmutil.Command = 0x00000120
CmdUndefineSpace tpmutil.Command = 0x00000122
CmdClear tpmutil.Command = 0x00000126
CmdHierarchyChangeAuth tpmutil.Command = 0x00000129
CmdDefineSpace tpmutil.Command = 0x0000012A
CmdCreatePrimary tpmutil.Command = 0x00000131
CmdIncrementNVCounter tpmutil.Command = 0x00000134
CmdWriteNV tpmutil.Command = 0x00000137
CmdWriteLockNV tpmutil.Command = 0x00000138
CmdDictionaryAttackLockReset tpmutil.Command = 0x00000139
CmdDictionaryAttackParameters tpmutil.Command = 0x0000013A
CmdPCREvent tpmutil.Command = 0x0000013C
CmdPCRReset tpmutil.Command = 0x0000013D
CmdSequenceComplete tpmutil.Command = 0x0000013E
CmdStartup tpmutil.Command = 0x00000144
CmdShutdown tpmutil.Command = 0x00000145
CmdActivateCredential tpmutil.Command = 0x00000147
CmdCertify tpmutil.Command = 0x00000148
CmdCertifyCreation tpmutil.Command = 0x0000014A
CmdReadNV tpmutil.Command = 0x0000014E
CmdReadLockNV tpmutil.Command = 0x0000014F
CmdPolicySecret tpmutil.Command = 0x00000151
CmdCreate tpmutil.Command = 0x00000153
CmdECDHZGen tpmutil.Command = 0x00000154
CmdImport tpmutil.Command = 0x00000156
CmdLoad tpmutil.Command = 0x00000157
CmdQuote tpmutil.Command = 0x00000158
CmdRSADecrypt tpmutil.Command = 0x00000159
CmdSequenceUpdate tpmutil.Command = 0x0000015C
CmdSign tpmutil.Command = 0x0000015D
CmdUnseal tpmutil.Command = 0x0000015E
CmdPolicySigned tpmutil.Command = 0x00000160
CmdContextLoad tpmutil.Command = 0x00000161
CmdContextSave tpmutil.Command = 0x00000162
CmdECDHKeyGen tpmutil.Command = 0x00000163
CmdEncryptDecrypt tpmutil.Command = 0x00000164
CmdFlushContext tpmutil.Command = 0x00000165
CmdLoadExternal tpmutil.Command = 0x00000167
CmdMakeCredential tpmutil.Command = 0x00000168
CmdReadPublicNV tpmutil.Command = 0x00000169
CmdPolicyCommandCode tpmutil.Command = 0x0000016C
CmdPolicyOr tpmutil.Command = 0x00000171
CmdReadPublic tpmutil.Command = 0x00000173
CmdRSAEncrypt tpmutil.Command = 0x00000174
CmdStartAuthSession tpmutil.Command = 0x00000176
CmdGetCapability tpmutil.Command = 0x0000017A
CmdGetRandom tpmutil.Command = 0x0000017B
CmdHash tpmutil.Command = 0x0000017D
CmdPCRRead tpmutil.Command = 0x0000017E
CmdPolicyPCR tpmutil.Command = 0x0000017F
CmdReadClock tpmutil.Command = 0x00000181
CmdPCRExtend tpmutil.Command = 0x00000182
CmdEventSequenceComplete tpmutil.Command = 0x00000185
CmdHashSequenceStart tpmutil.Command = 0x00000186
CmdPolicyGetDigest tpmutil.Command = 0x00000189
CmdPolicyPassword tpmutil.Command = 0x0000018C
CmdEncryptDecrypt2 tpmutil.Command = 0x00000193
)
// Regular TPM 2.0 devices use 24-bit mask (3 bytes) for PCR selection.
const sizeOfPCRSelect = 3
const defaultRSAExponent = 1<<16 + 1
// NVAttr is a bitmask used in Attributes field of NV indexes. Individual
// flags should be OR-ed to form a full mask.
type NVAttr uint32
// NV Attributes
const (
AttrPPWrite NVAttr = 0x00000001
AttrOwnerWrite NVAttr = 0x00000002
AttrAuthWrite NVAttr = 0x00000004
AttrPolicyWrite NVAttr = 0x00000008
AttrPolicyDelete NVAttr = 0x00000400
AttrWriteLocked NVAttr = 0x00000800
AttrWriteAll NVAttr = 0x00001000
AttrWriteDefine NVAttr = 0x00002000
AttrWriteSTClear NVAttr = 0x00004000
AttrGlobalLock NVAttr = 0x00008000
AttrPPRead NVAttr = 0x00010000
AttrOwnerRead NVAttr = 0x00020000
AttrAuthRead NVAttr = 0x00040000
AttrPolicyRead NVAttr = 0x00080000
AttrNoDA NVAttr = 0x02000000
AttrOrderly NVAttr = 0x04000000
AttrClearSTClear NVAttr = 0x08000000
AttrReadLocked NVAttr = 0x10000000
AttrWritten NVAttr = 0x20000000
AttrPlatformCreate NVAttr = 0x40000000
AttrReadSTClear NVAttr = 0x80000000
)
var permMap = map[NVAttr]string{
AttrPPWrite: "PPWrite",
AttrOwnerWrite: "OwnerWrite",
AttrAuthWrite: "AuthWrite",
AttrPolicyWrite: "PolicyWrite",
AttrPolicyDelete: "PolicyDelete",
AttrWriteLocked: "WriteLocked",
AttrWriteAll: "WriteAll",
AttrWriteDefine: "WriteDefine",
AttrWriteSTClear: "WriteSTClear",
AttrGlobalLock: "GlobalLock",
AttrPPRead: "PPRead",
AttrOwnerRead: "OwnerRead",
AttrAuthRead: "AuthRead",
AttrPolicyRead: "PolicyRead",
AttrNoDA: "No Do",
AttrOrderly: "Oderly",
AttrClearSTClear: "ClearSTClear",
AttrReadLocked: "ReadLocked",
AttrWritten: "Writte",
AttrPlatformCreate: "PlatformCreate",
AttrReadSTClear: "ReadSTClear",
}
// String returns a textual representation of the set of NVAttr
func (p NVAttr) String() string {
var retString strings.Builder
for iterator, item := range permMap {
if (p & iterator) != 0 {
retString.WriteString(item + " + ")
}
}
if retString.String() == "" {
return "Permission/s not found"
}
return strings.TrimSuffix(retString.String(), " + ")
}

362
vendor/github.com/google/go-tpm/legacy/tpm2/error.go generated vendored Normal file
View File

@ -0,0 +1,362 @@
package tpm2
import (
"fmt"
"github.com/google/go-tpm/tpmutil"
)
type (
// RCFmt0 holds Format 0 error codes
RCFmt0 uint8
// RCFmt1 holds Format 1 error codes
RCFmt1 uint8
// RCWarn holds error codes used in warnings
RCWarn uint8
// RCIndex is used to reference arguments, handles and sessions in errors
RCIndex uint8
)
// Format 0 error codes.
const (
RCInitialize RCFmt0 = 0x00
RCFailure RCFmt0 = 0x01
RCSequence RCFmt0 = 0x03
RCPrivate RCFmt0 = 0x0B
RCHMAC RCFmt0 = 0x19
RCDisabled RCFmt0 = 0x20
RCExclusive RCFmt0 = 0x21
RCAuthType RCFmt0 = 0x24
RCAuthMissing RCFmt0 = 0x25
RCPolicy RCFmt0 = 0x26
RCPCR RCFmt0 = 0x27
RCPCRChanged RCFmt0 = 0x28
RCUpgrade RCFmt0 = 0x2D
RCTooManyContexts RCFmt0 = 0x2E
RCAuthUnavailable RCFmt0 = 0x2F
RCReboot RCFmt0 = 0x30
RCUnbalanced RCFmt0 = 0x31
RCCommandSize RCFmt0 = 0x42
RCCommandCode RCFmt0 = 0x43
RCAuthSize RCFmt0 = 0x44
RCAuthContext RCFmt0 = 0x45
RCNVRange RCFmt0 = 0x46
RCNVSize RCFmt0 = 0x47
RCNVLocked RCFmt0 = 0x48
RCNVAuthorization RCFmt0 = 0x49
RCNVUninitialized RCFmt0 = 0x4A
RCNVSpace RCFmt0 = 0x4B
RCNVDefined RCFmt0 = 0x4C
RCBadContext RCFmt0 = 0x50
RCCPHash RCFmt0 = 0x51
RCParent RCFmt0 = 0x52
RCNeedsTest RCFmt0 = 0x53
RCNoResult RCFmt0 = 0x54
RCSensitive RCFmt0 = 0x55
)
var fmt0Msg = map[RCFmt0]string{
RCInitialize: "TPM not initialized by TPM2_Startup or already initialized",
RCFailure: "commands not being accepted because of a TPM failure",
RCSequence: "improper use of a sequence handle",
RCPrivate: "not currently used",
RCHMAC: "not currently used",
RCDisabled: "the command is disabled",
RCExclusive: "command failed because audit sequence required exclusivity",
RCAuthType: "authorization handle is not correct for command",
RCAuthMissing: "5 command requires an authorization session for handle and it is not present",
RCPolicy: "policy failure in math operation or an invalid authPolicy value",
RCPCR: "PCR check fail",
RCPCRChanged: "PCR have changed since checked",
RCUpgrade: "TPM is in field upgrade mode unless called via TPM2_FieldUpgradeData(), then it is not in field upgrade mode",
RCTooManyContexts: "context ID counter is at maximum",
RCAuthUnavailable: "authValue or authPolicy is not available for selected entity",
RCReboot: "a _TPM_Init and Startup(CLEAR) is required before the TPM can resume operation",
RCUnbalanced: "the protection algorithms (hash and symmetric) are not reasonably balanced; the digest size of the hash must be larger than the key size of the symmetric algorithm",
RCCommandSize: "command commandSize value is inconsistent with contents of the command buffer; either the size is not the same as the octets loaded by the hardware interface layer or the value is not large enough to hold a command header",
RCCommandCode: "command code not supported",
RCAuthSize: "the value of authorizationSize is out of range or the number of octets in the Authorization Area is greater than required",
RCAuthContext: "use of an authorization session with a context command or another command that cannot have an authorization session",
RCNVRange: "NV offset+size is out of range",
RCNVSize: "Requested allocation size is larger than allowed",
RCNVLocked: "NV access locked",
RCNVAuthorization: "NV access authorization fails in command actions",
RCNVUninitialized: "an NV Index is used before being initialized or the state saved by TPM2_Shutdown(STATE) could not be restored",
RCNVSpace: "insufficient space for NV allocation",
RCNVDefined: "NV Index or persistent object already defined",
RCBadContext: "context in TPM2_ContextLoad() is not valid",
RCCPHash: "cpHash value already set or not correct for use",
RCParent: "handle for parent is not a valid parent",
RCNeedsTest: "some function needs testing",
RCNoResult: "returned when an internal function cannot process a request due to an unspecified problem; this code is usually related to invalid parameters that are not properly filtered by the input unmarshaling code",
RCSensitive: "the sensitive area did not unmarshal correctly after decryption",
}
// Format 1 error codes.
const (
RCAsymmetric = 0x01
RCAttributes = 0x02
RCHash = 0x03
RCValue = 0x04
RCHierarchy = 0x05
RCKeySize = 0x07
RCMGF = 0x08
RCMode = 0x09
RCType = 0x0A
RCHandle = 0x0B
RCKDF = 0x0C
RCRange = 0x0D
RCAuthFail = 0x0E
RCNonce = 0x0F
RCPP = 0x10
RCScheme = 0x12
RCSize = 0x15
RCSymmetric = 0x16
RCTag = 0x17
RCSelector = 0x18
RCInsufficient = 0x1A
RCSignature = 0x1B
RCKey = 0x1C
RCPolicyFail = 0x1D
RCIntegrity = 0x1F
RCTicket = 0x20
RCReservedBits = 0x21
RCBadAuth = 0x22
RCExpired = 0x23
RCPolicyCC = 0x24
RCBinding = 0x25
RCCurve = 0x26
RCECCPoint = 0x27
)
var fmt1Msg = map[RCFmt1]string{
RCAsymmetric: "asymmetric algorithm not supported or not correct",
RCAttributes: "inconsistent attributes",
RCHash: "hash algorithm not supported or not appropriate",
RCValue: "value is out of range or is not correct for the context",
RCHierarchy: "hierarchy is not enabled or is not correct for the use",
RCKeySize: "key size is not supported",
RCMGF: "mask generation function not supported",
RCMode: "mode of operation not supported",
RCType: "the type of the value is not appropriate for the use",
RCHandle: "the handle is not correct for the use",
RCKDF: "unsupported key derivation function or function not appropriate for use",
RCRange: "value was out of allowed range",
RCAuthFail: "the authorization HMAC check failed and DA counter incremented",
RCNonce: "invalid nonce size or nonce value mismatch",
RCPP: "authorization requires assertion of PP",
RCScheme: "unsupported or incompatible scheme",
RCSize: "structure is the wrong size",
RCSymmetric: "unsupported symmetric algorithm or key size, or not appropriate for instance",
RCTag: "incorrect structure tag",
RCSelector: "union selector is incorrect",
RCInsufficient: "the TPM was unable to unmarshal a value because there were not enough octets in the input buffer",
RCSignature: "the signature is not valid",
RCKey: "key fields are not compatible with the selected use",
RCPolicyFail: "a policy check failed",
RCIntegrity: "integrity check failed",
RCTicket: "invalid ticket",
RCReservedBits: "reserved bits not set to zero as required",
RCBadAuth: "authorization failure without DA implications",
RCExpired: "the policy has expired",
RCPolicyCC: "the commandCode in the policy is not the commandCode of the command or the command code in a policy command references a command that is not implemented",
RCBinding: "public and sensitive portions of an object are not cryptographically bound",
RCCurve: "curve not supported",
RCECCPoint: "point is not on the required curve",
}
// Warning codes.
const (
RCContextGap RCWarn = 0x01
RCObjectMemory RCWarn = 0x02
RCSessionMemory RCWarn = 0x03
RCMemory RCWarn = 0x04
RCSessionHandles RCWarn = 0x05
RCObjectHandles RCWarn = 0x06
RCLocality RCWarn = 0x07
RCYielded RCWarn = 0x08
RCCanceled RCWarn = 0x09
RCTesting RCWarn = 0x0A
RCReferenceH0 RCWarn = 0x10
RCReferenceH1 RCWarn = 0x11
RCReferenceH2 RCWarn = 0x12
RCReferenceH3 RCWarn = 0x13
RCReferenceH4 RCWarn = 0x14
RCReferenceH5 RCWarn = 0x15
RCReferenceH6 RCWarn = 0x16
RCReferenceS0 RCWarn = 0x18
RCReferenceS1 RCWarn = 0x19
RCReferenceS2 RCWarn = 0x1A
RCReferenceS3 RCWarn = 0x1B
RCReferenceS4 RCWarn = 0x1C
RCReferenceS5 RCWarn = 0x1D
RCReferenceS6 RCWarn = 0x1E
RCNVRate RCWarn = 0x20
RCLockout RCWarn = 0x21
RCRetry RCWarn = 0x22
RCNVUnavailable RCWarn = 0x23
)
var warnMsg = map[RCWarn]string{
RCContextGap: "gap for context ID is too large",
RCObjectMemory: "out of memory for object contexts",
RCSessionMemory: "out of memory for session contexts",
RCMemory: "out of shared object/session memory or need space for internal operations",
RCSessionHandles: "out of session handles",
RCObjectHandles: "out of object handles",
RCLocality: "bad locality",
RCYielded: "the TPM has suspended operation on the command; forward progress was made and the command may be retried",
RCCanceled: "the command was canceled",
RCTesting: "TPM is performing self-tests",
RCReferenceH0: "the 1st handle in the handle area references a transient object or session that is not loaded",
RCReferenceH1: "the 2nd handle in the handle area references a transient object or session that is not loaded",
RCReferenceH2: "the 3rd handle in the handle area references a transient object or session that is not loaded",
RCReferenceH3: "the 4th handle in the handle area references a transient object or session that is not loaded",
RCReferenceH4: "the 5th handle in the handle area references a transient object or session that is not loaded",
RCReferenceH5: "the 6th handle in the handle area references a transient object or session that is not loaded",
RCReferenceH6: "the 7th handle in the handle area references a transient object or session that is not loaded",
RCReferenceS0: "the 1st authorization session handle references a session that is not loaded",
RCReferenceS1: "the 2nd authorization session handle references a session that is not loaded",
RCReferenceS2: "the 3rd authorization session handle references a session that is not loaded",
RCReferenceS3: "the 4th authorization session handle references a session that is not loaded",
RCReferenceS4: "the 5th authorization session handle references a session that is not loaded",
RCReferenceS5: "the 6th authorization session handle references a session that is not loaded",
RCReferenceS6: "the 7th authorization session handle references a session that is not loaded",
RCNVRate: "the TPM is rate-limiting accesses to prevent wearout of NV",
RCLockout: "authorizations for objects subject to DA protection are not allowed at this time because the TPM is in DA lockout mode",
RCRetry: "the TPM was not able to start the command",
RCNVUnavailable: "the command may require writing of NV and NV is not current accessible",
}
// Indexes for arguments, handles and sessions.
const (
RC1 RCIndex = iota + 1
RC2
RC3
RC4
RC5
RC6
RC7
RC8
RC9
RCA
RCB
RCC
RCD
RCE
RCF
)
const unknownCode = "unknown error code"
// Error is returned for all Format 0 errors from the TPM. It is used for general
// errors not specific to a parameter, handle or session.
type Error struct {
Code RCFmt0
}
func (e Error) Error() string {
msg := fmt0Msg[e.Code]
if msg == "" {
msg = unknownCode
}
return fmt.Sprintf("error code 0x%x : %s", e.Code, msg)
}
// VendorError represents a vendor-specific error response. These types of responses
// are not decoded and Code contains the complete response code.
type VendorError struct {
Code uint32
}
func (e VendorError) Error() string {
return fmt.Sprintf("vendor error code 0x%x", e.Code)
}
// Warning is typically used to report transient errors.
type Warning struct {
Code RCWarn
}
func (w Warning) Error() string {
msg := warnMsg[w.Code]
if msg == "" {
msg = unknownCode
}
return fmt.Sprintf("warning code 0x%x : %s", w.Code, msg)
}
// ParameterError describes an error related to a parameter, and the parameter number.
type ParameterError struct {
Code RCFmt1
Parameter RCIndex
}
func (e ParameterError) Error() string {
msg := fmt1Msg[e.Code]
if msg == "" {
msg = unknownCode
}
return fmt.Sprintf("parameter %d, error code 0x%x : %s", e.Parameter, e.Code, msg)
}
// HandleError describes an error related to a handle, and the handle number.
type HandleError struct {
Code RCFmt1
Handle RCIndex
}
func (e HandleError) Error() string {
msg := fmt1Msg[e.Code]
if msg == "" {
msg = unknownCode
}
return fmt.Sprintf("handle %d, error code 0x%x : %s", e.Handle, e.Code, msg)
}
// SessionError describes an error related to a session, and the session number.
type SessionError struct {
Code RCFmt1
Session RCIndex
}
func (e SessionError) Error() string {
msg := fmt1Msg[e.Code]
if msg == "" {
msg = unknownCode
}
return fmt.Sprintf("session %d, error code 0x%x : %s", e.Session, e.Code, msg)
}
// Decode a TPM2 response code and return the appropriate error. Logic
// according to the "Response Code Evaluation" chart in Part 1 of the TPM 2.0
// spec.
func decodeResponse(code tpmutil.ResponseCode) error {
if code == tpmutil.RCSuccess {
return nil
}
if code&0x180 == 0 { // Bits 7:8 == 0 is a TPM1 error
return fmt.Errorf("response status 0x%x", code)
}
if code&0x80 == 0 { // Bit 7 unset
if code&0x400 > 0 { // Bit 10 set, vendor specific code
return VendorError{uint32(code)}
}
if code&0x800 > 0 { // Bit 11 set, warning with code in bit 0:6
return Warning{RCWarn(code & 0x7f)}
}
// error with code in bit 0:6
return Error{RCFmt0(code & 0x7f)}
}
if code&0x40 > 0 { // Bit 6 set, code in 0:5, parameter number in 8:11
return ParameterError{RCFmt1(code & 0x3f), RCIndex((code & 0xf00) >> 8)}
}
if code&0x800 == 0 { // Bit 11 unset, code in 0:5, handle in 8:10
return HandleError{RCFmt1(code & 0x3f), RCIndex((code & 0x700) >> 8)}
}
// Code in 0:5, Session in 8:10
return SessionError{RCFmt1(code & 0x3f), RCIndex((code & 0x700) >> 8)}
}

116
vendor/github.com/google/go-tpm/legacy/tpm2/kdf.go generated vendored Normal file
View File

@ -0,0 +1,116 @@
// Copyright (c) 2018, Google LLC All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tpm2
import (
"crypto"
"crypto/hmac"
"encoding/binary"
"hash"
)
// KDFa implements TPM 2.0's default key derivation function, as defined in
// section 11.4.9.2 of the TPM revision 2 specification part 1.
// See: https://trustedcomputinggroup.org/resource/tpm-library-specification/
// The key & label parameters must not be zero length.
// The label parameter is a non-null-terminated string.
// The contextU & contextV parameters are optional.
// Deprecated: Use KDFaHash.
func KDFa(hashAlg Algorithm, key []byte, label string, contextU, contextV []byte, bits int) ([]byte, error) {
h, err := hashAlg.Hash()
if err != nil {
return nil, err
}
return KDFaHash(h, key, label, contextU, contextV, bits), nil
}
// KDFe implements TPM 2.0's ECDH key derivation function, as defined in
// section 11.4.9.3 of the TPM revision 2 specification part 1.
// See: https://trustedcomputinggroup.org/resource/tpm-library-specification/
// The z parameter is the x coordinate of one party's private ECC key multiplied
// by the other party's public ECC point.
// The use parameter is a non-null-terminated string.
// The partyUInfo and partyVInfo are the x coordinates of the initiator's and
// Deprecated: Use KDFeHash.
func KDFe(hashAlg Algorithm, z []byte, use string, partyUInfo, partyVInfo []byte, bits int) ([]byte, error) {
h, err := hashAlg.Hash()
if err != nil {
return nil, err
}
return KDFeHash(h, z, use, partyUInfo, partyVInfo, bits), nil
}
// KDFaHash implements TPM 2.0's default key derivation function, as defined in
// section 11.4.9.2 of the TPM revision 2 specification part 1.
// See: https://trustedcomputinggroup.org/resource/tpm-library-specification/
// The key & label parameters must not be zero length.
// The label parameter is a non-null-terminated string.
// The contextU & contextV parameters are optional.
func KDFaHash(h crypto.Hash, key []byte, label string, contextU, contextV []byte, bits int) []byte {
mac := hmac.New(h.New, key)
out := kdf(mac, bits, func() {
mac.Write([]byte(label))
mac.Write([]byte{0}) // Terminating null character for C-string.
mac.Write(contextU)
mac.Write(contextV)
binary.Write(mac, binary.BigEndian, uint32(bits))
})
return out
}
// KDFeHash implements TPM 2.0's ECDH key derivation function, as defined in
// section 11.4.9.3 of the TPM revision 2 specification part 1.
// See: https://trustedcomputinggroup.org/resource/tpm-library-specification/
// The z parameter is the x coordinate of one party's private ECC key multiplied
// by the other party's public ECC point.
// The use parameter is a non-null-terminated string.
// The partyUInfo and partyVInfo are the x coordinates of the initiator's and
// the responder's ECC points, respectively.
func KDFeHash(h crypto.Hash, z []byte, use string, partyUInfo, partyVInfo []byte, bits int) []byte {
hash := h.New()
out := kdf(hash, bits, func() {
hash.Write(z)
hash.Write([]byte(use))
hash.Write([]byte{0}) // Terminating null character for C-string.
hash.Write(partyUInfo)
hash.Write(partyVInfo)
})
return out
}
func kdf(h hash.Hash, bits int, update func()) []byte {
bytes := (bits + 7) / 8
out := []byte{}
for counter := 1; len(out) < bytes; counter++ {
h.Reset()
binary.Write(h, binary.BigEndian, uint32(counter))
update()
out = h.Sum(out)
}
// out's length is a multiple of hash size, so there will be excess
// bytes if bytes isn't a multiple of hash size.
out = out[:bytes]
// As mentioned in the KDFa and KDFe specs mentioned above,
// the unused bits of the most significant octet are masked off.
if maskBits := uint8(bits % 8); maskBits > 0 {
out[0] &= (1 << maskBits) - 1
}
return out
}

View File

@ -0,0 +1,57 @@
//go:build !windows
// Copyright (c) 2019, Google LLC All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tpm2
import (
"errors"
"fmt"
"io"
"os"
"github.com/google/go-tpm/tpmutil"
)
// OpenTPM opens a channel to the TPM at the given path. If the file is a
// device, then it treats it like a normal TPM device, and if the file is a
// Unix domain socket, then it opens a connection to the socket.
//
// This function may also be invoked with no paths, as tpm2.OpenTPM(). In this
// case, the default paths on Linux (/dev/tpmrm0 then /dev/tpm0), will be used.
func OpenTPM(path ...string) (tpm io.ReadWriteCloser, err error) {
switch len(path) {
case 0:
tpm, err = tpmutil.OpenTPM("/dev/tpmrm0")
if errors.Is(err, os.ErrNotExist) {
tpm, err = tpmutil.OpenTPM("/dev/tpm0")
}
case 1:
tpm, err = tpmutil.OpenTPM(path[0])
default:
return nil, errors.New("cannot specify multiple paths to tpm2.OpenTPM")
}
if err != nil {
return nil, err
}
// Make sure this is a TPM 2.0
_, err = GetManufacturer(tpm)
if err != nil {
tpm.Close()
return nil, fmt.Errorf("open %s: device is not a TPM 2.0", path)
}
return tpm, nil
}

View File

@ -0,0 +1,39 @@
//go:build windows
// Copyright (c) 2018, Google LLC All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tpm2
import (
"fmt"
"io"
"github.com/google/go-tpm/tpmutil"
"github.com/google/go-tpm/tpmutil/tbs"
)
// OpenTPM opens a channel to the TPM.
func OpenTPM() (io.ReadWriteCloser, error) {
info, err := tbs.GetDeviceInfo()
if err != nil {
return nil, err
}
if info.TPMVersion != tbs.TPMVersion20 {
return nil, fmt.Errorf("openTPM: device is not a TPM 2.0")
}
return tpmutil.OpenTPM()
}

1112
vendor/github.com/google/go-tpm/legacy/tpm2/structures.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

2326
vendor/github.com/google/go-tpm/legacy/tpm2/tpm2.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

211
vendor/github.com/google/go-tpm/tpmutil/encoding.go generated vendored Normal file
View File

@ -0,0 +1,211 @@
// Copyright (c) 2018, Google LLC All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tpmutil
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"reflect"
)
var (
selfMarshalerType = reflect.TypeOf((*SelfMarshaler)(nil)).Elem()
handlesAreaType = reflect.TypeOf((*[]Handle)(nil))
)
// packWithHeader takes a header and a sequence of elements that are either of
// fixed length or slices of fixed-length types and packs them into a single
// byte array using binary.Write. It updates the CommandHeader to have the right
// length.
func packWithHeader(ch commandHeader, cmd ...interface{}) ([]byte, error) {
hdrSize := binary.Size(ch)
body, err := Pack(cmd...)
if err != nil {
return nil, fmt.Errorf("couldn't pack message body: %v", err)
}
bodySize := len(body)
ch.Size = uint32(hdrSize + bodySize)
header, err := Pack(ch)
if err != nil {
return nil, fmt.Errorf("couldn't pack message header: %v", err)
}
return append(header, body...), nil
}
// Pack encodes a set of elements into a single byte array, using
// encoding/binary. This means that all the elements must be encodeable
// according to the rules of encoding/binary.
//
// It has one difference from encoding/binary: it encodes byte slices with a
// prepended length, to match how the TPM encodes variable-length arrays. If
// you wish to add a byte slice without length prefix, use RawBytes.
func Pack(elts ...interface{}) ([]byte, error) {
buf := new(bytes.Buffer)
if err := packType(buf, elts...); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// tryMarshal attempts to use a TPMMarshal() method defined on the type
// to pack v into buf. True is returned if the method exists and the
// marshal was attempted.
func tryMarshal(buf io.Writer, v reflect.Value) (bool, error) {
t := v.Type()
if t.Implements(selfMarshalerType) {
if v.Kind() == reflect.Ptr && v.IsNil() {
return true, fmt.Errorf("cannot call TPMMarshal on a nil pointer of type %T", v)
}
return true, v.Interface().(SelfMarshaler).TPMMarshal(buf)
}
// We might have a non-pointer struct field, but we dont have a
// pointer with which to implement the interface.
// If the pointer of the type implements the interface, we should be
// able to construct a value to call TPMMarshal() with.
// TODO(awly): Try and avoid blowing away private data by using Addr() instead of Set()
if reflect.PtrTo(t).Implements(selfMarshalerType) {
tmp := reflect.New(t)
tmp.Elem().Set(v)
return true, tmp.Interface().(SelfMarshaler).TPMMarshal(buf)
}
return false, nil
}
func packValue(buf io.Writer, v reflect.Value) error {
if v.Type() == handlesAreaType {
v = v.Convert(reflect.TypeOf((*handleList)(nil)))
}
if canMarshal, err := tryMarshal(buf, v); canMarshal {
return err
}
switch v.Kind() {
case reflect.Ptr:
if v.IsNil() {
return fmt.Errorf("cannot pack nil %s", v.Type().String())
}
return packValue(buf, v.Elem())
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
f := v.Field(i)
if err := packValue(buf, f); err != nil {
return err
}
}
default:
return binary.Write(buf, binary.BigEndian, v.Interface())
}
return nil
}
func packType(buf io.Writer, elts ...interface{}) error {
for _, e := range elts {
if err := packValue(buf, reflect.ValueOf(e)); err != nil {
return err
}
}
return nil
}
// tryUnmarshal attempts to use TPMUnmarshal() to perform the
// unpack, if the given value implements SelfMarshaler.
// True is returned if v implements SelfMarshaler & TPMUnmarshal
// was called, along with an error returned from TPMUnmarshal.
func tryUnmarshal(buf io.Reader, v reflect.Value) (bool, error) {
t := v.Type()
if t.Implements(selfMarshalerType) {
if v.Kind() == reflect.Ptr && v.IsNil() {
return true, fmt.Errorf("cannot call TPMUnmarshal on a nil pointer")
}
return true, v.Interface().(SelfMarshaler).TPMUnmarshal(buf)
}
// We might have a non-pointer struct field, which is addressable,
// If the pointer of the type implements the interface, and the
// value is addressable, we should be able to call TPMUnmarshal().
if v.CanAddr() && reflect.PtrTo(t).Implements(selfMarshalerType) {
return true, v.Addr().Interface().(SelfMarshaler).TPMUnmarshal(buf)
}
return false, nil
}
// Unpack is a convenience wrapper around UnpackBuf. Unpack returns the number
// of bytes read from b to fill elts and error, if any.
func Unpack(b []byte, elts ...interface{}) (int, error) {
buf := bytes.NewBuffer(b)
err := UnpackBuf(buf, elts...)
read := len(b) - buf.Len()
return read, err
}
func unpackValue(buf io.Reader, v reflect.Value) error {
if v.Type() == handlesAreaType {
v = v.Convert(reflect.TypeOf((*handleList)(nil)))
}
if didUnmarshal, err := tryUnmarshal(buf, v); didUnmarshal {
return err
}
switch v.Kind() {
case reflect.Ptr:
if v.IsNil() {
return fmt.Errorf("cannot unpack nil %s", v.Type().String())
}
return unpackValue(buf, v.Elem())
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
f := v.Field(i)
if err := unpackValue(buf, f); err != nil {
return err
}
}
return nil
default:
// binary.Read can only set pointer values, so we need to take the address.
if !v.CanAddr() {
return fmt.Errorf("cannot unpack unaddressable leaf type %q", v.Type().String())
}
return binary.Read(buf, binary.BigEndian, v.Addr().Interface())
}
}
// UnpackBuf recursively unpacks types from a reader just as encoding/binary
// does under binary.BigEndian, but with one difference: it unpacks a byte
// slice by first reading an integer with lengthPrefixSize bytes, then reading
// that many bytes. It assumes that incoming values are pointers to values so
// that, e.g., underlying slices can be resized as needed.
func UnpackBuf(buf io.Reader, elts ...interface{}) error {
for _, e := range elts {
v := reflect.ValueOf(e)
if v.Kind() != reflect.Ptr {
return fmt.Errorf("non-pointer value %q passed to UnpackBuf", v.Type().String())
}
if v.IsNil() {
return errors.New("nil pointer passed to UnpackBuf")
}
if err := unpackValue(buf, v); err != nil {
return err
}
}
return nil
}

10
vendor/github.com/google/go-tpm/tpmutil/poll_other.go generated vendored Normal file
View File

@ -0,0 +1,10 @@
//go:build !linux && !darwin
package tpmutil
import (
"os"
)
// Not implemented on Windows.
func poll(_ *os.File) error { return nil }

32
vendor/github.com/google/go-tpm/tpmutil/poll_unix.go generated vendored Normal file
View File

@ -0,0 +1,32 @@
//go:build linux || darwin
package tpmutil
import (
"fmt"
"os"
"golang.org/x/sys/unix"
)
// poll blocks until the file descriptor is ready for reading or an error occurs.
func poll(f *os.File) error {
var (
fds = []unix.PollFd{{
Fd: int32(f.Fd()),
Events: 0x1, // POLLIN
}}
timeout = -1 // Indefinite timeout
)
if _, err := unix.Poll(fds, timeout); err != nil {
return err
}
// Revents is filled in by the kernel.
// If the expected event happened, Revents should match Events.
if fds[0].Revents != fds[0].Events {
return fmt.Errorf("unexpected poll Revents 0x%x", fds[0].Revents)
}
return nil
}

113
vendor/github.com/google/go-tpm/tpmutil/run.go generated vendored Normal file
View File

@ -0,0 +1,113 @@
// Copyright (c) 2018, Google LLC All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package tpmutil provides common utility functions for both TPM 1.2 and TPM
// 2.0 devices.
package tpmutil
import (
"errors"
"io"
"os"
"time"
)
// maxTPMResponse is the largest possible response from the TPM. We need to know
// this because we don't always know the length of the TPM response, and
// /dev/tpm insists on giving it all back in a single value rather than
// returning a header and a body in separate responses.
const maxTPMResponse = 4096
// RunCommandRaw executes the given raw command and returns the raw response.
// Does not check the response code except to execute retry logic.
func RunCommandRaw(rw io.ReadWriter, inb []byte) ([]byte, error) {
if rw == nil {
return nil, errors.New("nil TPM handle")
}
// f(t) = (2^t)ms, up to 2s
var backoffFac uint
var rh responseHeader
var outb []byte
for {
if _, err := rw.Write(inb); err != nil {
return nil, err
}
// If the TPM is a real device, it may not be ready for reading
// immediately after writing the command. Wait until the file
// descriptor is ready to be read from.
if f, ok := rw.(*os.File); ok {
if err := poll(f); err != nil {
return nil, err
}
}
outb = make([]byte, maxTPMResponse)
outlen, err := rw.Read(outb)
if err != nil {
return nil, err
}
// Resize the buffer to match the amount read from the TPM.
outb = outb[:outlen]
_, err = Unpack(outb, &rh)
if err != nil {
return nil, err
}
// If TPM is busy, retry the command after waiting a few ms.
if rh.Res == RCRetry {
if backoffFac < 11 {
dur := (1 << backoffFac) * time.Millisecond
time.Sleep(dur)
backoffFac++
} else {
return nil, err
}
} else {
break
}
}
return outb, nil
}
// RunCommand executes cmd with given tag and arguments. Returns TPM response
// body (without response header) and response code from the header. Returned
// error may be nil if response code is not RCSuccess; caller should check
// both.
func RunCommand(rw io.ReadWriter, tag Tag, cmd Command, in ...interface{}) ([]byte, ResponseCode, error) {
inb, err := packWithHeader(commandHeader{tag, 0, cmd}, in...)
if err != nil {
return nil, 0, err
}
outb, err := RunCommandRaw(rw, inb)
if err != nil {
return nil, 0, err
}
var rh responseHeader
read, err := Unpack(outb, &rh)
if err != nil {
return nil, 0, err
}
if rh.Res != RCSuccess {
return nil, rh.Res, nil
}
return outb[read:], rh.Res, nil
}

111
vendor/github.com/google/go-tpm/tpmutil/run_other.go generated vendored Normal file
View File

@ -0,0 +1,111 @@
//go:build !windows
// Copyright (c) 2018, Google LLC All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tpmutil
import (
"fmt"
"io"
"net"
"os"
)
// OpenTPM opens a channel to the TPM at the given path. If the file is a
// device, then it treats it like a normal TPM device, and if the file is a
// Unix domain socket, then it opens a connection to the socket.
func OpenTPM(path string) (io.ReadWriteCloser, error) {
// If it's a regular file, then open it
var rwc io.ReadWriteCloser
fi, err := os.Stat(path)
if err != nil {
return nil, err
}
if fi.Mode()&os.ModeDevice != 0 {
var f *os.File
f, err = os.OpenFile(path, os.O_RDWR, 0600)
if err != nil {
return nil, err
}
rwc = io.ReadWriteCloser(f)
} else if fi.Mode()&os.ModeSocket != 0 {
rwc = NewEmulatorReadWriteCloser(path)
} else {
return nil, fmt.Errorf("unsupported TPM file mode %s", fi.Mode().String())
}
return rwc, nil
}
// dialer abstracts the net.Dial call so test code can provide its own net.Conn
// implementation.
type dialer func(network, path string) (net.Conn, error)
// EmulatorReadWriteCloser manages connections with a TPM emulator over a Unix
// domain socket. These emulators often operate in a write/read/disconnect
// sequence, so the Write method always connects, and the Read method always
// closes. EmulatorReadWriteCloser is not thread safe.
type EmulatorReadWriteCloser struct {
path string
conn net.Conn
dialer dialer
}
// NewEmulatorReadWriteCloser stores information about a Unix domain socket to
// write to and read from.
func NewEmulatorReadWriteCloser(path string) *EmulatorReadWriteCloser {
return &EmulatorReadWriteCloser{
path: path,
dialer: net.Dial,
}
}
// Read implements io.Reader by reading from the Unix domain socket and closing
// it.
func (erw *EmulatorReadWriteCloser) Read(p []byte) (int, error) {
// Read is always the second operation in a Write/Read sequence.
if erw.conn == nil {
return 0, fmt.Errorf("must call Write then Read in an alternating sequence")
}
n, err := erw.conn.Read(p)
erw.conn.Close()
erw.conn = nil
return n, err
}
// Write implements io.Writer by connecting to the Unix domain socket and
// writing.
func (erw *EmulatorReadWriteCloser) Write(p []byte) (int, error) {
if erw.conn != nil {
return 0, fmt.Errorf("must call Write then Read in an alternating sequence")
}
var err error
erw.conn, err = erw.dialer("unix", erw.path)
if err != nil {
return 0, err
}
return erw.conn.Write(p)
}
// Close implements io.Closer by closing the Unix domain socket if one is open.
func (erw *EmulatorReadWriteCloser) Close() error {
if erw.conn == nil {
return fmt.Errorf("cannot call Close when no connection is open")
}
err := erw.conn.Close()
erw.conn = nil
return err
}

84
vendor/github.com/google/go-tpm/tpmutil/run_windows.go generated vendored Normal file
View File

@ -0,0 +1,84 @@
// Copyright (c) 2018, Google LLC All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tpmutil
import (
"io"
"github.com/google/go-tpm/tpmutil/tbs"
)
// winTPMBuffer is a ReadWriteCloser to access the TPM in Windows.
type winTPMBuffer struct {
context tbs.Context
outBuffer []byte
}
// Executes the TPM command specified by commandBuffer (at Normal Priority), returning the number
// of bytes in the command and any error code returned by executing the TPM command. Command
// response can be read by calling Read().
func (rwc *winTPMBuffer) Write(commandBuffer []byte) (int, error) {
// TPM spec defines longest possible response to be maxTPMResponse.
rwc.outBuffer = rwc.outBuffer[:maxTPMResponse]
outBufferLen, err := rwc.context.SubmitCommand(
tbs.NormalPriority,
commandBuffer,
rwc.outBuffer,
)
if err != nil {
rwc.outBuffer = rwc.outBuffer[:0]
return 0, err
}
// Shrink outBuffer so it is length of response.
rwc.outBuffer = rwc.outBuffer[:outBufferLen]
return len(commandBuffer), nil
}
// Provides TPM response from the command called in the last Write call.
func (rwc *winTPMBuffer) Read(responseBuffer []byte) (int, error) {
if len(rwc.outBuffer) == 0 {
return 0, io.EOF
}
lenCopied := copy(responseBuffer, rwc.outBuffer)
// Cut out the piece of slice which was just read out, maintaining original slice capacity.
rwc.outBuffer = append(rwc.outBuffer[:0], rwc.outBuffer[lenCopied:]...)
return lenCopied, nil
}
func (rwc *winTPMBuffer) Close() error {
return rwc.context.Close()
}
// OpenTPM creates a new instance of a ReadWriteCloser which can interact with a
// Windows TPM.
func OpenTPM() (io.ReadWriteCloser, error) {
tpmContext, err := tbs.CreateContext(tbs.TPMVersion20, tbs.IncludeTPM12|tbs.IncludeTPM20)
rwc := &winTPMBuffer{
context: tpmContext,
outBuffer: make([]byte, 0, maxTPMResponse),
}
return rwc, err
}
// FromContext creates a new instance of a ReadWriteCloser which can
// interact with a Windows TPM, using the specified TBS handle.
func FromContext(ctx tbs.Context) io.ReadWriteCloser {
return &winTPMBuffer{
context: ctx,
outBuffer: make([]byte, 0, maxTPMResponse),
}
}

195
vendor/github.com/google/go-tpm/tpmutil/structures.go generated vendored Normal file
View File

@ -0,0 +1,195 @@
// Copyright (c) 2018, Google LLC All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tpmutil
import (
"bytes"
"encoding/binary"
"fmt"
"io"
)
// maxBytesBufferSize sets a sane upper bound on the size of a U32Bytes
// buffer. This limit exists to prevent a maliciously large size prefix
// from resulting in a massive memory allocation, potentially causing
// an OOM condition on the system.
// We expect no buffer from a TPM to approach 1Mb in size.
const maxBytesBufferSize uint32 = 1024 * 1024 // 1Mb.
// RawBytes is for Pack and RunCommand arguments that are already encoded.
// Compared to []byte, RawBytes will not be prepended with slice length during
// encoding.
type RawBytes []byte
// U16Bytes is a byte slice with a 16-bit header
type U16Bytes []byte
// TPMMarshal packs U16Bytes
func (b *U16Bytes) TPMMarshal(out io.Writer) error {
size := len([]byte(*b))
if err := binary.Write(out, binary.BigEndian, uint16(size)); err != nil {
return err
}
n, err := out.Write(*b)
if err != nil {
return err
}
if n != size {
return fmt.Errorf("unable to write all contents of U16Bytes")
}
return nil
}
// TPMUnmarshal unpacks a U16Bytes
func (b *U16Bytes) TPMUnmarshal(in io.Reader) error {
var tmpSize uint16
if err := binary.Read(in, binary.BigEndian, &tmpSize); err != nil {
return err
}
size := int(tmpSize)
if len(*b) >= size {
*b = (*b)[:size]
} else {
*b = append(*b, make([]byte, size-len(*b))...)
}
n, err := in.Read(*b)
if err != nil {
return err
}
if n != size {
return io.ErrUnexpectedEOF
}
return nil
}
// U32Bytes is a byte slice with a 32-bit header
type U32Bytes []byte
// TPMMarshal packs U32Bytes
func (b *U32Bytes) TPMMarshal(out io.Writer) error {
size := len([]byte(*b))
if err := binary.Write(out, binary.BigEndian, uint32(size)); err != nil {
return err
}
n, err := out.Write(*b)
if err != nil {
return err
}
if n != size {
return fmt.Errorf("unable to write all contents of U32Bytes")
}
return nil
}
// TPMUnmarshal unpacks a U32Bytes
func (b *U32Bytes) TPMUnmarshal(in io.Reader) error {
var tmpSize uint32
if err := binary.Read(in, binary.BigEndian, &tmpSize); err != nil {
return err
}
if tmpSize > maxBytesBufferSize {
return bytes.ErrTooLarge
}
// We can now safely cast to an int on 32-bit or 64-bit machines
size := int(tmpSize)
if len(*b) >= size {
*b = (*b)[:size]
} else {
*b = append(*b, make([]byte, size-len(*b))...)
}
n, err := in.Read(*b)
if err != nil {
return err
}
if n != size {
return fmt.Errorf("unable to read all contents in to U32Bytes")
}
return nil
}
// Tag is a command tag.
type Tag uint16
// Command is an identifier of a TPM command.
type Command uint32
// A commandHeader is the header for a TPM command.
type commandHeader struct {
Tag Tag
Size uint32
Cmd Command
}
// ResponseCode is a response code returned by TPM.
type ResponseCode uint32
// RCSuccess is response code for successful command. Identical for TPM 1.2 and
// 2.0.
const RCSuccess ResponseCode = 0x000
// RCRetry is response code for TPM is busy.
const RCRetry ResponseCode = 0x922
// A responseHeader is a header for TPM responses.
type responseHeader struct {
Tag Tag
Size uint32
Res ResponseCode
}
// A Handle is a reference to a TPM object.
type Handle uint32
// HandleValue returns the handle value. This behavior is intended to satisfy
// an interface that can be implemented by other, more complex types as well.
func (h Handle) HandleValue() uint32 {
return uint32(h)
}
type handleList []Handle
func (l *handleList) TPMMarshal(_ io.Writer) error {
return fmt.Errorf("TPMMarhsal on []Handle is not supported yet")
}
func (l *handleList) TPMUnmarshal(in io.Reader) error {
var numHandles uint16
if err := binary.Read(in, binary.BigEndian, &numHandles); err != nil {
return err
}
// Make len(e) match size exactly.
size := int(numHandles)
if len(*l) >= size {
*l = (*l)[:size]
} else {
*l = append(*l, make([]Handle, size-len(*l))...)
}
return binary.Read(in, binary.BigEndian, *l)
}
// SelfMarshaler allows custom types to override default encoding/decoding
// behavior in Pack, Unpack and UnpackBuf.
type SelfMarshaler interface {
TPMMarshal(out io.Writer) error
TPMUnmarshal(in io.Reader) error
}

View File

@ -0,0 +1,267 @@
// Copyright (c) 2018, Google LLC All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package tbs provides an low-level interface directly mapping to Windows
// Tbs.dll system library commands:
// https://docs.microsoft.com/en-us/windows/desktop/TBS/tpm-base-services-portal
// Public field descriptions contain links to the high-level Windows documentation.
package tbs
import (
"fmt"
"syscall"
"unsafe"
)
// Context references the current TPM context
type Context uintptr
// Version of TPM being used by the application.
type Version uint32
// Flag indicates TPM versions that are supported by the application.
type Flag uint32
// CommandPriority is used to determine which pending command to submit whenever the TPM is free.
type CommandPriority uint32
// Command parameters:
// https://github.com/tpn/winsdk-10/blob/master/Include/10.0.10240.0/shared/tbs.h
const (
// https://docs.microsoft.com/en-us/windows/desktop/api/Tbs/ns-tbs-tdtbs_context_params2
// OR flags to use multiple.
RequestRaw Flag = 1 << iota // Add flag to request raw context
IncludeTPM12 // Add flag to support TPM 1.2
IncludeTPM20 // Add flag to support TPM 2
TPMVersion12 Version = 1 // For TPM 1.2 applications
TPMVersion20 Version = 2 // For TPM 2 applications or applications using multiple TPM versions
// https://docs.microsoft.com/en-us/windows/desktop/tbs/command-scheduling
// https://docs.microsoft.com/en-us/windows/desktop/api/Tbs/nf-tbs-tbsip_submit_command#parameters
LowPriority CommandPriority = 100 // For low priority application use
NormalPriority CommandPriority = 200 // For normal priority application use
HighPriority CommandPriority = 300 // For high priority application use
SystemPriority CommandPriority = 400 // For system tasks that access the TPM
commandLocalityZero uint32 = 0 // Windows currently only supports TBS_COMMAND_LOCALITY_ZERO.
)
// Error is the return type of all functions in this package.
type Error uint32
func (err Error) Error() string {
if description, ok := errorDescriptions[err]; ok {
return fmt.Sprintf("TBS Error 0x%X: %s", uint32(err), description)
}
return fmt.Sprintf("Unrecognized TBS Error 0x%X", uint32(err))
}
func getError(err uintptr) error {
// tbs.dll uses 0x0 as the return value for success.
if err == 0 {
return nil
}
return Error(err)
}
// TBS Return Codes:
// https://docs.microsoft.com/en-us/windows/desktop/TBS/tbs-return-codes
const (
ErrInternalError Error = 0x80284001
ErrBadParameter Error = 0x80284002
ErrInvalidOutputPointer Error = 0x80284003
ErrInvalidContext Error = 0x80284004
ErrInsufficientBuffer Error = 0x80284005
ErrIOError Error = 0x80284006
ErrInvalidContextParam Error = 0x80284007
ErrServiceNotRunning Error = 0x80284008
ErrTooManyTBSContexts Error = 0x80284009
ErrTooManyResources Error = 0x8028400A
ErrServiceStartPending Error = 0x8028400B
ErrPPINotSupported Error = 0x8028400C
ErrCommandCanceled Error = 0x8028400D
ErrBufferTooLarge Error = 0x8028400E
ErrTPMNotFound Error = 0x8028400F
ErrServiceDisabled Error = 0x80284010
ErrNoEventLog Error = 0x80284011
ErrAccessDenied Error = 0x80284012
ErrProvisioningNotAllowed Error = 0x80284013
ErrPPIFunctionUnsupported Error = 0x80284014
ErrOwnerauthNotFound Error = 0x80284015
)
var errorDescriptions = map[Error]string{
ErrInternalError: "An internal software error occurred.",
ErrBadParameter: "One or more parameter values are not valid.",
ErrInvalidOutputPointer: "A specified output pointer is bad.",
ErrInvalidContext: "The specified context handle does not refer to a valid context.",
ErrInsufficientBuffer: "The specified output buffer is too small.",
ErrIOError: "An error occurred while communicating with the TPM.",
ErrInvalidContextParam: "A context parameter that is not valid was passed when attempting to create a TBS context.",
ErrServiceNotRunning: "The TBS service is not running and could not be started.",
ErrTooManyTBSContexts: "A new context could not be created because there are too many open contexts.",
ErrTooManyResources: "A new virtual resource could not be created because there are too many open virtual resources.",
ErrServiceStartPending: "The TBS service has been started but is not yet running.",
ErrPPINotSupported: "The physical presence interface is not supported.",
ErrCommandCanceled: "The command was canceled.",
ErrBufferTooLarge: "The input or output buffer is too large.",
ErrTPMNotFound: "A compatible Trusted Platform Module (TPM) Security Device cannot be found on this computer.",
ErrServiceDisabled: "The TBS service has been disabled.",
ErrNoEventLog: "The TBS event log is not available.",
ErrAccessDenied: "The caller does not have the appropriate rights to perform the requested operation.",
ErrProvisioningNotAllowed: "The TPM provisioning action is not allowed by the specified flags.",
ErrPPIFunctionUnsupported: "The Physical Presence Interface of this firmware does not support the requested method.",
ErrOwnerauthNotFound: "The requested TPM OwnerAuth value was not found.",
}
// Tbs.dll provides an API for making calls to the TPM:
// https://docs.microsoft.com/en-us/windows/desktop/TBS/tpm-base-services-portal
var (
tbsDLL = syscall.NewLazyDLL("Tbs.dll")
tbsGetDeviceInfo = tbsDLL.NewProc("Tbsi_GetDeviceInfo")
tbsCreateContext = tbsDLL.NewProc("Tbsi_Context_Create")
tbsContextClose = tbsDLL.NewProc("Tbsip_Context_Close")
tbsSubmitCommand = tbsDLL.NewProc("Tbsip_Submit_Command")
tbsGetTCGLog = tbsDLL.NewProc("Tbsi_Get_TCG_Log")
)
// Returns the address of the beginning of a slice or 0 for a nil slice.
func sliceAddress(s []byte) uintptr {
if len(s) == 0 {
return 0
}
return uintptr(unsafe.Pointer(&(s[0])))
}
// DeviceInfo is TPM_DEVICE_INFO from tbs.h
type DeviceInfo struct {
StructVersion uint32
TPMVersion Version
TPMInterfaceType uint32
TPMImpRevision uint32
}
// GetDeviceInfo gets the DeviceInfo of the current TPM:
// https://docs.microsoft.com/en-us/windows/win32/api/tbs/nf-tbs-tbsi_getdeviceinfo
func GetDeviceInfo() (*DeviceInfo, error) {
info := DeviceInfo{}
// TBS_RESULT Tbsi_GetDeviceInfo(
// UINT32 Size,
// PVOID Info
// );
if err := tbsGetDeviceInfo.Find(); err != nil {
return nil, err
}
result, _, _ := tbsGetDeviceInfo.Call(
unsafe.Sizeof(info),
uintptr(unsafe.Pointer(&info)),
)
return &info, getError(result)
}
// CreateContext creates a new TPM context:
// https://docs.microsoft.com/en-us/windows/desktop/api/Tbs/nf-tbs-tbsi_context_create
func CreateContext(version Version, flag Flag) (Context, error) {
var context Context
params := struct {
Version
Flag
}{version, flag}
// TBS_RESULT Tbsi_Context_Create(
// _In_ PCTBS_CONTEXT_PARAMS pContextParams,
// _Out_ PTBS_HCONTEXT *phContext
// );
if err := tbsCreateContext.Find(); err != nil {
return context, err
}
result, _, _ := tbsCreateContext.Call(
uintptr(unsafe.Pointer(&params)),
uintptr(unsafe.Pointer(&context)),
)
return context, getError(result)
}
// Close closes an existing TPM context:
// https://docs.microsoft.com/en-us/windows/desktop/api/Tbs/nf-tbs-tbsip_context_close
func (context Context) Close() error {
// TBS_RESULT Tbsip_Context_Close(
// _In_ TBS_HCONTEXT hContext
// );
if err := tbsContextClose.Find(); err != nil {
return err
}
result, _, _ := tbsContextClose.Call(uintptr(context))
return getError(result)
}
// SubmitCommand sends commandBuffer to the TPM, returning the number of bytes
// written to responseBuffer. ErrInsufficientBuffer is returned if the
// responseBuffer is too short. ErrInvalidOutputPointer is returned if the
// responseBuffer is nil. On failure, the returned length is unspecified.
// https://docs.microsoft.com/en-us/windows/desktop/api/Tbs/nf-tbs-tbsip_submit_command
func (context Context) SubmitCommand(
priority CommandPriority,
commandBuffer []byte,
responseBuffer []byte,
) (uint32, error) {
responseBufferLen := uint32(len(responseBuffer))
// TBS_RESULT Tbsip_Submit_Command(
// _In_ TBS_HCONTEXT hContext,
// _In_ TBS_COMMAND_LOCALITY Locality,
// _In_ TBS_COMMAND_PRIORITY Priority,
// _In_ const PCBYTE *pabCommand,
// _In_ UINT32 cbCommand,
// _Out_ PBYTE *pabResult,
// _Inout_ UINT32 *pcbOutput
// );
if err := tbsSubmitCommand.Find(); err != nil {
return 0, err
}
result, _, _ := tbsSubmitCommand.Call(
uintptr(context),
uintptr(commandLocalityZero),
uintptr(priority),
sliceAddress(commandBuffer),
uintptr(len(commandBuffer)),
sliceAddress(responseBuffer),
uintptr(unsafe.Pointer(&responseBufferLen)),
)
return responseBufferLen, getError(result)
}
// GetTCGLog gets the system event log, returning the number of bytes written
// to logBuffer. If logBuffer is nil, the size of the TCG log is returned.
// ErrInsufficientBuffer is returned if the logBuffer is too short. On failure,
// the returned length is unspecified.
// https://docs.microsoft.com/en-us/windows/desktop/api/Tbs/nf-tbs-tbsi_get_tcg_log
func (context Context) GetTCGLog(logBuffer []byte) (uint32, error) {
logBufferLen := uint32(len(logBuffer))
// TBS_RESULT Tbsi_Get_TCG_Log(
// TBS_HCONTEXT hContext,
// PBYTE pOutputBuf,
// PUINT32 pOutputBufLen
// );
if err := tbsGetTCGLog.Find(); err != nil {
return 0, err
}
result, _, _ := tbsGetTCGLog.Call(
uintptr(context),
sliceAddress(logBuffer),
uintptr(unsafe.Pointer(&logBufferLen)),
)
return logBufferLen, getError(result)
}

View File

@ -14,8 +14,34 @@ This package provides various compression algorithms.
[![Go](https://github.com/klauspost/compress/actions/workflows/go.yml/badge.svg)](https://github.com/klauspost/compress/actions/workflows/go.yml)
[![Sourcegraph Badge](https://sourcegraph.com/github.com/klauspost/compress/-/badge.svg)](https://sourcegraph.com/github.com/klauspost/compress?badge)
# package usage
Use `go get github.com/klauspost/compress@latest` to add it to your project.
This package will support the current Go version and 2 versions back.
* Use the `nounsafe` tag to disable all use of the "unsafe" package.
* Use the `noasm` tag to disable all assembly across packages.
Use the links above for more information on each.
# changelog
* Feb 19th, 2025 - [1.18.0](https://github.com/klauspost/compress/releases/tag/v1.18.0)
* Add unsafe little endian loaders https://github.com/klauspost/compress/pull/1036
* fix: check `r.err != nil` but return a nil value error `err` by @alingse in https://github.com/klauspost/compress/pull/1028
* flate: Simplify L4-6 loading https://github.com/klauspost/compress/pull/1043
* flate: Simplify matchlen (remove asm) https://github.com/klauspost/compress/pull/1045
* s2: Improve small block compression speed w/o asm https://github.com/klauspost/compress/pull/1048
* flate: Fix matchlen L5+L6 https://github.com/klauspost/compress/pull/1049
* flate: Cleanup & reduce casts https://github.com/klauspost/compress/pull/1050
* Oct 11th, 2024 - [1.17.11](https://github.com/klauspost/compress/releases/tag/v1.17.11)
* zstd: Fix extra CRC written with multiple Close calls https://github.com/klauspost/compress/pull/1017
* s2: Don't use stack for index tables https://github.com/klauspost/compress/pull/1014
* gzhttp: No content-type on no body response code by @juliens in https://github.com/klauspost/compress/pull/1011
* gzhttp: Do not set the content-type when response has no body by @kevinpollet in https://github.com/klauspost/compress/pull/1013
* Sep 23rd, 2024 - [1.17.10](https://github.com/klauspost/compress/releases/tag/v1.17.10)
* gzhttp: Add TransportAlwaysDecompress option. https://github.com/klauspost/compress/pull/978
* gzhttp: Add supported decompress request body by @mirecl in https://github.com/klauspost/compress/pull/1002
@ -65,9 +91,9 @@ https://github.com/klauspost/compress/pull/919 https://github.com/klauspost/comp
* zstd: Fix rare *CORRUPTION* output in "best" mode. See https://github.com/klauspost/compress/pull/876
* Oct 14th, 2023 - [v1.17.1](https://github.com/klauspost/compress/releases/tag/v1.17.1)
* s2: Fix S2 "best" dictionary wrong encoding by @klauspost in https://github.com/klauspost/compress/pull/871
* s2: Fix S2 "best" dictionary wrong encoding https://github.com/klauspost/compress/pull/871
* flate: Reduce allocations in decompressor and minor code improvements by @fakefloordiv in https://github.com/klauspost/compress/pull/869
* s2: Fix EstimateBlockSize on 6&7 length input by @klauspost in https://github.com/klauspost/compress/pull/867
* s2: Fix EstimateBlockSize on 6&7 length input https://github.com/klauspost/compress/pull/867
* Sept 19th, 2023 - [v1.17.0](https://github.com/klauspost/compress/releases/tag/v1.17.0)
* Add experimental dictionary builder https://github.com/klauspost/compress/pull/853
@ -124,7 +150,7 @@ https://github.com/klauspost/compress/pull/919 https://github.com/klauspost/comp
<summary>See changes to v1.15.x</summary>
* Jan 21st, 2023 (v1.15.15)
* deflate: Improve level 7-9 by @klauspost in https://github.com/klauspost/compress/pull/739
* deflate: Improve level 7-9 https://github.com/klauspost/compress/pull/739
* zstd: Add delta encoding support by @greatroar in https://github.com/klauspost/compress/pull/728
* zstd: Various speed improvements by @greatroar https://github.com/klauspost/compress/pull/741 https://github.com/klauspost/compress/pull/734 https://github.com/klauspost/compress/pull/736 https://github.com/klauspost/compress/pull/744 https://github.com/klauspost/compress/pull/743 https://github.com/klauspost/compress/pull/745
* gzhttp: Add SuffixETag() and DropETag() options to prevent ETag collisions on compressed responses by @willbicks in https://github.com/klauspost/compress/pull/740
@ -167,7 +193,7 @@ https://github.com/klauspost/compress/pull/919 https://github.com/klauspost/comp
* zstd: Fix decoder crash on amd64 (no BMI) on invalid input https://github.com/klauspost/compress/pull/645
* zstd: Disable decoder extended memory copies (amd64) due to possible crashes https://github.com/klauspost/compress/pull/644
* zstd: Allow single segments up to "max decoded size" by @klauspost in https://github.com/klauspost/compress/pull/643
* zstd: Allow single segments up to "max decoded size" https://github.com/klauspost/compress/pull/643
* July 13, 2022 (v1.15.8)
@ -209,7 +235,7 @@ https://github.com/klauspost/compress/pull/919 https://github.com/klauspost/comp
* zstd: Speed up when WithDecoderLowmem(false) https://github.com/klauspost/compress/pull/599
* zstd: faster next state update in BMI2 version of decode by @WojciechMula in https://github.com/klauspost/compress/pull/593
* huff0: Do not check max size when reading table. https://github.com/klauspost/compress/pull/586
* flate: Inplace hashing for level 7-9 by @klauspost in https://github.com/klauspost/compress/pull/590
* flate: Inplace hashing for level 7-9 https://github.com/klauspost/compress/pull/590
* May 11, 2022 (v1.15.4)
@ -236,12 +262,12 @@ https://github.com/klauspost/compress/pull/919 https://github.com/klauspost/comp
* zstd: Add stricter block size checks in [#523](https://github.com/klauspost/compress/pull/523)
* Mar 3, 2022 (v1.15.0)
* zstd: Refactor decoder by @klauspost in [#498](https://github.com/klauspost/compress/pull/498)
* zstd: Add stream encoding without goroutines by @klauspost in [#505](https://github.com/klauspost/compress/pull/505)
* zstd: Refactor decoder [#498](https://github.com/klauspost/compress/pull/498)
* zstd: Add stream encoding without goroutines [#505](https://github.com/klauspost/compress/pull/505)
* huff0: Prevent single blocks exceeding 16 bits by @klauspost in[#507](https://github.com/klauspost/compress/pull/507)
* flate: Inline literal emission by @klauspost in [#509](https://github.com/klauspost/compress/pull/509)
* gzhttp: Add zstd to transport by @klauspost in [#400](https://github.com/klauspost/compress/pull/400)
* gzhttp: Make content-type optional by @klauspost in [#510](https://github.com/klauspost/compress/pull/510)
* flate: Inline literal emission [#509](https://github.com/klauspost/compress/pull/509)
* gzhttp: Add zstd to transport [#400](https://github.com/klauspost/compress/pull/400)
* gzhttp: Make content-type optional [#510](https://github.com/klauspost/compress/pull/510)
Both compression and decompression now supports "synchronous" stream operations. This means that whenever "concurrency" is set to 1, they will operate without spawning goroutines.
@ -258,7 +284,7 @@ While the release has been extensively tested, it is recommended to testing when
* flate: Fix rare huffman only (-2) corruption. [#503](https://github.com/klauspost/compress/pull/503)
* zip: Update deprecated CreateHeaderRaw to correctly call CreateRaw by @saracen in [#502](https://github.com/klauspost/compress/pull/502)
* zip: don't read data descriptor early by @saracen in [#501](https://github.com/klauspost/compress/pull/501) #501
* huff0: Use static decompression buffer up to 30% faster by @klauspost in [#499](https://github.com/klauspost/compress/pull/499) [#500](https://github.com/klauspost/compress/pull/500)
* huff0: Use static decompression buffer up to 30% faster [#499](https://github.com/klauspost/compress/pull/499) [#500](https://github.com/klauspost/compress/pull/500)
* Feb 17, 2022 (v1.14.3)
* flate: Improve fastest levels compression speed ~10% more throughput. [#482](https://github.com/klauspost/compress/pull/482) [#489](https://github.com/klauspost/compress/pull/489) [#490](https://github.com/klauspost/compress/pull/490) [#491](https://github.com/klauspost/compress/pull/491) [#494](https://github.com/klauspost/compress/pull/494) [#478](https://github.com/klauspost/compress/pull/478)
@ -565,12 +591,14 @@ While the release has been extensively tested, it is recommended to testing when
The packages are drop-in replacements for standard libraries. Simply replace the import path to use them:
| old import | new import | Documentation
|--------------------|-----------------------------------------|--------------------|
| `compress/gzip` | `github.com/klauspost/compress/gzip` | [gzip](https://pkg.go.dev/github.com/klauspost/compress/gzip?tab=doc)
| `compress/zlib` | `github.com/klauspost/compress/zlib` | [zlib](https://pkg.go.dev/github.com/klauspost/compress/zlib?tab=doc)
| `archive/zip` | `github.com/klauspost/compress/zip` | [zip](https://pkg.go.dev/github.com/klauspost/compress/zip?tab=doc)
| `compress/flate` | `github.com/klauspost/compress/flate` | [flate](https://pkg.go.dev/github.com/klauspost/compress/flate?tab=doc)
Typical speed is about 2x of the standard library packages.
| old import | new import | Documentation |
|------------------|---------------------------------------|-------------------------------------------------------------------------|
| `compress/gzip` | `github.com/klauspost/compress/gzip` | [gzip](https://pkg.go.dev/github.com/klauspost/compress/gzip?tab=doc) |
| `compress/zlib` | `github.com/klauspost/compress/zlib` | [zlib](https://pkg.go.dev/github.com/klauspost/compress/zlib?tab=doc) |
| `archive/zip` | `github.com/klauspost/compress/zip` | [zip](https://pkg.go.dev/github.com/klauspost/compress/zip?tab=doc) |
| `compress/flate` | `github.com/klauspost/compress/flate` | [flate](https://pkg.go.dev/github.com/klauspost/compress/flate?tab=doc) |
* Optimized [deflate](https://godoc.org/github.com/klauspost/compress/flate) packages which can be used as a dropin replacement for [gzip](https://godoc.org/github.com/klauspost/compress/gzip), [zip](https://godoc.org/github.com/klauspost/compress/zip) and [zlib](https://godoc.org/github.com/klauspost/compress/zlib).
@ -625,84 +653,6 @@ This will only use up to 4KB in memory when the writer is idle.
Compression is almost always worse than the fastest compression level
and each write will allocate (a little) memory.
# Performance Update 2018
It has been a while since we have been looking at the speed of this package compared to the standard library, so I thought I would re-do my tests and give some overall recommendations based on the current state. All benchmarks have been performed with Go 1.10 on my Desktop Intel(R) Core(TM) i7-2600 CPU @3.40GHz. Since I last ran the tests, I have gotten more RAM, which means tests with big files are no longer limited by my SSD.
The raw results are in my [updated spreadsheet](https://docs.google.com/spreadsheets/d/1nuNE2nPfuINCZJRMt6wFWhKpToF95I47XjSsc-1rbPQ/edit?usp=sharing). Due to cgo changes and upstream updates i could not get the cgo version of gzip to compile. Instead I included the [zstd](https://github.com/datadog/zstd) cgo implementation. If I get cgo gzip to work again, I might replace the results in the sheet.
The columns to take note of are: *MB/s* - the throughput. *Reduction* - the data size reduction in percent of the original. *Rel Speed* relative speed compared to the standard library at the same level. *Smaller* - how many percent smaller is the compressed output compared to stdlib. Negative means the output was bigger. *Loss* means the loss (or gain) in compression as a percentage difference of the input.
The `gzstd` (standard library gzip) and `gzkp` (this package gzip) only uses one CPU core. [`pgzip`](https://github.com/klauspost/pgzip), [`bgzf`](https://github.com/biogo/hts/tree/master/bgzf) uses all 4 cores. [`zstd`](https://github.com/DataDog/zstd) uses one core, and is a beast (but not Go, yet).
## Overall differences.
There appears to be a roughly 5-10% speed advantage over the standard library when comparing at similar compression levels.
The biggest difference you will see is the result of [re-balancing](https://blog.klauspost.com/rebalancing-deflate-compression-levels/) the compression levels. I wanted by library to give a smoother transition between the compression levels than the standard library.
This package attempts to provide a more smooth transition, where "1" is taking a lot of shortcuts, "5" is the reasonable trade-off and "9" is the "give me the best compression", and the values in between gives something reasonable in between. The standard library has big differences in levels 1-4, but levels 5-9 having no significant gains - often spending a lot more time than can be justified by the achieved compression.
There are links to all the test data in the [spreadsheet](https://docs.google.com/spreadsheets/d/1nuNE2nPfuINCZJRMt6wFWhKpToF95I47XjSsc-1rbPQ/edit?usp=sharing) in the top left field on each tab.
## Web Content
This test set aims to emulate typical use in a web server. The test-set is 4GB data in 53k files, and is a mixture of (mostly) HTML, JS, CSS.
Since level 1 and 9 are close to being the same code, they are quite close. But looking at the levels in-between the differences are quite big.
Looking at level 6, this package is 88% faster, but will output about 6% more data. For a web server, this means you can serve 88% more data, but have to pay for 6% more bandwidth. You can draw your own conclusions on what would be the most expensive for your case.
## Object files
This test is for typical data files stored on a server. In this case it is a collection of Go precompiled objects. They are very compressible.
The picture is similar to the web content, but with small differences since this is very compressible. Levels 2-3 offer good speed, but is sacrificing quite a bit of compression.
The standard library seems suboptimal on level 3 and 4 - offering both worse compression and speed than level 6 & 7 of this package respectively.
## Highly Compressible File
This is a JSON file with very high redundancy. The reduction starts at 95% on level 1, so in real life terms we are dealing with something like a highly redundant stream of data, etc.
It is definitely visible that we are dealing with specialized content here, so the results are very scattered. This package does not do very well at levels 1-4, but picks up significantly at level 5 and levels 7 and 8 offering great speed for the achieved compression.
So if you know you content is extremely compressible you might want to go slightly higher than the defaults. The standard library has a huge gap between levels 3 and 4 in terms of speed (2.75x slowdown), so it offers little "middle ground".
## Medium-High Compressible
This is a pretty common test corpus: [enwik9](http://mattmahoney.net/dc/textdata.html). It contains the first 10^9 bytes of the English Wikipedia dump on Mar. 3, 2006. This is a very good test of typical text based compression and more data heavy streams.
We see a similar picture here as in "Web Content". On equal levels some compression is sacrificed for more speed. Level 5 seems to be the best trade-off between speed and size, beating stdlib level 3 in both.
## Medium Compressible
I will combine two test sets, one [10GB file set](http://mattmahoney.net/dc/10gb.html) and a VM disk image (~8GB). Both contain different data types and represent a typical backup scenario.
The most notable thing is how quickly the standard library drops to very low compression speeds around level 5-6 without any big gains in compression. Since this type of data is fairly common, this does not seem like good behavior.
## Un-compressible Content
This is mainly a test of how good the algorithms are at detecting un-compressible input. The standard library only offers this feature with very conservative settings at level 1. Obviously there is no reason for the algorithms to try to compress input that cannot be compressed. The only downside is that it might skip some compressible data on false detections.
## Huffman only compression
This compression library adds a special compression level, named `HuffmanOnly`, which allows near linear time compression. This is done by completely disabling matching of previous data, and only reduce the number of bits to represent each character.
This means that often used characters, like 'e' and ' ' (space) in text use the fewest bits to represent, and rare characters like '¤' takes more bits to represent. For more information see [wikipedia](https://en.wikipedia.org/wiki/Huffman_coding) or this nice [video](https://youtu.be/ZdooBTdW5bM).
Since this type of compression has much less variance, the compression speed is mostly unaffected by the input data, and is usually more than *180MB/s* for a single core.
The downside is that the compression ratio is usually considerably worse than even the fastest conventional compression. The compression ratio can never be better than 8:1 (12.5%).
The linear time compression can be used as a "better than nothing" mode, where you cannot risk the encoder to slow down on some content. For comparison, the size of the "Twain" text is *233460 bytes* (+29% vs. level 1) and encode speed is 144MB/s (4.5x level 1). So in this case you trade a 30% size increase for a 4 times speedup.
For more information see my blog post on [Fast Linear Time Compression](http://blog.klauspost.com/constant-time-gzipzip-compression/).
This is implemented on Go 1.7 as "Huffman Only" mode, though not exposed for gzip.
# Other packages

View File

@ -6,8 +6,10 @@
package flate
import (
"encoding/binary"
"fmt"
"math/bits"
"github.com/klauspost/compress/internal/le"
)
type fastEnc interface {
@ -58,11 +60,11 @@ const (
)
func load3232(b []byte, i int32) uint32 {
return binary.LittleEndian.Uint32(b[i:])
return le.Load32(b, i)
}
func load6432(b []byte, i int32) uint64 {
return binary.LittleEndian.Uint64(b[i:])
return le.Load64(b, i)
}
type tableEntry struct {
@ -134,8 +136,8 @@ func hashLen(u uint64, length, mls uint8) uint32 {
// matchlen will return the match length between offsets and t in src.
// The maximum length returned is maxMatchLength - 4.
// It is assumed that s > t, that t >=0 and s < len(src).
func (e *fastGen) matchlen(s, t int32, src []byte) int32 {
if debugDecode {
func (e *fastGen) matchlen(s, t int, src []byte) int32 {
if debugDeflate {
if t >= s {
panic(fmt.Sprint("t >=s:", t, s))
}
@ -149,18 +151,34 @@ func (e *fastGen) matchlen(s, t int32, src []byte) int32 {
panic(fmt.Sprint(s, "-", t, "(", s-t, ") > maxMatchLength (", maxMatchOffset, ")"))
}
}
s1 := int(s) + maxMatchLength - 4
if s1 > len(src) {
s1 = len(src)
s1 := min(s+maxMatchLength-4, len(src))
left := s1 - s
n := int32(0)
for left >= 8 {
diff := le.Load64(src, s) ^ le.Load64(src, t)
if diff != 0 {
return n + int32(bits.TrailingZeros64(diff)>>3)
}
s += 8
t += 8
n += 8
left -= 8
}
// Extend the match to be as long as possible.
return int32(matchLen(src[s:s1], src[t:]))
a := src[s:s1]
b := src[t:]
for i := range a {
if a[i] != b[i] {
break
}
n++
}
return n
}
// matchlenLong will return the match length between offsets and t in src.
// It is assumed that s > t, that t >=0 and s < len(src).
func (e *fastGen) matchlenLong(s, t int32, src []byte) int32 {
func (e *fastGen) matchlenLong(s, t int, src []byte) int32 {
if debugDeflate {
if t >= s {
panic(fmt.Sprint("t >=s:", t, s))
@ -176,7 +194,28 @@ func (e *fastGen) matchlenLong(s, t int32, src []byte) int32 {
}
}
// Extend the match to be as long as possible.
return int32(matchLen(src[s:], src[t:]))
left := len(src) - s
n := int32(0)
for left >= 8 {
diff := le.Load64(src, s) ^ le.Load64(src, t)
if diff != 0 {
return n + int32(bits.TrailingZeros64(diff)>>3)
}
s += 8
t += 8
n += 8
left -= 8
}
a := src[s:]
b := src[t:]
for i := range a {
if a[i] != b[i] {
break
}
n++
}
return n
}
// Reset the encoding table.

View File

@ -5,10 +5,11 @@
package flate
import (
"encoding/binary"
"fmt"
"io"
"math"
"github.com/klauspost/compress/internal/le"
)
const (
@ -438,7 +439,7 @@ func (w *huffmanBitWriter) writeOutBits() {
n := w.nbytes
// We over-write, but faster...
binary.LittleEndian.PutUint64(w.bytes[n:], bits)
le.Store64(w.bytes[n:], bits)
n += 6
if n >= bufferFlushSize {
@ -854,7 +855,7 @@ func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode)
bits |= c.code64() << (nbits & 63)
nbits += c.len()
if nbits >= 48 {
binary.LittleEndian.PutUint64(w.bytes[nbytes:], bits)
le.Store64(w.bytes[nbytes:], bits)
//*(*uint64)(unsafe.Pointer(&w.bytes[nbytes])) = bits
bits >>= 48
nbits -= 48
@ -882,7 +883,7 @@ func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode)
bits |= c.code64() << (nbits & 63)
nbits += c.len()
if nbits >= 48 {
binary.LittleEndian.PutUint64(w.bytes[nbytes:], bits)
le.Store64(w.bytes[nbytes:], bits)
//*(*uint64)(unsafe.Pointer(&w.bytes[nbytes])) = bits
bits >>= 48
nbits -= 48
@ -905,7 +906,7 @@ func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode)
bits |= uint64(extraLength) << (nbits & 63)
nbits += extraLengthBits
if nbits >= 48 {
binary.LittleEndian.PutUint64(w.bytes[nbytes:], bits)
le.Store64(w.bytes[nbytes:], bits)
//*(*uint64)(unsafe.Pointer(&w.bytes[nbytes])) = bits
bits >>= 48
nbits -= 48
@ -931,7 +932,7 @@ func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode)
bits |= c.code64() << (nbits & 63)
nbits += c.len()
if nbits >= 48 {
binary.LittleEndian.PutUint64(w.bytes[nbytes:], bits)
le.Store64(w.bytes[nbytes:], bits)
//*(*uint64)(unsafe.Pointer(&w.bytes[nbytes])) = bits
bits >>= 48
nbits -= 48
@ -953,7 +954,7 @@ func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode)
bits |= uint64((offset-(offsetComb>>8))&matchOffsetOnlyMask) << (nbits & 63)
nbits += uint8(offsetComb)
if nbits >= 48 {
binary.LittleEndian.PutUint64(w.bytes[nbytes:], bits)
le.Store64(w.bytes[nbytes:], bits)
//*(*uint64)(unsafe.Pointer(&w.bytes[nbytes])) = bits
bits >>= 48
nbits -= 48
@ -1107,7 +1108,7 @@ func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte, sync bool) {
// We must have at least 48 bits free.
if nbits >= 8 {
n := nbits >> 3
binary.LittleEndian.PutUint64(w.bytes[nbytes:], bits)
le.Store64(w.bytes[nbytes:], bits)
bits >>= (n * 8) & 63
nbits -= n * 8
nbytes += n
@ -1136,7 +1137,7 @@ func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte, sync bool) {
// Remaining...
for _, t := range input {
if nbits >= 48 {
binary.LittleEndian.PutUint64(w.bytes[nbytes:], bits)
le.Store64(w.bytes[nbytes:], bits)
//*(*uint64)(unsafe.Pointer(&w.bytes[nbytes])) = bits
bits >>= 48
nbits -= 48

View File

@ -1,9 +1,9 @@
package flate
import (
"encoding/binary"
"fmt"
"math/bits"
"github.com/klauspost/compress/internal/le"
)
// fastGen maintains the table for matches,
@ -77,6 +77,7 @@ func (e *fastEncL1) Encode(dst *tokens, src []byte) {
nextS := s
var candidate tableEntry
var t int32
for {
nextHash := hashLen(cv, tableBits, hashBytes)
candidate = e.table[nextHash]
@ -88,9 +89,8 @@ func (e *fastEncL1) Encode(dst *tokens, src []byte) {
now := load6432(src, nextS)
e.table[nextHash] = tableEntry{offset: s + e.cur}
nextHash = hashLen(now, tableBits, hashBytes)
offset := s - (candidate.offset - e.cur)
if offset < maxMatchOffset && uint32(cv) == load3232(src, candidate.offset-e.cur) {
t = candidate.offset - e.cur
if s-t < maxMatchOffset && uint32(cv) == load3232(src, t) {
e.table[nextHash] = tableEntry{offset: nextS + e.cur}
break
}
@ -103,8 +103,8 @@ func (e *fastEncL1) Encode(dst *tokens, src []byte) {
now >>= 8
e.table[nextHash] = tableEntry{offset: s + e.cur}
offset = s - (candidate.offset - e.cur)
if offset < maxMatchOffset && uint32(cv) == load3232(src, candidate.offset-e.cur) {
t = candidate.offset - e.cur
if s-t < maxMatchOffset && uint32(cv) == load3232(src, t) {
e.table[nextHash] = tableEntry{offset: nextS + e.cur}
break
}
@ -120,36 +120,10 @@ func (e *fastEncL1) Encode(dst *tokens, src []byte) {
// literal bytes prior to s.
// Extend the 4-byte match as long as possible.
t := candidate.offset - e.cur
var l = int32(4)
if false {
l = e.matchlenLong(s+4, t+4, src) + 4
} else {
// inlined:
a := src[s+4:]
b := src[t+4:]
for len(a) >= 8 {
if diff := binary.LittleEndian.Uint64(a) ^ binary.LittleEndian.Uint64(b); diff != 0 {
l += int32(bits.TrailingZeros64(diff) >> 3)
break
}
l += 8
a = a[8:]
b = b[8:]
}
if len(a) < 8 {
b = b[:len(a)]
for i := range a {
if a[i] != b[i] {
break
}
l++
}
}
}
l := e.matchlenLong(int(s+4), int(t+4), src) + 4
// Extend backwards
for t > 0 && s > nextEmit && src[t-1] == src[s-1] {
for t > 0 && s > nextEmit && le.Load8(src, t-1) == le.Load8(src, s-1) {
s--
t--
l++
@ -221,8 +195,8 @@ func (e *fastEncL1) Encode(dst *tokens, src []byte) {
candidate = e.table[currHash]
e.table[currHash] = tableEntry{offset: o + 2}
offset := s - (candidate.offset - e.cur)
if offset > maxMatchOffset || uint32(x) != load3232(src, candidate.offset-e.cur) {
t = candidate.offset - e.cur
if s-t > maxMatchOffset || uint32(x) != load3232(src, t) {
cv = x >> 8
s++
break

View File

@ -126,7 +126,7 @@ func (e *fastEncL2) Encode(dst *tokens, src []byte) {
// Extend the 4-byte match as long as possible.
t := candidate.offset - e.cur
l := e.matchlenLong(s+4, t+4, src) + 4
l := e.matchlenLong(int(s+4), int(t+4), src) + 4
// Extend backwards
for t > 0 && s > nextEmit && src[t-1] == src[s-1] {

View File

@ -135,7 +135,7 @@ func (e *fastEncL3) Encode(dst *tokens, src []byte) {
// Extend the 4-byte match as long as possible.
//
t := candidate.offset - e.cur
l := e.matchlenLong(s+4, t+4, src) + 4
l := e.matchlenLong(int(s+4), int(t+4), src) + 4
// Extend backwards
for t > 0 && s > nextEmit && src[t-1] == src[s-1] {

View File

@ -98,19 +98,19 @@ func (e *fastEncL4) Encode(dst *tokens, src []byte) {
e.bTable[nextHashL] = entry
t = lCandidate.offset - e.cur
if s-t < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.offset-e.cur) {
if s-t < maxMatchOffset && uint32(cv) == load3232(src, t) {
// We got a long match. Use that.
break
}
t = sCandidate.offset - e.cur
if s-t < maxMatchOffset && uint32(cv) == load3232(src, sCandidate.offset-e.cur) {
if s-t < maxMatchOffset && uint32(cv) == load3232(src, t) {
// Found a 4 match...
lCandidate = e.bTable[hash7(next, tableBits)]
// If the next long is a candidate, check if we should use that instead...
lOff := nextS - (lCandidate.offset - e.cur)
if lOff < maxMatchOffset && load3232(src, lCandidate.offset-e.cur) == uint32(next) {
lOff := lCandidate.offset - e.cur
if nextS-lOff < maxMatchOffset && load3232(src, lOff) == uint32(next) {
l1, l2 := matchLen(src[s+4:], src[t+4:]), matchLen(src[nextS+4:], src[nextS-lOff+4:])
if l2 > l1 {
s = nextS
@ -127,7 +127,7 @@ func (e *fastEncL4) Encode(dst *tokens, src []byte) {
// them as literal bytes.
// Extend the 4-byte match as long as possible.
l := e.matchlenLong(s+4, t+4, src) + 4
l := e.matchlenLong(int(s+4), int(t+4), src) + 4
// Extend backwards
for t > 0 && s > nextEmit && src[t-1] == src[s-1] {

View File

@ -111,16 +111,16 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
t = lCandidate.Cur.offset - e.cur
if s-t < maxMatchOffset {
if uint32(cv) == load3232(src, lCandidate.Cur.offset-e.cur) {
if uint32(cv) == load3232(src, t) {
// Store the next match
e.table[nextHashS] = tableEntry{offset: nextS + e.cur}
eLong := &e.bTable[nextHashL]
eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur
t2 := lCandidate.Prev.offset - e.cur
if s-t2 < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) {
l = e.matchlen(s+4, t+4, src) + 4
ml1 := e.matchlen(s+4, t2+4, src) + 4
if s-t2 < maxMatchOffset && uint32(cv) == load3232(src, t2) {
l = e.matchlen(int(s+4), int(t+4), src) + 4
ml1 := e.matchlen(int(s+4), int(t2+4), src) + 4
if ml1 > l {
t = t2
l = ml1
@ -130,7 +130,7 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
break
}
t = lCandidate.Prev.offset - e.cur
if s-t < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) {
if s-t < maxMatchOffset && uint32(cv) == load3232(src, t) {
// Store the next match
e.table[nextHashS] = tableEntry{offset: nextS + e.cur}
eLong := &e.bTable[nextHashL]
@ -140,9 +140,9 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
}
t = sCandidate.offset - e.cur
if s-t < maxMatchOffset && uint32(cv) == load3232(src, sCandidate.offset-e.cur) {
if s-t < maxMatchOffset && uint32(cv) == load3232(src, t) {
// Found a 4 match...
l = e.matchlen(s+4, t+4, src) + 4
l = e.matchlen(int(s+4), int(t+4), src) + 4
lCandidate = e.bTable[nextHashL]
// Store the next match
@ -153,8 +153,8 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
// If the next long is a candidate, use that...
t2 := lCandidate.Cur.offset - e.cur
if nextS-t2 < maxMatchOffset {
if load3232(src, lCandidate.Cur.offset-e.cur) == uint32(next) {
ml := e.matchlen(nextS+4, t2+4, src) + 4
if load3232(src, t2) == uint32(next) {
ml := e.matchlen(int(nextS+4), int(t2+4), src) + 4
if ml > l {
t = t2
s = nextS
@ -164,8 +164,8 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
}
// If the previous long is a candidate, use that...
t2 = lCandidate.Prev.offset - e.cur
if nextS-t2 < maxMatchOffset && load3232(src, lCandidate.Prev.offset-e.cur) == uint32(next) {
ml := e.matchlen(nextS+4, t2+4, src) + 4
if nextS-t2 < maxMatchOffset && load3232(src, t2) == uint32(next) {
ml := e.matchlen(int(nextS+4), int(t2+4), src) + 4
if ml > l {
t = t2
s = nextS
@ -185,9 +185,9 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
if l == 0 {
// Extend the 4-byte match as long as possible.
l = e.matchlenLong(s+4, t+4, src) + 4
l = e.matchlenLong(int(s+4), int(t+4), src) + 4
} else if l == maxMatchLength {
l += e.matchlenLong(s+l, t+l, src)
l += e.matchlenLong(int(s+l), int(t+l), src)
}
// Try to locate a better match by checking the end of best match...
@ -203,7 +203,7 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
s2 := s + skipBeginning
off := s2 - t2
if t2 >= 0 && off < maxMatchOffset && off > 0 {
if l2 := e.matchlenLong(s2, t2, src); l2 > l {
if l2 := e.matchlenLong(int(s2), int(t2), src); l2 > l {
t = t2
l = l2
s = s2
@ -423,14 +423,14 @@ func (e *fastEncL5Window) Encode(dst *tokens, src []byte) {
t = lCandidate.Cur.offset - e.cur
if s-t < maxMatchOffset {
if uint32(cv) == load3232(src, lCandidate.Cur.offset-e.cur) {
if uint32(cv) == load3232(src, t) {
// Store the next match
e.table[nextHashS] = tableEntry{offset: nextS + e.cur}
eLong := &e.bTable[nextHashL]
eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur
t2 := lCandidate.Prev.offset - e.cur
if s-t2 < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) {
if s-t2 < maxMatchOffset && uint32(cv) == load3232(src, t2) {
l = e.matchlen(s+4, t+4, src) + 4
ml1 := e.matchlen(s+4, t2+4, src) + 4
if ml1 > l {
@ -442,7 +442,7 @@ func (e *fastEncL5Window) Encode(dst *tokens, src []byte) {
break
}
t = lCandidate.Prev.offset - e.cur
if s-t < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) {
if s-t < maxMatchOffset && uint32(cv) == load3232(src, t) {
// Store the next match
e.table[nextHashS] = tableEntry{offset: nextS + e.cur}
eLong := &e.bTable[nextHashL]
@ -452,7 +452,7 @@ func (e *fastEncL5Window) Encode(dst *tokens, src []byte) {
}
t = sCandidate.offset - e.cur
if s-t < maxMatchOffset && uint32(cv) == load3232(src, sCandidate.offset-e.cur) {
if s-t < maxMatchOffset && uint32(cv) == load3232(src, t) {
// Found a 4 match...
l = e.matchlen(s+4, t+4, src) + 4
lCandidate = e.bTable[nextHashL]
@ -465,7 +465,7 @@ func (e *fastEncL5Window) Encode(dst *tokens, src []byte) {
// If the next long is a candidate, use that...
t2 := lCandidate.Cur.offset - e.cur
if nextS-t2 < maxMatchOffset {
if load3232(src, lCandidate.Cur.offset-e.cur) == uint32(next) {
if load3232(src, t2) == uint32(next) {
ml := e.matchlen(nextS+4, t2+4, src) + 4
if ml > l {
t = t2
@ -476,7 +476,7 @@ func (e *fastEncL5Window) Encode(dst *tokens, src []byte) {
}
// If the previous long is a candidate, use that...
t2 = lCandidate.Prev.offset - e.cur
if nextS-t2 < maxMatchOffset && load3232(src, lCandidate.Prev.offset-e.cur) == uint32(next) {
if nextS-t2 < maxMatchOffset && load3232(src, t2) == uint32(next) {
ml := e.matchlen(nextS+4, t2+4, src) + 4
if ml > l {
t = t2

View File

@ -113,7 +113,7 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
t = lCandidate.Cur.offset - e.cur
if s-t < maxMatchOffset {
if uint32(cv) == load3232(src, lCandidate.Cur.offset-e.cur) {
if uint32(cv) == load3232(src, t) {
// Long candidate matches at least 4 bytes.
// Store the next match
@ -123,9 +123,9 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
// Check the previous long candidate as well.
t2 := lCandidate.Prev.offset - e.cur
if s-t2 < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) {
l = e.matchlen(s+4, t+4, src) + 4
ml1 := e.matchlen(s+4, t2+4, src) + 4
if s-t2 < maxMatchOffset && uint32(cv) == load3232(src, t2) {
l = e.matchlen(int(s+4), int(t+4), src) + 4
ml1 := e.matchlen(int(s+4), int(t2+4), src) + 4
if ml1 > l {
t = t2
l = ml1
@ -136,7 +136,7 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
}
// Current value did not match, but check if previous long value does.
t = lCandidate.Prev.offset - e.cur
if s-t < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) {
if s-t < maxMatchOffset && uint32(cv) == load3232(src, t) {
// Store the next match
e.table[nextHashS] = tableEntry{offset: nextS + e.cur}
eLong := &e.bTable[nextHashL]
@ -146,9 +146,9 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
}
t = sCandidate.offset - e.cur
if s-t < maxMatchOffset && uint32(cv) == load3232(src, sCandidate.offset-e.cur) {
if s-t < maxMatchOffset && uint32(cv) == load3232(src, t) {
// Found a 4 match...
l = e.matchlen(s+4, t+4, src) + 4
l = e.matchlen(int(s+4), int(t+4), src) + 4
// Look up next long candidate (at nextS)
lCandidate = e.bTable[nextHashL]
@ -162,7 +162,7 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
const repOff = 1
t2 := s - repeat + repOff
if load3232(src, t2) == uint32(cv>>(8*repOff)) {
ml := e.matchlen(s+4+repOff, t2+4, src) + 4
ml := e.matchlen(int(s+4+repOff), int(t2+4), src) + 4
if ml > l {
t = t2
l = ml
@ -175,8 +175,8 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
// If the next long is a candidate, use that...
t2 = lCandidate.Cur.offset - e.cur
if nextS-t2 < maxMatchOffset {
if load3232(src, lCandidate.Cur.offset-e.cur) == uint32(next) {
ml := e.matchlen(nextS+4, t2+4, src) + 4
if load3232(src, t2) == uint32(next) {
ml := e.matchlen(int(nextS+4), int(t2+4), src) + 4
if ml > l {
t = t2
s = nextS
@ -186,8 +186,8 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
}
// If the previous long is a candidate, use that...
t2 = lCandidate.Prev.offset - e.cur
if nextS-t2 < maxMatchOffset && load3232(src, lCandidate.Prev.offset-e.cur) == uint32(next) {
ml := e.matchlen(nextS+4, t2+4, src) + 4
if nextS-t2 < maxMatchOffset && load3232(src, t2) == uint32(next) {
ml := e.matchlen(int(nextS+4), int(t2+4), src) + 4
if ml > l {
t = t2
s = nextS
@ -207,9 +207,9 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
// Extend the 4-byte match as long as possible.
if l == 0 {
l = e.matchlenLong(s+4, t+4, src) + 4
l = e.matchlenLong(int(s+4), int(t+4), src) + 4
} else if l == maxMatchLength {
l += e.matchlenLong(s+l, t+l, src)
l += e.matchlenLong(int(s+l), int(t+l), src)
}
// Try to locate a better match by checking the end-of-match...
@ -227,7 +227,7 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
off := s2 - t2
if off < maxMatchOffset {
if off > 0 && t2 >= 0 {
if l2 := e.matchlenLong(s2, t2, src); l2 > l {
if l2 := e.matchlenLong(int(s2), int(t2), src); l2 > l {
t = t2
l = l2
s = s2
@ -237,7 +237,7 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
t2 = eLong.Prev.offset - e.cur - l + skipBeginning
off := s2 - t2
if off > 0 && off < maxMatchOffset && t2 >= 0 {
if l2 := e.matchlenLong(s2, t2, src); l2 > l {
if l2 := e.matchlenLong(int(s2), int(t2), src); l2 > l {
t = t2
l = l2
s = s2

View File

@ -1,16 +0,0 @@
//go:build amd64 && !appengine && !noasm && gc
// +build amd64,!appengine,!noasm,gc
// Copyright 2019+ Klaus Post. All rights reserved.
// License information can be found in the LICENSE file.
package flate
// matchLen returns how many bytes match in a and b
//
// It assumes that:
//
// len(a) <= len(b) and len(a) > 0
//
//go:noescape
func matchLen(a []byte, b []byte) int

View File

@ -1,66 +0,0 @@
// Copied from S2 implementation.
//go:build !appengine && !noasm && gc && !noasm
#include "textflag.h"
// func matchLen(a []byte, b []byte) int
TEXT ·matchLen(SB), NOSPLIT, $0-56
MOVQ a_base+0(FP), AX
MOVQ b_base+24(FP), CX
MOVQ a_len+8(FP), DX
// matchLen
XORL SI, SI
CMPL DX, $0x08
JB matchlen_match4_standalone
matchlen_loopback_standalone:
MOVQ (AX)(SI*1), BX
XORQ (CX)(SI*1), BX
JZ matchlen_loop_standalone
#ifdef GOAMD64_v3
TZCNTQ BX, BX
#else
BSFQ BX, BX
#endif
SHRL $0x03, BX
LEAL (SI)(BX*1), SI
JMP gen_match_len_end
matchlen_loop_standalone:
LEAL -8(DX), DX
LEAL 8(SI), SI
CMPL DX, $0x08
JAE matchlen_loopback_standalone
matchlen_match4_standalone:
CMPL DX, $0x04
JB matchlen_match2_standalone
MOVL (AX)(SI*1), BX
CMPL (CX)(SI*1), BX
JNE matchlen_match2_standalone
LEAL -4(DX), DX
LEAL 4(SI), SI
matchlen_match2_standalone:
CMPL DX, $0x02
JB matchlen_match1_standalone
MOVW (AX)(SI*1), BX
CMPW (CX)(SI*1), BX
JNE matchlen_match1_standalone
LEAL -2(DX), DX
LEAL 2(SI), SI
matchlen_match1_standalone:
CMPL DX, $0x01
JB gen_match_len_end
MOVB (AX)(SI*1), BL
CMPB (CX)(SI*1), BL
JNE gen_match_len_end
INCL SI
gen_match_len_end:
MOVQ SI, ret+48(FP)
RET

View File

@ -1,27 +1,29 @@
//go:build !amd64 || appengine || !gc || noasm
// +build !amd64 appengine !gc noasm
// Copyright 2019+ Klaus Post. All rights reserved.
// License information can be found in the LICENSE file.
package flate
import (
"encoding/binary"
"math/bits"
"github.com/klauspost/compress/internal/le"
)
// matchLen returns the maximum common prefix length of a and b.
// a must be the shortest of the two.
func matchLen(a, b []byte) (n int) {
for ; len(a) >= 8 && len(b) >= 8; a, b = a[8:], b[8:] {
diff := binary.LittleEndian.Uint64(a) ^ binary.LittleEndian.Uint64(b)
left := len(a)
for left >= 8 {
diff := le.Load64(a, n) ^ le.Load64(b, n)
if diff != 0 {
return n + bits.TrailingZeros64(diff)>>3
}
n += 8
left -= 8
}
a = a[n:]
b = b[n:]
for i := range a {
if a[i] != b[i] {
break
@ -29,5 +31,4 @@ func matchLen(a, b []byte) (n int) {
n++
}
return n
}

View File

@ -4,6 +4,8 @@ import (
"io"
"math"
"sync"
"github.com/klauspost/compress/internal/le"
)
const (
@ -152,18 +154,11 @@ func hashSL(u uint32) uint32 {
}
func load3216(b []byte, i int16) uint32 {
// Help the compiler eliminate bounds checks on the read so it can be done in a single read.
b = b[i:]
b = b[:4]
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
return le.Load32(b, i)
}
func load6416(b []byte, i int16) uint64 {
// Help the compiler eliminate bounds checks on the read so it can be done in a single read.
b = b[i:]
b = b[:8]
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
return le.Load64(b, i)
}
func statelessEnc(dst *tokens, src []byte, startAt int16) {

View File

@ -6,10 +6,11 @@
package huff0
import (
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/klauspost/compress/internal/le"
)
// bitReader reads a bitstream in reverse.
@ -46,7 +47,7 @@ func (b *bitReaderBytes) init(in []byte) error {
return nil
}
// peekBitsFast requires that at least one bit is requested every time.
// peekByteFast requires that at least one byte is requested every time.
// There are no checks if the buffer is filled.
func (b *bitReaderBytes) peekByteFast() uint8 {
got := uint8(b.value >> 56)
@ -66,8 +67,7 @@ func (b *bitReaderBytes) fillFast() {
}
// 2 bounds checks.
v := b.in[b.off-4 : b.off]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
low := le.Load32(b.in, b.off-4)
b.value |= uint64(low) << (b.bitsRead - 32)
b.bitsRead -= 32
b.off -= 4
@ -76,7 +76,7 @@ func (b *bitReaderBytes) fillFast() {
// fillFastStart() assumes the bitReaderBytes is empty and there is at least 8 bytes to read.
func (b *bitReaderBytes) fillFastStart() {
// Do single re-slice to avoid bounds checks.
b.value = binary.LittleEndian.Uint64(b.in[b.off-8:])
b.value = le.Load64(b.in, b.off-8)
b.bitsRead = 0
b.off -= 8
}
@ -86,9 +86,8 @@ func (b *bitReaderBytes) fill() {
if b.bitsRead < 32 {
return
}
if b.off > 4 {
v := b.in[b.off-4 : b.off]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
if b.off >= 4 {
low := le.Load32(b.in, b.off-4)
b.value |= uint64(low) << (b.bitsRead - 32)
b.bitsRead -= 32
b.off -= 4
@ -175,9 +174,7 @@ func (b *bitReaderShifted) fillFast() {
return
}
// 2 bounds checks.
v := b.in[b.off-4 : b.off]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
low := le.Load32(b.in, b.off-4)
b.value |= uint64(low) << ((b.bitsRead - 32) & 63)
b.bitsRead -= 32
b.off -= 4
@ -185,8 +182,7 @@ func (b *bitReaderShifted) fillFast() {
// fillFastStart() assumes the bitReaderShifted is empty and there is at least 8 bytes to read.
func (b *bitReaderShifted) fillFastStart() {
// Do single re-slice to avoid bounds checks.
b.value = binary.LittleEndian.Uint64(b.in[b.off-8:])
b.value = le.Load64(b.in, b.off-8)
b.bitsRead = 0
b.off -= 8
}
@ -197,8 +193,7 @@ func (b *bitReaderShifted) fill() {
return
}
if b.off > 4 {
v := b.in[b.off-4 : b.off]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
low := le.Load32(b.in, b.off-4)
b.value |= uint64(low) << ((b.bitsRead - 32) & 63)
b.bitsRead -= 32
b.off -= 4

View File

@ -0,0 +1,5 @@
package le
type Indexer interface {
int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64
}

View File

@ -0,0 +1,42 @@
//go:build !(amd64 || arm64 || ppc64le || riscv64) || nounsafe || purego || appengine
package le
import (
"encoding/binary"
)
// Load8 will load from b at index i.
func Load8[I Indexer](b []byte, i I) byte {
return b[i]
}
// Load16 will load from b at index i.
func Load16[I Indexer](b []byte, i I) uint16 {
return binary.LittleEndian.Uint16(b[i:])
}
// Load32 will load from b at index i.
func Load32[I Indexer](b []byte, i I) uint32 {
return binary.LittleEndian.Uint32(b[i:])
}
// Load64 will load from b at index i.
func Load64[I Indexer](b []byte, i I) uint64 {
return binary.LittleEndian.Uint64(b[i:])
}
// Store16 will store v at b.
func Store16(b []byte, v uint16) {
binary.LittleEndian.PutUint16(b, v)
}
// Store32 will store v at b.
func Store32(b []byte, v uint32) {
binary.LittleEndian.PutUint32(b, v)
}
// Store64 will store v at b.
func Store64(b []byte, v uint64) {
binary.LittleEndian.PutUint64(b, v)
}

View File

@ -0,0 +1,55 @@
// We enable 64 bit LE platforms:
//go:build (amd64 || arm64 || ppc64le || riscv64) && !nounsafe && !purego && !appengine
package le
import (
"unsafe"
)
// Load8 will load from b at index i.
func Load8[I Indexer](b []byte, i I) byte {
//return binary.LittleEndian.Uint16(b[i:])
//return *(*uint16)(unsafe.Pointer(&b[i]))
return *(*byte)(unsafe.Add(unsafe.Pointer(unsafe.SliceData(b)), i))
}
// Load16 will load from b at index i.
func Load16[I Indexer](b []byte, i I) uint16 {
//return binary.LittleEndian.Uint16(b[i:])
//return *(*uint16)(unsafe.Pointer(&b[i]))
return *(*uint16)(unsafe.Add(unsafe.Pointer(unsafe.SliceData(b)), i))
}
// Load32 will load from b at index i.
func Load32[I Indexer](b []byte, i I) uint32 {
//return binary.LittleEndian.Uint32(b[i:])
//return *(*uint32)(unsafe.Pointer(&b[i]))
return *(*uint32)(unsafe.Add(unsafe.Pointer(unsafe.SliceData(b)), i))
}
// Load64 will load from b at index i.
func Load64[I Indexer](b []byte, i I) uint64 {
//return binary.LittleEndian.Uint64(b[i:])
//return *(*uint64)(unsafe.Pointer(&b[i]))
return *(*uint64)(unsafe.Add(unsafe.Pointer(unsafe.SliceData(b)), i))
}
// Store16 will store v at b.
func Store16(b []byte, v uint16) {
//binary.LittleEndian.PutUint16(b, v)
*(*uint16)(unsafe.Pointer(unsafe.SliceData(b))) = v
}
// Store32 will store v at b.
func Store32(b []byte, v uint32) {
//binary.LittleEndian.PutUint32(b, v)
*(*uint32)(unsafe.Pointer(unsafe.SliceData(b))) = v
}
// Store64 will store v at b.
func Store64(b []byte, v uint64) {
//binary.LittleEndian.PutUint64(b, v)
*(*uint64)(unsafe.Pointer(unsafe.SliceData(b))) = v
}

View File

@ -79,7 +79,7 @@ This will take ownership of the buffer until the stream is closed.
func EncodeStream(src []byte, dst io.Writer) error {
enc := s2.NewWriter(dst)
// The encoder owns the buffer until Flush or Close is called.
err := enc.EncodeBuffer(buf)
err := enc.EncodeBuffer(src)
if err != nil {
enc.Close()
return err

View File

@ -11,6 +11,8 @@ package s2
import (
"fmt"
"strconv"
"github.com/klauspost/compress/internal/le"
)
// decode writes the decoding of src to dst. It assumes that the varint-encoded
@ -38,21 +40,18 @@ func s2Decode(dst, src []byte) int {
case x < 60:
s++
case x == 60:
x = uint32(src[s+1])
s += 2
x = uint32(src[s-1])
case x == 61:
in := src[s : s+3]
x = uint32(in[1]) | uint32(in[2])<<8
x = uint32(le.Load16(src, s+1))
s += 3
case x == 62:
in := src[s : s+4]
// Load as 32 bit and shift down.
x = uint32(in[0]) | uint32(in[1])<<8 | uint32(in[2])<<16 | uint32(in[3])<<24
x = le.Load32(src, s)
x >>= 8
s += 4
case x == 63:
in := src[s : s+5]
x = uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24
x = le.Load32(src, s+1)
s += 5
}
length = int(x) + 1
@ -85,8 +84,7 @@ func s2Decode(dst, src []byte) int {
length = int(src[s]) + 4
s += 1
case 6:
in := src[s : s+2]
length = int(uint32(in[0])|(uint32(in[1])<<8)) + (1 << 8)
length = int(le.Load16(src, s)) + 1<<8
s += 2
case 7:
in := src[s : s+3]
@ -99,15 +97,13 @@ func s2Decode(dst, src []byte) int {
}
length += 4
case tagCopy2:
in := src[s : s+3]
offset = int(uint32(in[1]) | uint32(in[2])<<8)
length = 1 + int(in[0])>>2
offset = int(le.Load16(src, s+1))
length = 1 + int(src[s])>>2
s += 3
case tagCopy4:
in := src[s : s+5]
offset = int(uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24)
length = 1 + int(in[0])>>2
offset = int(le.Load32(src, s+1))
length = 1 + int(src[s])>>2
s += 5
}

View File

@ -10,14 +10,16 @@ import (
"encoding/binary"
"fmt"
"math/bits"
"github.com/klauspost/compress/internal/le"
)
func load32(b []byte, i int) uint32 {
return binary.LittleEndian.Uint32(b[i:])
return le.Load32(b, i)
}
func load64(b []byte, i int) uint64 {
return binary.LittleEndian.Uint64(b[i:])
return le.Load64(b, i)
}
// hash6 returns the hash of the lowest 6 bytes of u to fit in a hash table with h bits.
@ -44,7 +46,12 @@ func encodeGo(dst, src []byte) []byte {
d += emitLiteral(dst[d:], src)
return dst[:d]
}
n := encodeBlockGo(dst[d:], src)
var n int
if len(src) < 64<<10 {
n = encodeBlockGo64K(dst[d:], src)
} else {
n = encodeBlockGo(dst[d:], src)
}
if n > 0 {
d += n
return dst[:d]
@ -70,7 +77,6 @@ func encodeBlockGo(dst, src []byte) (d int) {
debug = false
)
var table [maxTableSize]uint32
// sLimit is when to stop looking for offset/length copies. The inputMargin
@ -277,13 +283,228 @@ emitRemainder:
return d
}
// encodeBlockGo64K is a specialized version for compressing blocks <= 64KB
func encodeBlockGo64K(dst, src []byte) (d int) {
// Initialize the hash table.
const (
tableBits = 14
maxTableSize = 1 << tableBits
debug = false
)
var table [maxTableSize]uint16
// sLimit is when to stop looking for offset/length copies. The inputMargin
// lets us use a fast path for emitLiteral in the main loop, while we are
// looking for copies.
sLimit := len(src) - inputMargin
// Bail if we can't compress to at least this.
dstLimit := len(src) - len(src)>>5 - 5
// nextEmit is where in src the next emitLiteral should start from.
nextEmit := 0
// The encoded form must start with a literal, as there are no previous
// bytes to copy, so we start looking for hash matches at s == 1.
s := 1
cv := load64(src, s)
// We search for a repeat at -1, but don't output repeats when nextEmit == 0
repeat := 1
for {
candidate := 0
for {
// Next src position to check
nextS := s + (s-nextEmit)>>5 + 4
if nextS > sLimit {
goto emitRemainder
}
hash0 := hash6(cv, tableBits)
hash1 := hash6(cv>>8, tableBits)
candidate = int(table[hash0])
candidate2 := int(table[hash1])
table[hash0] = uint16(s)
table[hash1] = uint16(s + 1)
hash2 := hash6(cv>>16, tableBits)
// Check repeat at offset checkRep.
const checkRep = 1
if uint32(cv>>(checkRep*8)) == load32(src, s-repeat+checkRep) {
base := s + checkRep
// Extend back
for i := base - repeat; base > nextEmit && i > 0 && src[i-1] == src[base-1]; {
i--
base--
}
// Bail if we exceed the maximum size.
if d+(base-nextEmit) > dstLimit {
return 0
}
d += emitLiteral(dst[d:], src[nextEmit:base])
// Extend forward
candidate := s - repeat + 4 + checkRep
s += 4 + checkRep
for s <= sLimit {
if diff := load64(src, s) ^ load64(src, candidate); diff != 0 {
s += bits.TrailingZeros64(diff) >> 3
break
}
s += 8
candidate += 8
}
if debug {
// Validate match.
if s <= candidate {
panic("s <= candidate")
}
a := src[base:s]
b := src[base-repeat : base-repeat+(s-base)]
if !bytes.Equal(a, b) {
panic("mismatch")
}
}
if nextEmit > 0 {
// same as `add := emitCopy(dst[d:], repeat, s-base)` but skips storing offset.
d += emitRepeat(dst[d:], repeat, s-base)
} else {
// First match, cannot be repeat.
d += emitCopy(dst[d:], repeat, s-base)
}
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
cv = load64(src, s)
continue
}
if uint32(cv) == load32(src, candidate) {
break
}
candidate = int(table[hash2])
if uint32(cv>>8) == load32(src, candidate2) {
table[hash2] = uint16(s + 2)
candidate = candidate2
s++
break
}
table[hash2] = uint16(s + 2)
if uint32(cv>>16) == load32(src, candidate) {
s += 2
break
}
cv = load64(src, nextS)
s = nextS
}
// Extend backwards.
// The top bytes will be rechecked to get the full match.
for candidate > 0 && s > nextEmit && src[candidate-1] == src[s-1] {
candidate--
s--
}
// Bail if we exceed the maximum size.
if d+(s-nextEmit) > dstLimit {
return 0
}
// A 4-byte match has been found. We'll later see if more than 4 bytes
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
// them as literal bytes.
d += emitLiteral(dst[d:], src[nextEmit:s])
// Call emitCopy, and then see if another emitCopy could be our next
// move. Repeat until we find no match for the input immediately after
// what was consumed by the last emitCopy call.
//
// If we exit this loop normally then we need to call emitLiteral next,
// though we don't yet know how big the literal will be. We handle that
// by proceeding to the next iteration of the main loop. We also can
// exit this loop via goto if we get close to exhausting the input.
for {
// Invariant: we have a 4-byte match at s, and no need to emit any
// literal bytes prior to s.
base := s
repeat = base - candidate
// Extend the 4-byte match as long as possible.
s += 4
candidate += 4
for s <= len(src)-8 {
if diff := load64(src, s) ^ load64(src, candidate); diff != 0 {
s += bits.TrailingZeros64(diff) >> 3
break
}
s += 8
candidate += 8
}
d += emitCopy(dst[d:], repeat, s-base)
if debug {
// Validate match.
if s <= candidate {
panic("s <= candidate")
}
a := src[base:s]
b := src[base-repeat : base-repeat+(s-base)]
if !bytes.Equal(a, b) {
panic("mismatch")
}
}
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
if d > dstLimit {
// Do we have space for more, if not bail.
return 0
}
// Check for an immediate match, otherwise start search at s+1
x := load64(src, s-2)
m2Hash := hash6(x, tableBits)
currHash := hash6(x>>16, tableBits)
candidate = int(table[currHash])
table[m2Hash] = uint16(s - 2)
table[currHash] = uint16(s)
if debug && s == candidate {
panic("s == candidate")
}
if uint32(x>>16) != load32(src, candidate) {
cv = load64(src, s+1)
s++
break
}
}
}
emitRemainder:
if nextEmit < len(src) {
// Bail if we exceed the maximum size.
if d+len(src)-nextEmit > dstLimit {
return 0
}
d += emitLiteral(dst[d:], src[nextEmit:])
}
return d
}
func encodeBlockSnappyGo(dst, src []byte) (d int) {
// Initialize the hash table.
const (
tableBits = 14
maxTableSize = 1 << tableBits
)
var table [maxTableSize]uint32
// sLimit is when to stop looking for offset/length copies. The inputMargin
@ -467,6 +688,197 @@ emitRemainder:
return d
}
// encodeBlockSnappyGo64K is a special version of encodeBlockSnappyGo for sizes <64KB
func encodeBlockSnappyGo64K(dst, src []byte) (d int) {
// Initialize the hash table.
const (
tableBits = 14
maxTableSize = 1 << tableBits
)
var table [maxTableSize]uint16
// sLimit is when to stop looking for offset/length copies. The inputMargin
// lets us use a fast path for emitLiteral in the main loop, while we are
// looking for copies.
sLimit := len(src) - inputMargin
// Bail if we can't compress to at least this.
dstLimit := len(src) - len(src)>>5 - 5
// nextEmit is where in src the next emitLiteral should start from.
nextEmit := 0
// The encoded form must start with a literal, as there are no previous
// bytes to copy, so we start looking for hash matches at s == 1.
s := 1
cv := load64(src, s)
// We search for a repeat at -1, but don't output repeats when nextEmit == 0
repeat := 1
for {
candidate := 0
for {
// Next src position to check
nextS := s + (s-nextEmit)>>5 + 4
if nextS > sLimit {
goto emitRemainder
}
hash0 := hash6(cv, tableBits)
hash1 := hash6(cv>>8, tableBits)
candidate = int(table[hash0])
candidate2 := int(table[hash1])
table[hash0] = uint16(s)
table[hash1] = uint16(s + 1)
hash2 := hash6(cv>>16, tableBits)
// Check repeat at offset checkRep.
const checkRep = 1
if uint32(cv>>(checkRep*8)) == load32(src, s-repeat+checkRep) {
base := s + checkRep
// Extend back
for i := base - repeat; base > nextEmit && i > 0 && src[i-1] == src[base-1]; {
i--
base--
}
// Bail if we exceed the maximum size.
if d+(base-nextEmit) > dstLimit {
return 0
}
d += emitLiteral(dst[d:], src[nextEmit:base])
// Extend forward
candidate := s - repeat + 4 + checkRep
s += 4 + checkRep
for s <= sLimit {
if diff := load64(src, s) ^ load64(src, candidate); diff != 0 {
s += bits.TrailingZeros64(diff) >> 3
break
}
s += 8
candidate += 8
}
d += emitCopyNoRepeat(dst[d:], repeat, s-base)
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
cv = load64(src, s)
continue
}
if uint32(cv) == load32(src, candidate) {
break
}
candidate = int(table[hash2])
if uint32(cv>>8) == load32(src, candidate2) {
table[hash2] = uint16(s + 2)
candidate = candidate2
s++
break
}
table[hash2] = uint16(s + 2)
if uint32(cv>>16) == load32(src, candidate) {
s += 2
break
}
cv = load64(src, nextS)
s = nextS
}
// Extend backwards
for candidate > 0 && s > nextEmit && src[candidate-1] == src[s-1] {
candidate--
s--
}
// Bail if we exceed the maximum size.
if d+(s-nextEmit) > dstLimit {
return 0
}
// A 4-byte match has been found. We'll later see if more than 4 bytes
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
// them as literal bytes.
d += emitLiteral(dst[d:], src[nextEmit:s])
// Call emitCopy, and then see if another emitCopy could be our next
// move. Repeat until we find no match for the input immediately after
// what was consumed by the last emitCopy call.
//
// If we exit this loop normally then we need to call emitLiteral next,
// though we don't yet know how big the literal will be. We handle that
// by proceeding to the next iteration of the main loop. We also can
// exit this loop via goto if we get close to exhausting the input.
for {
// Invariant: we have a 4-byte match at s, and no need to emit any
// literal bytes prior to s.
base := s
repeat = base - candidate
// Extend the 4-byte match as long as possible.
s += 4
candidate += 4
for s <= len(src)-8 {
if diff := load64(src, s) ^ load64(src, candidate); diff != 0 {
s += bits.TrailingZeros64(diff) >> 3
break
}
s += 8
candidate += 8
}
d += emitCopyNoRepeat(dst[d:], repeat, s-base)
if false {
// Validate match.
a := src[base:s]
b := src[base-repeat : base-repeat+(s-base)]
if !bytes.Equal(a, b) {
panic("mismatch")
}
}
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
if d > dstLimit {
// Do we have space for more, if not bail.
return 0
}
// Check for an immediate match, otherwise start search at s+1
x := load64(src, s-2)
m2Hash := hash6(x, tableBits)
currHash := hash6(x>>16, tableBits)
candidate = int(table[currHash])
table[m2Hash] = uint16(s - 2)
table[currHash] = uint16(s)
if uint32(x>>16) != load32(src, candidate) {
cv = load64(src, s+1)
s++
break
}
}
}
emitRemainder:
if nextEmit < len(src) {
// Bail if we exceed the maximum size.
if d+len(src)-nextEmit > dstLimit {
return 0
}
d += emitLiteral(dst[d:], src[nextEmit:])
}
return d
}
// encodeBlockGo encodes a non-empty src to a guaranteed-large-enough dst. It
// assumes that the varint-encoded length of the decompressed bytes has already
// been written.

View File

@ -348,12 +348,7 @@ func encodeBlockBetterSnappyGo(dst, src []byte) (d int) {
nextS := 0
for {
// Next src position to check
nextS = (s-nextEmit)>>7 + 1
if nextS > maxSkip {
nextS = s + maxSkip
} else {
nextS += s
}
nextS = min(s+(s-nextEmit)>>7+1, s+maxSkip)
if nextS > sLimit {
goto emitRemainder
@ -483,6 +478,415 @@ emitRemainder:
return d
}
func encodeBlockBetterGo64K(dst, src []byte) (d int) {
// sLimit is when to stop looking for offset/length copies. The inputMargin
// lets us use a fast path for emitLiteral in the main loop, while we are
// looking for copies.
sLimit := len(src) - inputMargin
if len(src) < minNonLiteralBlockSize {
return 0
}
// Initialize the hash tables.
// Use smaller tables for smaller blocks
const (
// Long hash matches.
lTableBits = 16
maxLTableSize = 1 << lTableBits
// Short hash matches.
sTableBits = 13
maxSTableSize = 1 << sTableBits
)
var lTable [maxLTableSize]uint16
var sTable [maxSTableSize]uint16
// Bail if we can't compress to at least this.
dstLimit := len(src) - len(src)>>5 - 6
// nextEmit is where in src the next emitLiteral should start from.
nextEmit := 0
// The encoded form must start with a literal, as there are no previous
// bytes to copy, so we start looking for hash matches at s == 1.
s := 1
cv := load64(src, s)
// We initialize repeat to 0, so we never match on first attempt
repeat := 0
for {
candidateL := 0
nextS := 0
for {
// Next src position to check
nextS = s + (s-nextEmit)>>6 + 1
if nextS > sLimit {
goto emitRemainder
}
hashL := hash7(cv, lTableBits)
hashS := hash4(cv, sTableBits)
candidateL = int(lTable[hashL])
candidateS := int(sTable[hashS])
lTable[hashL] = uint16(s)
sTable[hashS] = uint16(s)
valLong := load64(src, candidateL)
valShort := load64(src, candidateS)
// If long matches at least 8 bytes, use that.
if cv == valLong {
break
}
if cv == valShort {
candidateL = candidateS
break
}
// Check repeat at offset checkRep.
const checkRep = 1
// Minimum length of a repeat. Tested with various values.
// While 4-5 offers improvements in some, 6 reduces
// regressions significantly.
const wantRepeatBytes = 6
const repeatMask = ((1 << (wantRepeatBytes * 8)) - 1) << (8 * checkRep)
if false && repeat > 0 && cv&repeatMask == load64(src, s-repeat)&repeatMask {
base := s + checkRep
// Extend back
for i := base - repeat; base > nextEmit && i > 0 && src[i-1] == src[base-1]; {
i--
base--
}
d += emitLiteral(dst[d:], src[nextEmit:base])
// Extend forward
candidate := s - repeat + wantRepeatBytes + checkRep
s += wantRepeatBytes + checkRep
for s < len(src) {
if len(src)-s < 8 {
if src[s] == src[candidate] {
s++
candidate++
continue
}
break
}
if diff := load64(src, s) ^ load64(src, candidate); diff != 0 {
s += bits.TrailingZeros64(diff) >> 3
break
}
s += 8
candidate += 8
}
// same as `add := emitCopy(dst[d:], repeat, s-base)` but skips storing offset.
d += emitRepeat(dst[d:], repeat, s-base)
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
// Index in-between
index0 := base + 1
index1 := s - 2
for index0 < index1 {
cv0 := load64(src, index0)
cv1 := load64(src, index1)
lTable[hash7(cv0, lTableBits)] = uint16(index0)
sTable[hash4(cv0>>8, sTableBits)] = uint16(index0 + 1)
lTable[hash7(cv1, lTableBits)] = uint16(index1)
sTable[hash4(cv1>>8, sTableBits)] = uint16(index1 + 1)
index0 += 2
index1 -= 2
}
cv = load64(src, s)
continue
}
// Long likely matches 7, so take that.
if uint32(cv) == uint32(valLong) {
break
}
// Check our short candidate
if uint32(cv) == uint32(valShort) {
// Try a long candidate at s+1
hashL = hash7(cv>>8, lTableBits)
candidateL = int(lTable[hashL])
lTable[hashL] = uint16(s + 1)
if uint32(cv>>8) == load32(src, candidateL) {
s++
break
}
// Use our short candidate.
candidateL = candidateS
break
}
cv = load64(src, nextS)
s = nextS
}
// Extend backwards
for candidateL > 0 && s > nextEmit && src[candidateL-1] == src[s-1] {
candidateL--
s--
}
// Bail if we exceed the maximum size.
if d+(s-nextEmit) > dstLimit {
return 0
}
base := s
offset := base - candidateL
// Extend the 4-byte match as long as possible.
s += 4
candidateL += 4
for s < len(src) {
if len(src)-s < 8 {
if src[s] == src[candidateL] {
s++
candidateL++
continue
}
break
}
if diff := load64(src, s) ^ load64(src, candidateL); diff != 0 {
s += bits.TrailingZeros64(diff) >> 3
break
}
s += 8
candidateL += 8
}
d += emitLiteral(dst[d:], src[nextEmit:base])
if repeat == offset {
d += emitRepeat(dst[d:], offset, s-base)
} else {
d += emitCopy(dst[d:], offset, s-base)
repeat = offset
}
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
if d > dstLimit {
// Do we have space for more, if not bail.
return 0
}
// Index short & long
index0 := base + 1
index1 := s - 2
cv0 := load64(src, index0)
cv1 := load64(src, index1)
lTable[hash7(cv0, lTableBits)] = uint16(index0)
sTable[hash4(cv0>>8, sTableBits)] = uint16(index0 + 1)
// lTable could be postponed, but very minor difference.
lTable[hash7(cv1, lTableBits)] = uint16(index1)
sTable[hash4(cv1>>8, sTableBits)] = uint16(index1 + 1)
index0 += 1
index1 -= 1
cv = load64(src, s)
// Index large values sparsely in between.
// We do two starting from different offsets for speed.
index2 := (index0 + index1 + 1) >> 1
for index2 < index1 {
lTable[hash7(load64(src, index0), lTableBits)] = uint16(index0)
lTable[hash7(load64(src, index2), lTableBits)] = uint16(index2)
index0 += 2
index2 += 2
}
}
emitRemainder:
if nextEmit < len(src) {
// Bail if we exceed the maximum size.
if d+len(src)-nextEmit > dstLimit {
return 0
}
d += emitLiteral(dst[d:], src[nextEmit:])
}
return d
}
// encodeBlockBetterSnappyGo encodes a non-empty src to a guaranteed-large-enough dst. It
// assumes that the varint-encoded length of the decompressed bytes has already
// been written.
//
// It also assumes that:
//
// len(dst) >= MaxEncodedLen(len(src)) &&
// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize
func encodeBlockBetterSnappyGo64K(dst, src []byte) (d int) {
// sLimit is when to stop looking for offset/length copies. The inputMargin
// lets us use a fast path for emitLiteral in the main loop, while we are
// looking for copies.
sLimit := len(src) - inputMargin
if len(src) < minNonLiteralBlockSize {
return 0
}
// Initialize the hash tables.
// Use smaller tables for smaller blocks
const (
// Long hash matches.
lTableBits = 15
maxLTableSize = 1 << lTableBits
// Short hash matches.
sTableBits = 13
maxSTableSize = 1 << sTableBits
)
var lTable [maxLTableSize]uint16
var sTable [maxSTableSize]uint16
// Bail if we can't compress to at least this.
dstLimit := len(src) - len(src)>>5 - 6
// nextEmit is where in src the next emitLiteral should start from.
nextEmit := 0
// The encoded form must start with a literal, as there are no previous
// bytes to copy, so we start looking for hash matches at s == 1.
s := 1
cv := load64(src, s)
const maxSkip = 100
for {
candidateL := 0
nextS := 0
for {
// Next src position to check
nextS = min(s+(s-nextEmit)>>6+1, s+maxSkip)
if nextS > sLimit {
goto emitRemainder
}
hashL := hash7(cv, lTableBits)
hashS := hash4(cv, sTableBits)
candidateL = int(lTable[hashL])
candidateS := int(sTable[hashS])
lTable[hashL] = uint16(s)
sTable[hashS] = uint16(s)
if uint32(cv) == load32(src, candidateL) {
break
}
// Check our short candidate
if uint32(cv) == load32(src, candidateS) {
// Try a long candidate at s+1
hashL = hash7(cv>>8, lTableBits)
candidateL = int(lTable[hashL])
lTable[hashL] = uint16(s + 1)
if uint32(cv>>8) == load32(src, candidateL) {
s++
break
}
// Use our short candidate.
candidateL = candidateS
break
}
cv = load64(src, nextS)
s = nextS
}
// Extend backwards
for candidateL > 0 && s > nextEmit && src[candidateL-1] == src[s-1] {
candidateL--
s--
}
// Bail if we exceed the maximum size.
if d+(s-nextEmit) > dstLimit {
return 0
}
base := s
offset := base - candidateL
// Extend the 4-byte match as long as possible.
s += 4
candidateL += 4
for s < len(src) {
if len(src)-s < 8 {
if src[s] == src[candidateL] {
s++
candidateL++
continue
}
break
}
if diff := load64(src, s) ^ load64(src, candidateL); diff != 0 {
s += bits.TrailingZeros64(diff) >> 3
break
}
s += 8
candidateL += 8
}
d += emitLiteral(dst[d:], src[nextEmit:base])
d += emitCopyNoRepeat(dst[d:], offset, s-base)
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
if d > dstLimit {
// Do we have space for more, if not bail.
return 0
}
// Index short & long
index0 := base + 1
index1 := s - 2
cv0 := load64(src, index0)
cv1 := load64(src, index1)
lTable[hash7(cv0, lTableBits)] = uint16(index0)
sTable[hash4(cv0>>8, sTableBits)] = uint16(index0 + 1)
lTable[hash7(cv1, lTableBits)] = uint16(index1)
sTable[hash4(cv1>>8, sTableBits)] = uint16(index1 + 1)
index0 += 1
index1 -= 1
cv = load64(src, s)
// Index large values sparsely in between.
// We do two starting from different offsets for speed.
index2 := (index0 + index1 + 1) >> 1
for index2 < index1 {
lTable[hash7(load64(src, index0), lTableBits)] = uint16(index0)
lTable[hash7(load64(src, index2), lTableBits)] = uint16(index2)
index0 += 2
index2 += 2
}
}
emitRemainder:
if nextEmit < len(src) {
// Bail if we exceed the maximum size.
if d+len(src)-nextEmit > dstLimit {
return 0
}
d += emitLiteral(dst[d:], src[nextEmit:])
}
return d
}
// encodeBlockBetterDict encodes a non-empty src to a guaranteed-large-enough dst. It
// assumes that the varint-encoded length of the decompressed bytes has already
// been written.

View File

@ -21,6 +21,9 @@ func encodeBlock(dst, src []byte) (d int) {
if len(src) < minNonLiteralBlockSize {
return 0
}
if len(src) <= 64<<10 {
return encodeBlockGo64K(dst, src)
}
return encodeBlockGo(dst, src)
}
@ -32,6 +35,9 @@ func encodeBlock(dst, src []byte) (d int) {
//
// len(dst) >= MaxEncodedLen(len(src))
func encodeBlockBetter(dst, src []byte) (d int) {
if len(src) <= 64<<10 {
return encodeBlockBetterGo64K(dst, src)
}
return encodeBlockBetterGo(dst, src)
}
@ -43,6 +49,9 @@ func encodeBlockBetter(dst, src []byte) (d int) {
//
// len(dst) >= MaxEncodedLen(len(src))
func encodeBlockBetterSnappy(dst, src []byte) (d int) {
if len(src) <= 64<<10 {
return encodeBlockBetterSnappyGo64K(dst, src)
}
return encodeBlockBetterSnappyGo(dst, src)
}
@ -57,6 +66,9 @@ func encodeBlockSnappy(dst, src []byte) (d int) {
if len(src) < minNonLiteralBlockSize {
return 0
}
if len(src) <= 64<<10 {
return encodeBlockSnappyGo64K(dst, src)
}
return encodeBlockSnappyGo(dst, src)
}

View File

@ -1,4 +1,3 @@
module github.com/klauspost/compress
go 1.19
go 1.22

View File

@ -6,7 +6,7 @@ A high performance compression algorithm is implemented. For now focused on spee
This package provides [compression](#Compressor) to and [decompression](#Decompressor) of Zstandard content.
This package is pure Go and without use of "unsafe".
This package is pure Go. Use `noasm` and `nounsafe` to disable relevant features.
The `zstd` package is provided as open source software using a Go standard license.

View File

@ -5,11 +5,12 @@
package zstd
import (
"encoding/binary"
"errors"
"fmt"
"io"
"math/bits"
"github.com/klauspost/compress/internal/le"
)
// bitReader reads a bitstream in reverse.
@ -18,6 +19,7 @@ import (
type bitReader struct {
in []byte
value uint64 // Maybe use [16]byte, but shifting is awkward.
cursor int // offset where next read should end
bitsRead uint8
}
@ -32,6 +34,7 @@ func (b *bitReader) init(in []byte) error {
if v == 0 {
return errors.New("corrupt stream, did not find end of stream")
}
b.cursor = len(in)
b.bitsRead = 64
b.value = 0
if len(in) >= 8 {
@ -67,18 +70,15 @@ func (b *bitReader) fillFast() {
if b.bitsRead < 32 {
return
}
v := b.in[len(b.in)-4:]
b.in = b.in[:len(b.in)-4]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
b.value = (b.value << 32) | uint64(low)
b.cursor -= 4
b.value = (b.value << 32) | uint64(le.Load32(b.in, b.cursor))
b.bitsRead -= 32
}
// fillFastStart() assumes the bitreader is empty and there is at least 8 bytes to read.
func (b *bitReader) fillFastStart() {
v := b.in[len(b.in)-8:]
b.in = b.in[:len(b.in)-8]
b.value = binary.LittleEndian.Uint64(v)
b.cursor -= 8
b.value = le.Load64(b.in, b.cursor)
b.bitsRead = 0
}
@ -87,25 +87,23 @@ func (b *bitReader) fill() {
if b.bitsRead < 32 {
return
}
if len(b.in) >= 4 {
v := b.in[len(b.in)-4:]
b.in = b.in[:len(b.in)-4]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
b.value = (b.value << 32) | uint64(low)
if b.cursor >= 4 {
b.cursor -= 4
b.value = (b.value << 32) | uint64(le.Load32(b.in, b.cursor))
b.bitsRead -= 32
return
}
b.bitsRead -= uint8(8 * len(b.in))
for len(b.in) > 0 {
b.value = (b.value << 8) | uint64(b.in[len(b.in)-1])
b.in = b.in[:len(b.in)-1]
b.bitsRead -= uint8(8 * b.cursor)
for b.cursor > 0 {
b.cursor -= 1
b.value = (b.value << 8) | uint64(b.in[b.cursor])
}
}
// finished returns true if all bits have been read from the bit stream.
func (b *bitReader) finished() bool {
return len(b.in) == 0 && b.bitsRead >= 64
return b.cursor == 0 && b.bitsRead >= 64
}
// overread returns true if more bits have been requested than is on the stream.
@ -115,13 +113,14 @@ func (b *bitReader) overread() bool {
// remain returns the number of bits remaining.
func (b *bitReader) remain() uint {
return 8*uint(len(b.in)) + 64 - uint(b.bitsRead)
return 8*uint(b.cursor) + 64 - uint(b.bitsRead)
}
// close the bitstream and returns an error if out-of-buffer reads occurred.
func (b *bitReader) close() error {
// Release reference.
b.in = nil
b.cursor = 0
if !b.finished() {
return fmt.Errorf("%d extra bits on block, should be 0", b.remain())
}

View File

@ -5,14 +5,10 @@
package zstd
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io"
"os"
"path/filepath"
"sync"
"github.com/klauspost/compress/huff0"
@ -648,21 +644,6 @@ func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) {
println("initializing sequences:", err)
return err
}
// Extract blocks...
if false && hist.dict == nil {
fatalErr := func(err error) {
if err != nil {
panic(err)
}
}
fn := fmt.Sprintf("n-%d-lits-%d-prev-%d-%d-%d-win-%d.blk", hist.decoders.nSeqs, len(hist.decoders.literals), hist.recentOffsets[0], hist.recentOffsets[1], hist.recentOffsets[2], hist.windowSize)
var buf bytes.Buffer
fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.litLengths.fse))
fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.matchLengths.fse))
fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.offsets.fse))
buf.Write(in)
os.WriteFile(filepath.Join("testdata", "seqs", fn), buf.Bytes(), os.ModePerm)
}
return nil
}

View File

@ -9,6 +9,7 @@ import (
"fmt"
"math"
"math/bits"
"slices"
"github.com/klauspost/compress/huff0"
)
@ -457,16 +458,7 @@ func fuzzFseEncoder(data []byte) int {
// All 0
return 0
}
maxCount := func(a []uint32) int {
var max uint32
for _, v := range a {
if v > max {
max = v
}
}
return int(max)
}
cnt := maxCount(hist[:maxSym])
cnt := int(slices.Max(hist[:maxSym]))
if cnt == len(data) {
// RLE
return 0
@ -884,15 +876,6 @@ func (b *blockEnc) genCodes() {
}
}
}
maxCount := func(a []uint32) int {
var max uint32
for _, v := range a {
if v > max {
max = v
}
}
return int(max)
}
if debugAsserts && mlMax > maxMatchLengthSymbol {
panic(fmt.Errorf("mlMax > maxMatchLengthSymbol (%d)", mlMax))
}
@ -903,7 +886,7 @@ func (b *blockEnc) genCodes() {
panic(fmt.Errorf("llMax > maxLiteralLengthSymbol (%d)", llMax))
}
b.coders.mlEnc.HistogramFinished(mlMax, maxCount(mlH[:mlMax+1]))
b.coders.ofEnc.HistogramFinished(ofMax, maxCount(ofH[:ofMax+1]))
b.coders.llEnc.HistogramFinished(llMax, maxCount(llH[:llMax+1]))
b.coders.mlEnc.HistogramFinished(mlMax, int(slices.Max(mlH[:mlMax+1])))
b.coders.ofEnc.HistogramFinished(ofMax, int(slices.Max(ofH[:ofMax+1])))
b.coders.llEnc.HistogramFinished(llMax, int(slices.Max(llH[:llMax+1])))
}

View File

@ -123,7 +123,7 @@ func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) {
}
// Read bytes from the decompressed stream into p.
// Returns the number of bytes written and any error that occurred.
// Returns the number of bytes read and any error that occurred.
// When the stream is done, io.EOF will be returned.
func (d *Decoder) Read(p []byte) (int, error) {
var n int
@ -323,6 +323,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
frame.bBuf = nil
if frame.history.decoders.br != nil {
frame.history.decoders.br.in = nil
frame.history.decoders.br.cursor = 0
}
d.decoders <- block
}()

View File

@ -116,7 +116,7 @@ func (e *fastBase) matchlen(s, t int32, src []byte) int32 {
panic(err)
}
if t < 0 {
err := fmt.Sprintf("s (%d) < 0", s)
err := fmt.Sprintf("t (%d) < 0", t)
panic(err)
}
if s-t > e.maxMatchOff {

View File

@ -7,20 +7,25 @@
package zstd
import (
"encoding/binary"
"math/bits"
"github.com/klauspost/compress/internal/le"
)
// matchLen returns the maximum common prefix length of a and b.
// a must be the shortest of the two.
func matchLen(a, b []byte) (n int) {
for ; len(a) >= 8 && len(b) >= 8; a, b = a[8:], b[8:] {
diff := binary.LittleEndian.Uint64(a) ^ binary.LittleEndian.Uint64(b)
left := len(a)
for left >= 8 {
diff := le.Load64(a, n) ^ le.Load64(b, n)
if diff != 0 {
return n + bits.TrailingZeros64(diff)>>3
}
n += 8
left -= 8
}
a = a[n:]
b = b[n:]
for i := range a {
if a[i] != b[i] {

View File

@ -245,7 +245,7 @@ func (s *sequenceDecs) decodeSync(hist []byte) error {
return io.ErrUnexpectedEOF
}
var ll, mo, ml int
if len(br.in) > 4+((maxOffsetBits+16+16)>>3) {
if br.cursor > 4+((maxOffsetBits+16+16)>>3) {
// inlined function:
// ll, mo, ml = s.nextFast(br, llState, mlState, ofState)

View File

@ -7,9 +7,9 @@
TEXT ·sequenceDecs_decode_amd64(SB), $8-32
MOVQ br+8(FP), CX
MOVQ 24(CX), DX
MOVBQZX 32(CX), BX
MOVBQZX 40(CX), BX
MOVQ (CX), AX
MOVQ 8(CX), SI
MOVQ 32(CX), SI
ADDQ SI, AX
MOVQ AX, (SP)
MOVQ ctx+16(FP), AX
@ -299,8 +299,8 @@ sequenceDecs_decode_amd64_match_len_ofs_ok:
MOVQ R13, 160(AX)
MOVQ br+8(FP), AX
MOVQ DX, 24(AX)
MOVB BL, 32(AX)
MOVQ SI, 8(AX)
MOVB BL, 40(AX)
MOVQ SI, 32(AX)
// Return success
MOVQ $0x00000000, ret+24(FP)
@ -335,9 +335,9 @@ error_overread:
TEXT ·sequenceDecs_decode_56_amd64(SB), $8-32
MOVQ br+8(FP), CX
MOVQ 24(CX), DX
MOVBQZX 32(CX), BX
MOVBQZX 40(CX), BX
MOVQ (CX), AX
MOVQ 8(CX), SI
MOVQ 32(CX), SI
ADDQ SI, AX
MOVQ AX, (SP)
MOVQ ctx+16(FP), AX
@ -598,8 +598,8 @@ sequenceDecs_decode_56_amd64_match_len_ofs_ok:
MOVQ R13, 160(AX)
MOVQ br+8(FP), AX
MOVQ DX, 24(AX)
MOVB BL, 32(AX)
MOVQ SI, 8(AX)
MOVB BL, 40(AX)
MOVQ SI, 32(AX)
// Return success
MOVQ $0x00000000, ret+24(FP)
@ -634,9 +634,9 @@ error_overread:
TEXT ·sequenceDecs_decode_bmi2(SB), $8-32
MOVQ br+8(FP), BX
MOVQ 24(BX), AX
MOVBQZX 32(BX), DX
MOVBQZX 40(BX), DX
MOVQ (BX), CX
MOVQ 8(BX), BX
MOVQ 32(BX), BX
ADDQ BX, CX
MOVQ CX, (SP)
MOVQ ctx+16(FP), CX
@ -884,8 +884,8 @@ sequenceDecs_decode_bmi2_match_len_ofs_ok:
MOVQ R12, 160(CX)
MOVQ br+8(FP), CX
MOVQ AX, 24(CX)
MOVB DL, 32(CX)
MOVQ BX, 8(CX)
MOVB DL, 40(CX)
MOVQ BX, 32(CX)
// Return success
MOVQ $0x00000000, ret+24(FP)
@ -920,9 +920,9 @@ error_overread:
TEXT ·sequenceDecs_decode_56_bmi2(SB), $8-32
MOVQ br+8(FP), BX
MOVQ 24(BX), AX
MOVBQZX 32(BX), DX
MOVBQZX 40(BX), DX
MOVQ (BX), CX
MOVQ 8(BX), BX
MOVQ 32(BX), BX
ADDQ BX, CX
MOVQ CX, (SP)
MOVQ ctx+16(FP), CX
@ -1141,8 +1141,8 @@ sequenceDecs_decode_56_bmi2_match_len_ofs_ok:
MOVQ R12, 160(CX)
MOVQ br+8(FP), CX
MOVQ AX, 24(CX)
MOVB DL, 32(CX)
MOVQ BX, 8(CX)
MOVB DL, 40(CX)
MOVQ BX, 32(CX)
// Return success
MOVQ $0x00000000, ret+24(FP)
@ -1787,9 +1787,9 @@ empty_seqs:
TEXT ·sequenceDecs_decodeSync_amd64(SB), $64-32
MOVQ br+8(FP), CX
MOVQ 24(CX), DX
MOVBQZX 32(CX), BX
MOVBQZX 40(CX), BX
MOVQ (CX), AX
MOVQ 8(CX), SI
MOVQ 32(CX), SI
ADDQ SI, AX
MOVQ AX, (SP)
MOVQ ctx+16(FP), AX
@ -2281,8 +2281,8 @@ handle_loop:
loop_finished:
MOVQ br+8(FP), AX
MOVQ DX, 24(AX)
MOVB BL, 32(AX)
MOVQ SI, 8(AX)
MOVB BL, 40(AX)
MOVQ SI, 32(AX)
// Update the context
MOVQ ctx+16(FP), AX
@ -2349,9 +2349,9 @@ error_not_enough_space:
TEXT ·sequenceDecs_decodeSync_bmi2(SB), $64-32
MOVQ br+8(FP), BX
MOVQ 24(BX), AX
MOVBQZX 32(BX), DX
MOVBQZX 40(BX), DX
MOVQ (BX), CX
MOVQ 8(BX), BX
MOVQ 32(BX), BX
ADDQ BX, CX
MOVQ CX, (SP)
MOVQ ctx+16(FP), CX
@ -2801,8 +2801,8 @@ handle_loop:
loop_finished:
MOVQ br+8(FP), CX
MOVQ AX, 24(CX)
MOVB DL, 32(CX)
MOVQ BX, 8(CX)
MOVB DL, 40(CX)
MOVQ BX, 32(CX)
// Update the context
MOVQ ctx+16(FP), AX
@ -2869,9 +2869,9 @@ error_not_enough_space:
TEXT ·sequenceDecs_decodeSync_safe_amd64(SB), $64-32
MOVQ br+8(FP), CX
MOVQ 24(CX), DX
MOVBQZX 32(CX), BX
MOVBQZX 40(CX), BX
MOVQ (CX), AX
MOVQ 8(CX), SI
MOVQ 32(CX), SI
ADDQ SI, AX
MOVQ AX, (SP)
MOVQ ctx+16(FP), AX
@ -3465,8 +3465,8 @@ handle_loop:
loop_finished:
MOVQ br+8(FP), AX
MOVQ DX, 24(AX)
MOVB BL, 32(AX)
MOVQ SI, 8(AX)
MOVB BL, 40(AX)
MOVQ SI, 32(AX)
// Update the context
MOVQ ctx+16(FP), AX
@ -3533,9 +3533,9 @@ error_not_enough_space:
TEXT ·sequenceDecs_decodeSync_safe_bmi2(SB), $64-32
MOVQ br+8(FP), BX
MOVQ 24(BX), AX
MOVBQZX 32(BX), DX
MOVBQZX 40(BX), DX
MOVQ (BX), CX
MOVQ 8(BX), BX
MOVQ 32(BX), BX
ADDQ BX, CX
MOVQ CX, (SP)
MOVQ ctx+16(FP), CX
@ -4087,8 +4087,8 @@ handle_loop:
loop_finished:
MOVQ br+8(FP), CX
MOVQ AX, 24(CX)
MOVB DL, 32(CX)
MOVQ BX, 8(CX)
MOVB DL, 40(CX)
MOVQ BX, 32(CX)
// Update the context
MOVQ ctx+16(FP), AX

View File

@ -29,7 +29,7 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
}
for i := range seqs {
var ll, mo, ml int
if len(br.in) > 4+((maxOffsetBits+16+16)>>3) {
if br.cursor > 4+((maxOffsetBits+16+16)>>3) {
// inlined function:
// ll, mo, ml = s.nextFast(br, llState, mlState, ofState)

View File

@ -69,7 +69,6 @@ var llBitsTable = [maxLLCode + 1]byte{
func llCode(litLength uint32) uint8 {
const llDeltaCode = 19
if litLength <= 63 {
// Compiler insists on bounds check (Go 1.12)
return llCodeTable[litLength&63]
}
return uint8(highBit(litLength)) + llDeltaCode
@ -102,7 +101,6 @@ var mlBitsTable = [maxMLCode + 1]byte{
func mlCode(mlBase uint32) uint8 {
const mlDeltaCode = 36
if mlBase <= 127 {
// Compiler insists on bounds check (Go 1.12)
return mlCodeTable[mlBase&127]
}
return uint8(highBit(mlBase)) + mlDeltaCode

View File

@ -197,7 +197,7 @@ func (r *SnappyConverter) Convert(in io.Reader, w io.Writer) (int64, error) {
n, r.err = w.Write(r.block.output)
if r.err != nil {
return written, err
return written, r.err
}
written += int64(n)
continue
@ -239,7 +239,7 @@ func (r *SnappyConverter) Convert(in io.Reader, w io.Writer) (int64, error) {
}
n, r.err = w.Write(r.block.output)
if r.err != nil {
return written, err
return written, r.err
}
written += int64(n)
continue

View File

@ -5,10 +5,11 @@ package zstd
import (
"bytes"
"encoding/binary"
"errors"
"log"
"math"
"github.com/klauspost/compress/internal/le"
)
// enable debug printing
@ -110,11 +111,11 @@ func printf(format string, a ...interface{}) {
}
func load3232(b []byte, i int32) uint32 {
return binary.LittleEndian.Uint32(b[:len(b):len(b)][i:])
return le.Load32(b, i)
}
func load6432(b []byte, i int32) uint64 {
return binary.LittleEndian.Uint64(b[:len(b):len(b)][i:])
return le.Load64(b, i)
}
type byter interface {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2018-2023 The NATS Authors
* Copyright 2018-2024 The NATS Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@ -133,7 +133,7 @@ func (o *OperatorLimits) Validate(vr *ValidationResults) {
}
}
// Mapping for publishes
// WeightedMapping for publishes
type WeightedMapping struct {
Subject Subject `json:"subject"`
Weight uint8 `json:"weight,omitempty"`
@ -177,13 +177,13 @@ func (a *Account) AddMapping(sub Subject, to ...WeightedMapping) {
a.Mappings[sub] = to
}
// Enable external authorization for account users.
// ExternalAuthorization enables external authorization for account users.
// AuthUsers are those users specified to bypass the authorization callout and should be used for the authorization service itself.
// AllowedAccounts specifies which accounts, if any, that the authorization service can bind an authorized user to.
// The authorization response, a user JWT, will still need to be signed by the correct account.
// If optional XKey is specified, that is the public xkey (x25519) and the server will encrypt the request such that only the
// holder of the private key can decrypt. The auth service can also optionally encrypt the response back to the server using it's
// publick xkey which will be in the authorization request.
// public xkey which will be in the authorization request.
type ExternalAuthorization struct {
AuthUsers StringList `json:"auth_users,omitempty"`
AllowedAccounts StringList `json:"allowed_accounts,omitempty"`
@ -194,12 +194,12 @@ func (ac *ExternalAuthorization) IsEnabled() bool {
return len(ac.AuthUsers) > 0
}
// Helper function to determine if external authorization is enabled.
// HasExternalAuthorization helper function to determine if external authorization is enabled.
func (a *Account) HasExternalAuthorization() bool {
return a.Authorization.IsEnabled()
}
// Helper function to setup external authorization.
// EnableExternalAuthorization helper function to setup external authorization.
func (a *Account) EnableExternalAuthorization(users ...string) {
a.Authorization.AuthUsers.Add(users...)
}
@ -230,6 +230,20 @@ func (ac *ExternalAuthorization) Validate(vr *ValidationResults) {
}
}
const (
ClusterTrafficSystem = "system"
ClusterTrafficOwner = "owner"
)
type ClusterTraffic string
func (ct ClusterTraffic) Valid() error {
if ct == "" || ct == ClusterTrafficSystem || ct == ClusterTrafficOwner {
return nil
}
return fmt.Errorf("unknown cluster traffic option: %q", ct)
}
// Account holds account specific claims data
type Account struct {
Imports Imports `json:"imports,omitempty"`
@ -241,6 +255,7 @@ type Account struct {
Mappings Mapping `json:"mappings,omitempty"`
Authorization ExternalAuthorization `json:"authorization,omitempty"`
Trace *MsgTrace `json:"trace,omitempty"`
ClusterTraffic ClusterTraffic `json:"cluster_traffic,omitempty"`
Info
GenericFields
}
@ -308,6 +323,10 @@ func (a *Account) Validate(acct *AccountClaims, vr *ValidationResults) {
}
a.SigningKeys.Validate(vr)
a.Info.Validate(vr)
if err := a.ClusterTraffic.Valid(); err != nil {
vr.AddError(err.Error())
}
}
// AccountClaims defines the body of an account JWT
@ -338,13 +357,17 @@ func NewAccountClaims(subject string) *AccountClaims {
// Encode converts account claims into a JWT string
func (a *AccountClaims) Encode(pair nkeys.KeyPair) (string, error) {
return a.EncodeWithSigner(pair, nil)
}
func (a *AccountClaims) EncodeWithSigner(pair nkeys.KeyPair, fn SignFn) (string, error) {
if !nkeys.IsValidPublicAccountKey(a.Subject) {
return "", errors.New("expected subject to be account public key")
}
sort.Sort(a.Exports)
sort.Sort(a.Imports)
a.Type = AccountClaim
return a.ClaimsData.encode(pair, a)
return a.ClaimsData.encode(pair, a, fn)
}
// DecodeAccountClaims decodes account claims from a JWT string

View File

@ -1,5 +1,5 @@
/*
* Copyright 2018 The NATS Authors
* Copyright 2018-2024 The NATS Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@ -72,11 +72,15 @@ func NewActivationClaims(subject string) *ActivationClaims {
// Encode turns an activation claim into a JWT strimg
func (a *ActivationClaims) Encode(pair nkeys.KeyPair) (string, error) {
return a.EncodeWithSigner(pair, nil)
}
func (a *ActivationClaims) EncodeWithSigner(pair nkeys.KeyPair, fn SignFn) (string, error) {
if !nkeys.IsValidPublicAccountKey(a.ClaimsData.Subject) {
return "", errors.New("expected subject to be an account")
}
a.Type = ActivationClaim
return a.ClaimsData.encode(pair, a)
return a.ClaimsData.encode(pair, a, fn)
}
// DecodeActivationClaims tries to create an activation claim from a JWT string

View File

@ -1,5 +1,5 @@
/*
* Copyright 2022 The NATS Authors
* Copyright 2022-2024 The NATS Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@ -113,8 +113,12 @@ func (ac *AuthorizationRequestClaims) Validate(vr *ValidationResults) {
// Encode tries to turn the auth request claims into a JWT string.
func (ac *AuthorizationRequestClaims) Encode(pair nkeys.KeyPair) (string, error) {
return ac.EncodeWithSigner(pair, nil)
}
func (ac *AuthorizationRequestClaims) EncodeWithSigner(pair nkeys.KeyPair, fn SignFn) (string, error) {
ac.Type = AuthorizationRequestClaim
return ac.ClaimsData.encode(pair, ac)
return ac.ClaimsData.encode(pair, ac, fn)
}
// DecodeAuthorizationRequestClaims tries to parse an auth request claims from a JWT string
@ -242,6 +246,10 @@ func (ar *AuthorizationResponseClaims) Validate(vr *ValidationResults) {
// Encode tries to turn the auth request claims into a JWT string.
func (ar *AuthorizationResponseClaims) Encode(pair nkeys.KeyPair) (string, error) {
ar.Type = AuthorizationResponseClaim
return ar.ClaimsData.encode(pair, ar)
return ar.EncodeWithSigner(pair, nil)
}
func (ar *AuthorizationResponseClaims) EncodeWithSigner(pair nkeys.KeyPair, fn SignFn) (string, error) {
ar.Type = AuthorizationResponseClaim
return ar.ClaimsData.encode(pair, ar, fn)
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2018-2022 The NATS Authors
* Copyright 2018-2024 The NATS Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@ -68,10 +68,16 @@ func IsGenericClaimType(s string) bool {
}
}
// SignFn is used in an external sign environment. The function should be
// able to locate the private key for the specified pub key specified and sign the
// specified data returning the signature as generated.
type SignFn func(pub string, data []byte) ([]byte, error)
// Claims is a JWT claims
type Claims interface {
Claims() *ClaimsData
Encode(kp nkeys.KeyPair) (string, error)
EncodeWithSigner(pair nkeys.KeyPair, fn SignFn) (string, error)
ExpectedPrefixes() []nkeys.PrefixByte
Payload() interface{}
String() string
@ -121,7 +127,7 @@ func serialize(v interface{}) (string, error) {
return encodeToString(j), nil
}
func (c *ClaimsData) doEncode(header *Header, kp nkeys.KeyPair, claim Claims) (string, error) {
func (c *ClaimsData) doEncode(header *Header, kp nkeys.KeyPair, claim Claims, fn SignFn) (string, error) {
if header == nil {
return "", errors.New("header is required")
}
@ -200,9 +206,21 @@ func (c *ClaimsData) doEncode(header *Header, kp nkeys.KeyPair, claim Claims) (s
if header.Algorithm == AlgorithmNkeyOld {
return "", errors.New(AlgorithmNkeyOld + " not supported to write jwtV2")
} else if header.Algorithm == AlgorithmNkey {
sig, err := kp.Sign([]byte(toSign))
if err != nil {
return "", err
var sig []byte
if fn != nil {
pk, err := kp.PublicKey()
if err != nil {
return "", err
}
sig, err = fn(pk, []byte(toSign))
if err != nil {
return "", err
}
} else {
sig, err = kp.Sign([]byte(toSign))
if err != nil {
return "", err
}
}
eSig = encodeToString(sig)
} else {
@ -224,8 +242,8 @@ func (c *ClaimsData) hash() (string, error) {
// Encode encodes a claim into a JWT token. The claim is signed with the
// provided nkey's private key
func (c *ClaimsData) encode(kp nkeys.KeyPair, payload Claims) (string, error) {
return c.doEncode(&Header{TokenTypeJwt, AlgorithmNkey}, kp, payload)
func (c *ClaimsData) encode(kp nkeys.KeyPair, payload Claims, fn SignFn) (string, error) {
return c.doEncode(&Header{TokenTypeJwt, AlgorithmNkey}, kp, payload, fn)
}
// Returns a JSON representation of the claim

View File

@ -273,7 +273,7 @@ func isContainedIn(kind ExportType, subjects []Subject, vr *ValidationResults) {
}
// Validate calls validate on all of the exports
func (e *Exports) Validate(vr *ValidationResults) error {
func (e *Exports) Validate(vr *ValidationResults) {
var serviceSubjects []Subject
var streamSubjects []Subject
@ -292,8 +292,6 @@ func (e *Exports) Validate(vr *ValidationResults) error {
isContainedIn(Service, serviceSubjects, vr)
isContainedIn(Stream, streamSubjects, vr)
return nil
}
// HasExportContainingSubject checks if the export list has an export with the provided subject

View File

@ -1,5 +1,5 @@
/*
* Copyright 2018-2020 The NATS Authors
* Copyright 2018-2024 The NATS Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@ -107,7 +107,11 @@ func (gc *GenericClaims) Payload() interface{} {
// Encode takes a generic claims and creates a JWT string
func (gc *GenericClaims) Encode(pair nkeys.KeyPair) (string, error) {
return gc.ClaimsData.encode(pair, gc)
return gc.ClaimsData.encode(pair, gc, nil)
}
func (gc *GenericClaims) EncodeWithSigner(pair nkeys.KeyPair, fn SignFn) (string, error) {
return gc.ClaimsData.encode(pair, gc, fn)
}
// Validate checks the generic part of the claims data

View File

@ -1,5 +1,5 @@
/*
* Copyright 2018 The NATS Authors
* Copyright 2018-2024 The NATS Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@ -191,6 +191,10 @@ func (oc *OperatorClaims) DidSign(op Claims) bool {
// Encode the claims into a JWT string
func (oc *OperatorClaims) Encode(pair nkeys.KeyPair) (string, error) {
return oc.EncodeWithSigner(pair, nil)
}
func (oc *OperatorClaims) EncodeWithSigner(pair nkeys.KeyPair, fn SignFn) (string, error) {
if !nkeys.IsValidPublicOperatorKey(oc.Subject) {
return "", errors.New("expected subject to be an operator public key")
}
@ -199,7 +203,7 @@ func (oc *OperatorClaims) Encode(pair nkeys.KeyPair) (string, error) {
return "", err
}
oc.Type = OperatorClaim
return oc.ClaimsData.encode(pair, oc)
return oc.ClaimsData.encode(pair, oc, fn)
}
func (oc *OperatorClaims) ClaimType() ClaimType {

View File

@ -309,7 +309,7 @@ func (l *Limits) Validate(vr *ValidationResults) {
}
}
if l.Times != nil && len(l.Times) > 0 {
if len(l.Times) > 0 {
for _, t := range l.Times {
t.Validate(vr)
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2018-2019 The NATS Authors
* Copyright 2018-2024 The NATS Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@ -92,11 +92,15 @@ func (u *UserClaims) HasEmptyPermissions() bool {
// Encode tries to turn the user claims into a JWT string
func (u *UserClaims) Encode(pair nkeys.KeyPair) (string, error) {
return u.EncodeWithSigner(pair, nil)
}
func (u *UserClaims) EncodeWithSigner(pair nkeys.KeyPair, fn SignFn) (string, error) {
if !nkeys.IsValidPublicUserKey(u.Subject) {
return "", errors.New("expected subject to be user public key")
}
u.Type = UserClaim
return u.ClaimsData.encode(pair, u)
return u.ClaimsData.encode(pair, u, fn)
}
// DecodeUserClaims tries to parse a user claims from a JWT string

View File

@ -1,4 +1,4 @@
// Copyright 2020 The NATS Authors
// Copyright 2020-2021 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -12,7 +12,6 @@
// limitations under the License.
//go:build gofuzz
// +build gofuzz
package conf

View File

@ -1,4 +1,4 @@
// Copyright 2013-2018 The NATS Authors
// Copyright 2013-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -26,6 +26,8 @@ package conf
// see parse_test.go for more examples.
import (
"crypto/sha256"
"encoding/json"
"fmt"
"os"
"path/filepath"
@ -35,6 +37,8 @@ import (
"unicode"
)
const _EMPTY_ = ""
type parser struct {
mapping map[string]any
lx *lexer
@ -69,6 +73,15 @@ func Parse(data string) (map[string]any, error) {
return p.mapping, nil
}
// ParseWithChecks is equivalent to Parse but runs in pedantic mode.
func ParseWithChecks(data string) (map[string]any, error) {
p, err := parse(data, "", true)
if err != nil {
return nil, err
}
return p.mapping, nil
}
// ParseFile is a helper to open file, etc. and parse the contents.
func ParseFile(fp string) (map[string]any, error) {
data, err := os.ReadFile(fp)
@ -98,6 +111,44 @@ func ParseFileWithChecks(fp string) (map[string]any, error) {
return p.mapping, nil
}
// cleanupUsedEnvVars will recursively remove all already used
// environment variables which might be in the parsed tree.
func cleanupUsedEnvVars(m map[string]any) {
for k, v := range m {
t := v.(*token)
if t.usedVariable {
delete(m, k)
continue
}
// Cleanup any other env var that is still in the map.
if tm, ok := t.value.(map[string]any); ok {
cleanupUsedEnvVars(tm)
}
}
}
// ParseFileWithChecksDigest returns the processed config and a digest
// that represents the configuration.
func ParseFileWithChecksDigest(fp string) (map[string]any, string, error) {
data, err := os.ReadFile(fp)
if err != nil {
return nil, _EMPTY_, err
}
p, err := parse(string(data), fp, true)
if err != nil {
return nil, _EMPTY_, err
}
// Filter out any environment variables before taking the digest.
cleanupUsedEnvVars(p.mapping)
digest := sha256.New()
e := json.NewEncoder(digest)
err = e.Encode(p.mapping)
if err != nil {
return nil, _EMPTY_, err
}
return p.mapping, fmt.Sprintf("sha256:%x", digest.Sum(nil)), nil
}
type token struct {
item item
value any
@ -105,6 +156,10 @@ type token struct {
sourceFile string
}
func (t *token) MarshalJSON() ([]byte, error) {
return json.Marshal(t.value)
}
func (t *token) Value() any {
return t.value
}

View File

@ -1,4 +1,4 @@
// Copyright 2012-2019 The NATS Authors
// Copyright 2012-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

View File

@ -1,4 +1,4 @@
// Copyright 2012-2019 The NATS Authors
// Copyright 2012-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -12,7 +12,6 @@
// limitations under the License.
//go:build !windows
// +build !windows
package logger

View File

@ -1,4 +1,4 @@
// Copyright 2012-2018 The NATS Authors
// Copyright 2012-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

View File

@ -10,7 +10,7 @@ The script `runTestsOnTravis.sh` will run a given job based on the definition fo
As for the naming convention:
- All JetStream tests name should start with `TestJetStream`
- All JetStream test name should start with `TestJetStream`
- Cluster tests should go into `jetstream_cluster_test.go` and start with `TestJetStreamCluster`
- Super-cluster tests should go into `jetstream_super_cluster_test.go` and start with `TestJetStreamSuperCluster`

View File

@ -1,4 +1,4 @@
// Copyright 2018-2023 The NATS Authors
// Copyright 2018-2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -61,6 +61,7 @@ type Account struct {
sqmu sync.Mutex
sl *Sublist
ic *client
sq *sendq
isid uint64
etmr *time.Timer
ctmr *time.Timer
@ -97,6 +98,12 @@ type Account struct {
nameTag string
lastLimErr int64
routePoolIdx int
// If the trace destination is specified and a message with a traceParentHdr
// is received, and has the least significant bit of the last token set to 1,
// then if traceDestSampling is > 0 and < 100, a random value will be selected
// and if it falls between 0 and that value, message tracing will be triggered.
traceDest string
traceDestSampling int
// Guarantee that only one goroutine can be running either checkJetStreamMigrate
// or clearObserverState at a given time for this account to prevent interleaving.
jscmMu sync.Mutex
@ -132,6 +139,10 @@ type streamImport struct {
claim *jwt.Import
usePub bool
invalid bool
// This is `allow_trace` and when true and message tracing is happening,
// we will trace egresses past the account boundary, if `false`, we stop
// at the account boundary.
atrc bool
}
const ClientInfoHdr = "Nats-Request-Info"
@ -156,6 +167,7 @@ type serviceImport struct {
share bool
tracking bool
didDeliver bool
atrc bool // allow trace (got from service export)
trackingHdr http.Header // header from request
}
@ -213,6 +225,11 @@ type serviceExport struct {
latency *serviceLatency
rtmr *time.Timer
respThresh time.Duration
// This is `allow_trace` and when true and message tracing is happening,
// when processing a service import we will go through account boundary
// and trace egresses on that other account. If `false`, we stop at the
// account boundary.
atrc bool
}
// Used to track service latency.
@ -250,11 +267,29 @@ func (a *Account) String() string {
return a.Name
}
func (a *Account) setTraceDest(dest string) {
a.mu.Lock()
a.traceDest = dest
a.mu.Unlock()
}
func (a *Account) getTraceDestAndSampling() (string, int) {
a.mu.RLock()
dest := a.traceDest
sampling := a.traceDestSampling
a.mu.RUnlock()
return dest, sampling
}
// Used to create shallow copies of accounts for transfer
// from opts to real accounts in server struct.
// Account `na` write lock is expected to be held on entry
// while account `a` is the one from the Options struct
// being loaded/reloaded and do not need locking.
func (a *Account) shallowCopy(na *Account) {
na.Nkey = a.Nkey
na.Issuer = a.Issuer
na.traceDest, na.traceDestSampling = a.traceDest, a.traceDestSampling
if a.imports.streams != nil {
na.imports.streams = make([]*streamImport, 0, len(a.imports.streams))
@ -425,6 +460,29 @@ func (a *Account) GetName() string {
return name
}
// getNameTag will return the name tag or the account name if not set.
func (a *Account) getNameTag() string {
if a == nil {
return _EMPTY_
}
a.mu.RLock()
defer a.mu.RUnlock()
return a.getNameTagLocked()
}
// getNameTagLocked will return the name tag or the account name if not set.
// Lock should be held.
func (a *Account) getNameTagLocked() string {
if a == nil {
return _EMPTY_
}
nameTag := a.nameTag
if nameTag == _EMPTY_ {
nameTag = a.Name
}
return nameTag
}
// NumConnections returns active number of clients for this account for
// all known servers.
func (a *Account) NumConnections() int {
@ -623,7 +681,7 @@ func (a *Account) AddWeightedMappings(src string, dests ...*MapDest) error {
if tw[d.Cluster] > 100 {
return fmt.Errorf("total weight needs to be <= 100")
}
err := ValidateMappingDestination(d.Subject)
err := ValidateMapping(src, d.Subject)
if err != nil {
return err
}
@ -858,9 +916,14 @@ func (a *Account) Interest(subject string) int {
func (a *Account) addClient(c *client) int {
a.mu.Lock()
n := len(a.clients)
if a.clients != nil {
a.clients[c] = struct{}{}
// Could come here earlier than the account is registered with the server.
// Make sure we can still track clients.
if a.clients == nil {
a.clients = make(map[*client]struct{})
}
a.clients[c] = struct{}{}
// If we did not add it, we are done
if n == len(a.clients) {
a.mu.Unlock()
@ -1900,11 +1963,13 @@ func (a *Account) addServiceImport(dest *Account, from, to string, claim *jwt.Im
return nil, ErrMissingAccount
}
var atrc bool
dest.mu.RLock()
se := dest.getServiceExport(to)
if se != nil {
rt = se.respType
lat = se.latency
atrc = se.atrc
}
dest.mu.RUnlock()
@ -1949,7 +2014,7 @@ func (a *Account) addServiceImport(dest *Account, from, to string, claim *jwt.Im
if claim != nil {
share = claim.Share
}
si := &serviceImport{dest, claim, se, nil, from, to, tr, 0, rt, lat, nil, nil, usePub, false, false, share, false, false, nil}
si := &serviceImport{dest, claim, se, nil, from, to, tr, 0, rt, lat, nil, nil, usePub, false, false, share, false, false, atrc, nil}
a.imports.services[from] = si
a.mu.Unlock()
@ -2021,7 +2086,7 @@ func (a *Account) addServiceImportSub(si *serviceImport) error {
a.mu.Unlock()
cb := func(sub *subscription, c *client, acc *Account, subject, reply string, msg []byte) {
c.processServiceImport(si, acc, msg)
c.pa.delivered = c.processServiceImport(si, acc, msg)
}
sub, err := c.processSubEx([]byte(subject), nil, []byte(sid), cb, true, true, false)
if err != nil {
@ -2173,9 +2238,15 @@ func shouldSample(l *serviceLatency, c *client) (bool, http.Header) {
}
return true, http.Header{trcB3: b3} // sampling allowed or left to recipient of header
} else if tId := h[trcCtx]; len(tId) != 0 {
var sample bool
// sample 00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01
tk := strings.Split(tId[0], "-")
if len(tk) == 4 && len([]byte(tk[3])) == 2 && tk[3] == "01" {
if len(tk) == 4 && len([]byte(tk[3])) == 2 {
if hexVal, err := strconv.ParseInt(tk[3], 16, 8); err == nil {
sample = hexVal&0x1 == 0x1
}
}
if sample {
return true, newTraceCtxHeader(h, tId)
} else {
return false, nil
@ -2387,6 +2458,18 @@ func (a *Account) SetServiceExportResponseThreshold(export string, maxTime time.
return nil
}
func (a *Account) SetServiceExportAllowTrace(export string, allowTrace bool) error {
a.mu.Lock()
se := a.getServiceExport(export)
if se == nil {
a.mu.Unlock()
return fmt.Errorf("no export defined for %q", export)
}
se.atrc = allowTrace
a.mu.Unlock()
return nil
}
// This is for internal service import responses.
func (a *Account) addRespServiceImport(dest *Account, to string, osi *serviceImport, tracking bool, header http.Header) *serviceImport {
nrr := string(osi.acc.newServiceReply(tracking))
@ -2396,7 +2479,7 @@ func (a *Account) addRespServiceImport(dest *Account, to string, osi *serviceImp
// dest is the requestor's account. a is the service responder with the export.
// Marked as internal here, that is how we distinguish.
si := &serviceImport{dest, nil, osi.se, nil, nrr, to, nil, 0, rt, nil, nil, nil, false, true, false, osi.share, false, false, nil}
si := &serviceImport{dest, nil, osi.se, nil, nrr, to, nil, 0, rt, nil, nil, nil, false, true, false, osi.share, false, false, false, nil}
if a.exports.responses == nil {
a.exports.responses = make(map[string]*serviceImport)
@ -2425,6 +2508,10 @@ func (a *Account) addRespServiceImport(dest *Account, to string, osi *serviceImp
// AddStreamImportWithClaim will add in the stream import from a specific account with optional token.
func (a *Account) AddStreamImportWithClaim(account *Account, from, prefix string, imClaim *jwt.Import) error {
return a.addStreamImportWithClaim(account, from, prefix, false, imClaim)
}
func (a *Account) addStreamImportWithClaim(account *Account, from, prefix string, allowTrace bool, imClaim *jwt.Import) error {
if account == nil {
return ErrMissingAccount
}
@ -2447,7 +2534,7 @@ func (a *Account) AddStreamImportWithClaim(account *Account, from, prefix string
}
}
return a.AddMappedStreamImportWithClaim(account, from, prefix+from, imClaim)
return a.addMappedStreamImportWithClaim(account, from, prefix+from, allowTrace, imClaim)
}
// AddMappedStreamImport helper for AddMappedStreamImportWithClaim
@ -2457,6 +2544,10 @@ func (a *Account) AddMappedStreamImport(account *Account, from, to string) error
// AddMappedStreamImportWithClaim will add in the stream import from a specific account with optional token.
func (a *Account) AddMappedStreamImportWithClaim(account *Account, from, to string, imClaim *jwt.Import) error {
return a.addMappedStreamImportWithClaim(account, from, to, false, imClaim)
}
func (a *Account) addMappedStreamImportWithClaim(account *Account, from, to string, allowTrace bool, imClaim *jwt.Import) error {
if account == nil {
return ErrMissingAccount
}
@ -2502,7 +2593,10 @@ func (a *Account) AddMappedStreamImportWithClaim(account *Account, from, to stri
a.mu.Unlock()
return ErrStreamImportDuplicate
}
a.imports.streams = append(a.imports.streams, &streamImport{account, from, to, tr, nil, imClaim, usePub, false})
if imClaim != nil {
allowTrace = imClaim.AllowTrace
}
a.imports.streams = append(a.imports.streams, &streamImport{account, from, to, tr, nil, imClaim, usePub, false, allowTrace})
a.mu.Unlock()
return nil
}
@ -2520,7 +2614,7 @@ func (a *Account) isStreamImportDuplicate(acc *Account, from string) bool {
// AddStreamImport will add in the stream import from a specific account.
func (a *Account) AddStreamImport(account *Account, from, prefix string) error {
return a.AddStreamImportWithClaim(account, from, prefix, nil)
return a.addStreamImportWithClaim(account, from, prefix, false, nil)
}
// IsPublicExport is a placeholder to denote a public export.
@ -2839,7 +2933,9 @@ func (a *Account) checkStreamImportsEqual(b *Account) bool {
bm[bim.acc.Name+bim.from+bim.to] = bim
}
for _, aim := range a.imports.streams {
if _, ok := bm[aim.acc.Name+aim.from+aim.to]; !ok {
if bim, ok := bm[aim.acc.Name+aim.from+aim.to]; !ok {
return false
} else if aim.atrc != bim.atrc {
return false
}
}
@ -2925,6 +3021,9 @@ func isServiceExportEqual(a, b *serviceExport) bool {
return false
}
}
if a.atrc != b.atrc {
return false
}
return true
}
@ -3200,6 +3299,19 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
// Grab trace label under lock.
tl := a.traceLabel()
var td string
var tds int
if ac.Trace != nil {
// Update trace destination and sampling
td, tds = string(ac.Trace.Destination), ac.Trace.Sampling
if !IsValidPublishSubject(td) {
td, tds = _EMPTY_, 0
} else if tds <= 0 || tds > 100 {
tds = 100
}
}
a.traceDest, a.traceDestSampling = td, tds
// Check for external authorization.
if ac.HasExternalAuthorization() {
a.extAuth = &jwt.ExternalAuthorization{}
@ -3328,6 +3440,9 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
s.Debugf("Error adding service export response threshold for [%s]: %v", tl, err)
}
}
if err := a.SetServiceExportAllowTrace(sub, e.AllowTrace); err != nil {
s.Debugf("Error adding allow_trace for %q: %v", sub, err)
}
}
var revocationChanged *bool
@ -3465,10 +3580,15 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
if si != nil && si.acc.Name == a.Name {
// Check for if we are still authorized for an import.
si.invalid = !a.checkServiceImportAuthorized(acc, si.to, si.claim)
if si.latency != nil && !si.response {
// Make sure we should still be tracking latency.
// Make sure we should still be tracking latency and if we
// are allowed to trace.
if !si.response {
if se := a.getServiceExport(si.to); se != nil {
si.latency = se.latency
if si.latency != nil {
si.latency = se.latency
}
// Update allow trace.
si.atrc = se.atrc
}
}
}
@ -3562,6 +3682,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
a.updated = time.Now()
clients := a.getClientsLocked()
ajs := a.js
a.mu.Unlock()
// Sort if we are over the limit.
@ -3586,6 +3707,26 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
a.enableAllJetStreamServiceImportsAndMappings()
}
if ajs != nil {
// Check whether the account NRG status changed. If it has then we need to notify the
// Raft groups running on the system so that they can move their subs if needed.
a.mu.Lock()
previous := ajs.nrgAccount
switch ac.ClusterTraffic {
case "system", _EMPTY_:
ajs.nrgAccount = _EMPTY_
case "owner":
ajs.nrgAccount = a.Name
default:
s.Errorf("Account claim for %q has invalid value %q for cluster traffic account", a.Name, ac.ClusterTraffic)
}
changed := ajs.nrgAccount != previous
a.mu.Unlock()
if changed {
s.updateNRGAccountStatus()
}
}
for i, c := range clients {
a.mu.RLock()
exceeded := a.mconns != jwt.NoLimit && i >= int(a.mconns)
@ -3901,6 +4042,25 @@ func (dr *DirAccResolver) Reload() error {
return dr.DirJWTStore.Reload()
}
// ServerAPIClaimUpdateResponse is the response to $SYS.REQ.ACCOUNT.<id>.CLAIMS.UPDATE and $SYS.REQ.CLAIMS.UPDATE
type ServerAPIClaimUpdateResponse struct {
Server *ServerInfo `json:"server"`
Data *ClaimUpdateStatus `json:"data,omitempty"`
Error *ClaimUpdateError `json:"error,omitempty"`
}
type ClaimUpdateError struct {
Account string `json:"account,omitempty"`
Code int `json:"code"`
Description string `json:"description,omitempty"`
}
type ClaimUpdateStatus struct {
Account string `json:"account,omitempty"`
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
func respondToUpdate(s *Server, respSubj string, acc string, message string, err error) {
if err == nil {
if acc == _EMPTY_ {
@ -3918,22 +4078,26 @@ func respondToUpdate(s *Server, respSubj string, acc string, message string, err
if respSubj == _EMPTY_ {
return
}
server := &ServerInfo{}
response := map[string]interface{}{"server": server}
m := map[string]interface{}{}
if acc != _EMPTY_ {
m["account"] = acc
response := ServerAPIClaimUpdateResponse{
Server: &ServerInfo{},
}
if err == nil {
m["code"] = http.StatusOK
m["message"] = message
response["data"] = m
response.Data = &ClaimUpdateStatus{
Account: acc,
Code: http.StatusOK,
Message: message,
}
} else {
m["code"] = http.StatusInternalServerError
m["description"] = fmt.Sprintf("%s - %v", message, err)
response["error"] = m
response.Error = &ClaimUpdateError{
Account: acc,
Code: http.StatusInternalServerError,
Description: fmt.Sprintf("%s - %v", message, err),
}
}
s.sendInternalMsgLocked(respSubj, _EMPTY_, server, response)
s.sendInternalMsgLocked(respSubj, _EMPTY_, response.Server, response)
}
func handleListRequest(store *DirJWTStore, s *Server, reply string) {

View File

@ -417,6 +417,10 @@ func (c *client) matchesPinnedCert(tlsPinnedCerts PinnedCertSet) bool {
return true
}
var (
mustacheRE = regexp.MustCompile(`{{2}([^}]+)}{2}`)
)
func processUserPermissionsTemplate(lim jwt.UserPermissionLimits, ujwt *jwt.UserClaims, acc *Account) (jwt.UserPermissionLimits, error) {
nArrayCartesianProduct := func(a ...[]string) [][]string {
c := 1
@ -448,16 +452,26 @@ func processUserPermissionsTemplate(lim jwt.UserPermissionLimits, ujwt *jwt.User
}
return p
}
isTag := func(op string) []string {
if strings.EqualFold("tag(", op[:4]) && strings.HasSuffix(op, ")") {
v := strings.TrimPrefix(op, "tag(")
v = strings.TrimSuffix(v, ")")
return []string{"tag", v}
} else if strings.EqualFold("account-tag(", op[:12]) && strings.HasSuffix(op, ")") {
v := strings.TrimPrefix(op, "account-tag(")
v = strings.TrimSuffix(v, ")")
return []string{"account-tag", v}
}
return nil
}
applyTemplate := func(list jwt.StringList, failOnBadSubject bool) (jwt.StringList, error) {
found := false
FOR_FIND:
for i := 0; i < len(list); i++ {
// check if templates are present
for _, tk := range strings.Split(list[i], tsep) {
if strings.HasPrefix(tk, "{{") && strings.HasSuffix(tk, "}}") {
found = true
break FOR_FIND
}
if mustacheRE.MatchString(list[i]) {
found = true
break FOR_FIND
}
}
if !found {
@ -466,94 +480,78 @@ func processUserPermissionsTemplate(lim jwt.UserPermissionLimits, ujwt *jwt.User
// process the templates
emittedList := make([]string, 0, len(list))
for i := 0; i < len(list); i++ {
tokens := strings.Split(list[i], tsep)
newTokens := make([]string, len(tokens))
tagValues := [][]string{}
// find all the templates {{}} in this acl
tokens := mustacheRE.FindAllString(list[i], -1)
srcs := make([]string, len(tokens))
values := make([][]string, len(tokens))
hasTags := false
for tokenNum, tk := range tokens {
if strings.HasPrefix(tk, "{{") && strings.HasSuffix(tk, "}}") {
op := strings.ToLower(strings.TrimSuffix(strings.TrimPrefix(tk, "{{"), "}}"))
switch {
case op == "name()":
tk = ujwt.Name
case op == "subject()":
tk = ujwt.Subject
case op == "account-name()":
srcs[tokenNum] = tk
op := strings.TrimSpace(strings.TrimSuffix(strings.TrimPrefix(tk, "{{"), "}}"))
if strings.EqualFold("name()", op) {
values[tokenNum] = []string{ujwt.Name}
} else if strings.EqualFold("subject()", op) {
values[tokenNum] = []string{ujwt.Subject}
} else if strings.EqualFold("account-name()", op) {
acc.mu.RLock()
values[tokenNum] = []string{acc.nameTag}
acc.mu.RUnlock()
} else if strings.EqualFold("account-subject()", op) {
// this always has an issuer account since this is a scoped signer
values[tokenNum] = []string{ujwt.IssuerAccount}
} else if isTag(op) != nil {
hasTags = true
match := isTag(op)
var tags jwt.TagList
if match[0] == "account-tag" {
acc.mu.RLock()
name := acc.nameTag
tags = acc.tags
acc.mu.RUnlock()
tk = name
case op == "account-subject()":
tk = ujwt.IssuerAccount
case (strings.HasPrefix(op, "tag(") || strings.HasPrefix(op, "account-tag(")) &&
strings.HasSuffix(op, ")"):
// insert dummy tav value that will throw of subject validation (in case nothing is found)
tk = _EMPTY_
// collect list of matching tag values
var tags jwt.TagList
var tagPrefix string
if strings.HasPrefix(op, "account-tag(") {
acc.mu.RLock()
tags = acc.tags
acc.mu.RUnlock()
tagPrefix = fmt.Sprintf("%s:", strings.ToLower(
strings.TrimSuffix(strings.TrimPrefix(op, "account-tag("), ")")))
} else {
tags = ujwt.Tags
tagPrefix = fmt.Sprintf("%s:", strings.ToLower(
strings.TrimSuffix(strings.TrimPrefix(op, "tag("), ")")))
}
valueList := []string{}
for _, tag := range tags {
if strings.HasPrefix(tag, tagPrefix) {
tagValue := strings.TrimPrefix(tag, tagPrefix)
valueList = append(valueList, tagValue)
}
}
if len(valueList) != 0 {
tagValues = append(tagValues, valueList)
}
default:
// if macro is not recognized, throw off subject check on purpose
tk = " "
} else {
tags = ujwt.Tags
}
tagPrefix := fmt.Sprintf("%s:", strings.ToLower(match[1]))
var valueList []string
for _, tag := range tags {
if strings.HasPrefix(tag, tagPrefix) {
tagValue := strings.TrimPrefix(tag, tagPrefix)
valueList = append(valueList, tagValue)
}
}
if len(valueList) != 0 {
values[tokenNum] = valueList
} else if failOnBadSubject {
return nil, fmt.Errorf("generated invalid subject %q: %q is not defined", list[i], match[1])
} else {
// generate an invalid subject?
values[tokenNum] = []string{" "}
}
} else if failOnBadSubject {
return nil, fmt.Errorf("template operation in %q: %q is not defined", list[i], op)
}
newTokens[tokenNum] = tk
}
// fill in tag value placeholders
if len(tagValues) == 0 {
emitSubj := strings.Join(newTokens, tsep)
if IsValidSubject(emitSubj) {
emittedList = append(emittedList, emitSubj)
if !hasTags {
subj := list[i]
for idx, m := range srcs {
subj = strings.Replace(subj, m, values[idx][0], -1)
}
if IsValidSubject(subj) {
emittedList = append(emittedList, subj)
} else if failOnBadSubject {
return nil, fmt.Errorf("generated invalid subject")
}
// else skip emitting
} else {
// compute the cartesian product and compute subject to emit for each combination
for _, valueList := range nArrayCartesianProduct(tagValues...) {
b := strings.Builder{}
for i, token := range newTokens {
if token == _EMPTY_ && len(valueList) > 0 {
b.WriteString(valueList[0])
valueList = valueList[1:]
} else {
b.WriteString(token)
}
if i != len(newTokens)-1 {
b.WriteString(tsep)
}
a := nArrayCartesianProduct(values...)
for _, aa := range a {
subj := list[i]
for j := 0; j < len(srcs); j++ {
subj = strings.Replace(subj, srcs[j], aa[j], -1)
}
emitSubj := b.String()
if IsValidSubject(emitSubj) {
emittedList = append(emittedList, emitSubj)
if IsValidSubject(subj) {
emittedList = append(emittedList, subj)
} else if failOnBadSubject {
return nil, fmt.Errorf("generated invalid subject")
}
// else skip emitting
}
}
}
@ -606,13 +604,39 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) (au
}
return
}
// We have a juc defined here, check account.
// We have a juc, check if externally managed, i.e. should be delegated
// to the auth callout service.
if juc != nil && !acc.hasExternalAuth() {
if !authorized {
s.sendAccountAuthErrorEvent(c, c.acc, reason)
}
return
}
// Check config-mode. The global account is a condition since users that
// are not found in the config are implicitly bound to the global account.
// This means those users should be implicitly delegated to auth callout
// if configured. Exclude LEAF connections from this check.
if c.kind != LEAF && juc == nil && opts.AuthCallout != nil && c.acc.Name != globalAccountName {
// If no allowed accounts are defined, then all accounts are in scope.
// Otherwise see if the account is in the list.
delegated := len(opts.AuthCallout.AllowedAccounts) == 0
if !delegated {
for _, n := range opts.AuthCallout.AllowedAccounts {
if n == c.acc.Name {
delegated = true
break
}
}
}
// Not delegated, so return with previous authorized result.
if !delegated {
if !authorized {
s.sendAccountAuthErrorEvent(c, c.acc, reason)
}
return
}
}
// We have auth callout set here.
var skip bool
@ -1471,7 +1495,8 @@ func validateAllowedConnectionTypes(m map[string]struct{}) error {
switch ctuc {
case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket,
jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeLeafnodeWS,
jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS:
jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS,
jwt.ConnectionTypeInProcess:
default:
return fmt.Errorf("unknown connection type %q", ct)
}

View File

@ -1,4 +1,4 @@
// Copyright 2022-2023 The NATS Authors
// Copyright 2022-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

View File

@ -1,4 +1,4 @@
// Copyright 2023 The NATS Authors
// Copyright 2023-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

View File

@ -1,4 +1,4 @@
// Copyright 2022-2023 The NATS Authors
// Copyright 2022-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -46,11 +46,13 @@ type MatchByType int
const (
matchByIssuer MatchByType = iota + 1
matchBySubject
matchByThumbprint
)
var MatchByMap = map[string]MatchByType{
"issuer": matchByIssuer,
"subject": matchBySubject,
"issuer": matchByIssuer,
"subject": matchBySubject,
"thumbprint": matchByThumbprint,
}
var Usage = `

View File

@ -26,8 +26,7 @@ var _ = MATCHBYEMPTY
// otherKey implements crypto.Signer and crypto.Decrypter to satisfy linter on platforms that don't implement certstore
type otherKey struct{}
func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, config *tls.Config) error {
_, _, _, _ = certStore, certMatchBy, certMatch, config
func TLSConfig(_ StoreType, _ MatchByType, _ string, _ []string, _ bool, _ *tls.Config) error {
return ErrOSNotCompatCertStore
}

View File

@ -1,4 +1,4 @@
// Copyright 2022-2023 The NATS Authors
// Copyright 2022-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -41,26 +41,26 @@ import (
const (
// wincrypt.h constants
winAcquireCached = 0x1 // CRYPT_ACQUIRE_CACHE_FLAG
winAcquireSilent = 0x40 // CRYPT_ACQUIRE_SILENT_FLAG
winAcquireOnlyNCryptKey = 0x40000 // CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG
winEncodingX509ASN = 1 // X509_ASN_ENCODING
winEncodingPKCS7 = 65536 // PKCS_7_ASN_ENCODING
winCertStoreProvSystem = 10 // CERT_STORE_PROV_SYSTEM
winCertStoreCurrentUser = uint32(winCertStoreCurrentUserID << winCompareShift) // CERT_SYSTEM_STORE_CURRENT_USER
winCertStoreLocalMachine = uint32(winCertStoreLocalMachineID << winCompareShift) // CERT_SYSTEM_STORE_LOCAL_MACHINE
winCertStoreCurrentUserID = 1 // CERT_SYSTEM_STORE_CURRENT_USER_ID
winCertStoreLocalMachineID = 2 // CERT_SYSTEM_STORE_LOCAL_MACHINE_ID
winInfoIssuerFlag = 4 // CERT_INFO_ISSUER_FLAG
winInfoSubjectFlag = 7 // CERT_INFO_SUBJECT_FLAG
winCompareNameStrW = 8 // CERT_COMPARE_NAME_STR_A
winCompareShift = 16 // CERT_COMPARE_SHIFT
winAcquireCached = windows.CRYPT_ACQUIRE_CACHE_FLAG
winAcquireSilent = windows.CRYPT_ACQUIRE_SILENT_FLAG
winAcquireOnlyNCryptKey = windows.CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG
winEncodingX509ASN = windows.X509_ASN_ENCODING
winEncodingPKCS7 = windows.PKCS_7_ASN_ENCODING
winCertStoreProvSystem = windows.CERT_STORE_PROV_SYSTEM
winCertStoreCurrentUser = windows.CERT_SYSTEM_STORE_CURRENT_USER
winCertStoreLocalMachine = windows.CERT_SYSTEM_STORE_LOCAL_MACHINE
winCertStoreReadOnly = windows.CERT_STORE_READONLY_FLAG
winInfoIssuerFlag = windows.CERT_INFO_ISSUER_FLAG
winInfoSubjectFlag = windows.CERT_INFO_SUBJECT_FLAG
winCompareNameStrW = windows.CERT_COMPARE_NAME_STR_W
winCompareShift = windows.CERT_COMPARE_SHIFT
// Reference https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
winFindIssuerStr = winCompareNameStrW<<winCompareShift | winInfoIssuerFlag // CERT_FIND_ISSUER_STR_W
winFindSubjectStr = winCompareNameStrW<<winCompareShift | winInfoSubjectFlag // CERT_FIND_SUBJECT_STR_W
winFindIssuerStr = windows.CERT_FIND_ISSUER_STR_W
winFindSubjectStr = windows.CERT_FIND_SUBJECT_STR_W
winFindHashStr = windows.CERT_FIND_HASH_STR
winNcryptKeySpec = 0xFFFFFFFF // CERT_NCRYPT_KEY_SPEC
winNcryptKeySpec = windows.CERT_NCRYPT_KEY_SPEC
winBCryptPadPKCS1 uintptr = 0x2
winBCryptPadPSS uintptr = 0x8 // Modern TLS 1.2+
@ -76,7 +76,7 @@ const (
winECK3Magic = 0x334B4345 // "ECK3" BCRYPT_ECDH_PUBLIC_P384_MAGIC
winECK5Magic = 0x354B4345 // "ECK5" BCRYPT_ECDH_PUBLIC_P521_MAGIC
winCryptENotFound = 0x80092004 // CRYPT_E_NOT_FOUND
winCryptENotFound = windows.CRYPT_E_NOT_FOUND
providerMSSoftware = "Microsoft Software Key Storage Provider"
)
@ -111,14 +111,24 @@ var (
crypto.SHA512: winWide("SHA512"), // BCRYPT_SHA512_ALGORITHM
}
// MY is well-known system store on Windows that holds personal certificates
winMyStore = winWide("MY")
// MY is well-known system store on Windows that holds personal certificates. Read
// More about the CA locations here:
// https://learn.microsoft.com/en-us/dotnet/framework/configure-apps/file-schema/wcf/certificate-of-clientcertificate-element?redirectedfrom=MSDN
// https://superuser.com/questions/217719/what-are-the-windows-system-certificate-stores
// https://docs.microsoft.com/en-us/windows/win32/seccrypto/certificate-stores
// https://learn.microsoft.com/en-us/windows/win32/seccrypto/system-store-locations
// https://stackoverflow.com/questions/63286085/which-x509-storename-refers-to-the-certificates-stored-beneath-trusted-root-cert#:~:text=4-,StoreName.,is%20%22Intermediate%20Certification%20Authorities%22.
winMyStore = winWide("MY")
winIntermediateCAStore = winWide("CA")
winRootStore = winWide("Root")
winAuthRootStore = winWide("AuthRoot")
// These DLLs must be available on all Windows hosts
winCrypt32 = windows.NewLazySystemDLL("crypt32.dll")
winNCrypt = windows.NewLazySystemDLL("ncrypt.dll")
winCertFindCertificateInStore = winCrypt32.NewProc("CertFindCertificateInStore")
winCertVerifyTimeValidity = winCrypt32.NewProc("CertVerifyTimeValidity")
winCryptAcquireCertificatePrivateKey = winCrypt32.NewProc("CryptAcquireCertificatePrivateKey")
winNCryptExportKey = winNCrypt.NewProc("NCryptExportKey")
winNCryptOpenStorageProvider = winNCrypt.NewProc("NCryptOpenStorageProvider")
@ -156,9 +166,40 @@ type winPSSPaddingInfo struct {
cbSalt uint32
}
// TLSConfig fulfills the same function as reading cert and key pair from pem files but
// sources the Windows certificate store instead
func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, config *tls.Config) error {
// createCACertsPool generates a CertPool from the Windows certificate store,
// adding all matching certificates from the caCertsMatch array to the pool.
// All matching certificates (vs first) are added to the pool based on a user
// request. If no certificates are found an error is returned.
func createCACertsPool(cs *winCertStore, storeType uint32, caCertsMatch []string, skipInvalid bool) (*x509.CertPool, error) {
var errs []error
caPool := x509.NewCertPool()
for _, s := range caCertsMatch {
lfs, err := cs.caCertsBySubjectMatch(s, storeType, skipInvalid)
if err != nil {
errs = append(errs, err)
} else {
for _, lf := range lfs {
caPool.AddCert(lf)
}
}
}
// If every lookup failed return the errors.
if len(errs) == len(caCertsMatch) {
return nil, fmt.Errorf("unable to match any CA certificate: %v", errs)
}
return caPool, nil
}
// TLSConfig fulfills the same function as reading cert and key pair from
// pem files but sources the Windows certificate store instead. The
// certMatchBy and certMatch fields search the "MY" certificate location
// for the first certificate that matches the certMatch field. The
// caCertsMatch field is used to search the Trusted Root, Third Party Root,
// and Intermediate Certificate Authority locations for certificates with
// Subjects matching the provided strings. If a match is found, the
// certificate is added to the pool that is used to verify the certificate
// chain.
func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, caCertsMatch []string, skipInvalid bool, config *tls.Config) error {
var (
leaf *x509.Certificate
leafCtx *windows.CertContext
@ -185,9 +226,11 @@ func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, c
// certByIssuer or certBySubject
if certMatchBy == matchBySubject || certMatchBy == MATCHBYEMPTY {
leaf, leafCtx, err = cs.certBySubject(certMatch, scope)
leaf, leafCtx, err = cs.certBySubject(certMatch, scope, skipInvalid)
} else if certMatchBy == matchByIssuer {
leaf, leafCtx, err = cs.certByIssuer(certMatch, scope)
leaf, leafCtx, err = cs.certByIssuer(certMatch, scope, skipInvalid)
} else if certMatchBy == matchByThumbprint {
leaf, leafCtx, err = cs.certByThumbprint(certMatch, scope, skipInvalid)
} else {
return ErrBadMatchByType
}
@ -205,6 +248,14 @@ func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, c
if pk == nil {
return ErrNoPrivateKeyStoreRef
}
// Look for CA Certificates
if len(caCertsMatch) != 0 {
caPool, err := createCACertsPool(cs, scope, caCertsMatch, skipInvalid)
if err != nil {
return err
}
config.ClientCAs = caPool
}
} else {
return ErrBadCertStore
}
@ -278,7 +329,7 @@ func winFindCert(store windows.Handle, enc, findFlags, findType uint32, para *ui
)
if h == 0 {
// Actual error, or simply not found?
if errno, ok := err.(syscall.Errno); ok && errno == winCryptENotFound {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.Errno(winCryptENotFound) {
return nil, ErrFailedCertSearch
}
return nil, ErrFailedCertSearch
@ -287,6 +338,16 @@ func winFindCert(store windows.Handle, enc, findFlags, findType uint32, para *ui
return (*windows.CertContext)(unsafe.Pointer(h)), nil
}
// winVerifyCertValid wraps the CertVerifyTimeValidity and simply returns true if the certificate is valid
func winVerifyCertValid(timeToVerify *windows.Filetime, certInfo *windows.CertInfo) bool {
// this function does not document returning errors / setting lasterror
r, _, _ := winCertVerifyTimeValidity.Call(
uintptr(unsafe.Pointer(timeToVerify)),
uintptr(unsafe.Pointer(certInfo)),
)
return r == 0
}
// winCertStore is a store implementation for the Windows Certificate Store
type winCertStore struct {
Prov uintptr
@ -326,21 +387,70 @@ func winCertContextToX509(ctx *windows.CertContext) (*x509.Certificate, error) {
// CertContext pointer returned allows subsequent key operations like Sign. Caller specifies
// current user's personal certs or local machine's personal certs using storeType.
// See CERT_FIND_ISSUER_STR description at https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
func (w *winCertStore) certByIssuer(issuer string, storeType uint32) (*x509.Certificate, *windows.CertContext, error) {
return w.certSearch(winFindIssuerStr, issuer, winMyStore, storeType)
func (w *winCertStore) certByIssuer(issuer string, storeType uint32, skipInvalid bool) (*x509.Certificate, *windows.CertContext, error) {
return w.certSearch(winFindIssuerStr, issuer, winMyStore, storeType, skipInvalid)
}
// certBySubject matches and returns the first certificate found by passed subject field.
// CertContext pointer returned allows subsequent key operations like Sign. Caller specifies
// current user's personal certs or local machine's personal certs using storeType.
// See CERT_FIND_SUBJECT_STR description at https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
func (w *winCertStore) certBySubject(subject string, storeType uint32) (*x509.Certificate, *windows.CertContext, error) {
return w.certSearch(winFindSubjectStr, subject, winMyStore, storeType)
func (w *winCertStore) certBySubject(subject string, storeType uint32, skipInvalid bool) (*x509.Certificate, *windows.CertContext, error) {
return w.certSearch(winFindSubjectStr, subject, winMyStore, storeType, skipInvalid)
}
// certByThumbprint matches and returns the first certificate found by passed SHA1 thumbprint.
// CertContext pointer returned allows subsequent key operations like Sign. Caller specifies
// current user's personal certs or local machine's personal certs using storeType.
// See CERT_FIND_SUBJECT_STR description at https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
func (w *winCertStore) certByThumbprint(hash string, storeType uint32, skipInvalid bool) (*x509.Certificate, *windows.CertContext, error) {
return w.certSearch(winFindHashStr, hash, winMyStore, storeType, skipInvalid)
}
// caCertsBySubjectMatch matches and returns all matching certificates of the subject field.
//
// The following locations are searched:
// 1) Root (Trusted Root Certification Authorities)
// 2) AuthRoot (Third-Party Root Certification Authorities)
// 3) CA (Intermediate Certification Authorities)
//
// Caller specifies current user's personal certs or local machine's personal certs using storeType.
// See CERT_FIND_SUBJECT_STR description at https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore
func (w *winCertStore) caCertsBySubjectMatch(subject string, storeType uint32, skipInvalid bool) ([]*x509.Certificate, error) {
var (
leaf *x509.Certificate
searchLocations = [3]*uint16{winRootStore, winAuthRootStore, winIntermediateCAStore}
rv []*x509.Certificate
)
// surprisingly, an empty string returns a result. We'll treat this as an error.
if subject == "" {
return nil, ErrBadCaCertMatchField
}
for _, sr := range searchLocations {
var err error
if leaf, _, err = w.certSearch(winFindSubjectStr, subject, sr, storeType, skipInvalid); err == nil {
rv = append(rv, leaf)
} else {
// Ignore the failed search from a single location. Errors we catch include
// ErrFailedX509Extract (resulting from a malformed certificate) and errors
// around invalid attributes, unsupported algorithms, etc. These are corner
// cases as certificates with these errors shouldn't have been allowed
// to be added to the store in the first place.
if err != ErrFailedCertSearch {
return nil, err
}
}
}
// Not found anywhere
if len(rv) == 0 {
return nil, ErrFailedCertSearch
}
return rv, nil
}
// certSearch is a helper function to lookup certificates based on search type and match value.
// store is used to specify which store to perform the lookup in (system or user).
func (w *winCertStore) certSearch(searchType uint32, matchValue string, searchRoot *uint16, store uint32) (*x509.Certificate, *windows.CertContext, error) {
func (w *winCertStore) certSearch(searchType uint32, matchValue string, searchRoot *uint16, store uint32, skipInvalid bool) (*x509.Certificate, *windows.CertContext, error) {
// store handle to "MY" store
h, err := w.storeHandle(store, searchRoot)
if err != nil {
@ -357,23 +467,32 @@ func (w *winCertStore) certSearch(searchType uint32, matchValue string, searchRo
// pass 0 as the third parameter because it is not used
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa376064(v=vs.85).aspx
nc, err := winFindCert(h, winEncodingX509ASN|winEncodingPKCS7, 0, searchType, i, prev)
if err != nil {
return nil, nil, err
}
if nc != nil {
// certificate found
prev = nc
// Extract the DER-encoded certificate from the cert context
xc, err := winCertContextToX509(nc)
if err == nil {
cert = xc
} else {
return nil, nil, ErrFailedX509Extract
for {
nc, err := winFindCert(h, winEncodingX509ASN|winEncodingPKCS7, 0, searchType, i, prev)
if err != nil {
return nil, nil, err
}
if nc != nil {
// certificate found
prev = nc
var now *windows.Filetime
if skipInvalid && !winVerifyCertValid(now, nc.CertInfo) {
continue
}
// Extract the DER-encoded certificate from the cert context
xc, err := winCertContextToX509(nc)
if err == nil {
cert = xc
break
} else {
return nil, nil, ErrFailedX509Extract
}
} else {
return nil, nil, ErrFailedCertSearch
}
} else {
return nil, nil, ErrFailedCertSearch
}
if cert == nil {
@ -396,7 +515,7 @@ func winNewStoreHandle(provider uint32, store *uint16) (*winStoreHandle, error)
winCertStoreProvSystem,
0,
0,
provider,
provider|winCertStoreReadOnly,
uintptr(unsafe.Pointer(store)))
if err != nil {
return nil, ErrBadCryptoStoreProvider

View File

@ -68,6 +68,12 @@ var (
// ErrBadCertMatchField represents malformed cert_match option
ErrBadCertMatchField = errors.New("expected 'cert_match' to be a valid non-empty string")
// ErrBadCaCertMatchField represents malformed cert_match option
ErrBadCaCertMatchField = errors.New("expected 'ca_certs_match' to be a valid non-empty string array")
// ErrBadCertMatchSkipInvalidField represents malformed cert_match_skip_invalid option
ErrBadCertMatchSkipInvalidField = errors.New("expected 'cert_match_skip_invalid' to be a boolean")
// ErrOSNotCompatCertStore represents cert_store passed that exists but is not valid on current OS
ErrOSNotCompatCertStore = errors.New("cert_store not compatible with current operating system")
)

View File

@ -1,4 +1,4 @@
// Copyright 2016-2018 The NATS Authors
// Copyright 2016-2020 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
// Copyright 2012-2023 The NATS Authors
// Copyright 2012-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -14,6 +14,7 @@
package server
import (
"regexp"
"runtime/debug"
"time"
)
@ -38,6 +39,8 @@ var (
gitCommit, serverVersion string
// trustedKeys is a whitespace separated array of trusted operator's public nkeys.
trustedKeys string
// SemVer regexp to validate the VERSION.
semVerRe = regexp.MustCompile(`^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$`)
)
func init() {
@ -55,7 +58,7 @@ func init() {
const (
// VERSION is the current version for the server.
VERSION = "2.10.22"
VERSION = "2.11.1"
// PROTO is the currently supported protocol.
// 0 was the original
@ -171,6 +174,9 @@ const (
// MAX_HPUB_ARGS Maximum possible number of arguments from HPUB proto.
MAX_HPUB_ARGS = 4
// MAX_RSUB_ARGS Maximum possible number of arguments from a RS+/LS+ proto.
MAX_RSUB_ARGS = 6
// DEFAULT_MAX_CLOSED_CLIENTS is the maximum number of closed connections we hold onto.
DEFAULT_MAX_CLOSED_CLIENTS = 10000

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
// Copyright 2012-2021 The NATS Authors
// Copyright 2012-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

View File

@ -12,7 +12,6 @@
// limitations under the License.
//go:build !windows && !openbsd && !netbsd && !wasm
// +build !windows,!openbsd,!netbsd,!wasm
package server

View File

@ -12,7 +12,6 @@
// limitations under the License.
//go:build netbsd
// +build netbsd
package server

View File

@ -12,7 +12,6 @@
// limitations under the License.
//go:build openbsd
// +build openbsd
package server

View File

@ -1,4 +1,4 @@
// Copyright 2022 The NATS Authors
// Copyright 2022-2021 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -12,7 +12,6 @@
// limitations under the License.
//go:build wasm
// +build wasm
package server

View File

@ -1,4 +1,4 @@
// Copyright 2020 The NATS Authors
// Copyright 2020-2021 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -12,7 +12,6 @@
// limitations under the License.
//go:build windows
// +build windows
package server

View File

@ -1,4 +1,4 @@
// Copyright 2012-2021 The NATS Authors
// Copyright 2012-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -153,6 +153,9 @@ var (
// Gateway's name.
ErrWrongGateway = errors.New("wrong gateway")
// ErrGatewayNameHasSpaces signals that the gateway name contains spaces, which is not allowed.
ErrGatewayNameHasSpaces = errors.New("gateway name cannot contain spaces")
// ErrNoSysAccount is returned when an attempt to publish or subscribe is made
// when there is no internal system account defined.
ErrNoSysAccount = errors.New("system account not setup")
@ -163,6 +166,9 @@ var (
// ErrServerNotRunning is used to signal an error that a server is not running.
ErrServerNotRunning = errors.New("server is not running")
// ErrServerNameHasSpaces signals that the server name contains spaces, which is not allowed.
ErrServerNameHasSpaces = errors.New("server name cannot contain spaces")
// ErrBadMsgHeader signals the parser detected a bad message header
ErrBadMsgHeader = errors.New("bad message header detected")
@ -181,7 +187,7 @@ var (
ErrClusterNameRemoteConflict = errors.New("cluster name from remote server conflicts")
// ErrClusterNameHasSpaces signals that the cluster name contains spaces, which is not allowed.
ErrClusterNameHasSpaces = errors.New("cluster name cannot contain spaces or new lines")
ErrClusterNameHasSpaces = errors.New("cluster name cannot contain spaces")
// ErrMalformedSubject is returned when a subscription is made with a subject that does not conform to subject rules.
ErrMalformedSubject = errors.New("malformed subject")
@ -206,7 +212,7 @@ var (
ErrInvalidMappingDestination = errors.New("invalid mapping destination")
// ErrInvalidMappingDestinationSubject is used to error on a bad transform destination mapping
ErrInvalidMappingDestinationSubject = fmt.Errorf("%w: invalid subject", ErrInvalidMappingDestination)
ErrInvalidMappingDestinationSubject = fmt.Errorf("%w: invalid transform", ErrInvalidMappingDestination)
// ErrMappingDestinationNotUsingAllWildcards is used to error on a transform destination not using all of the token wildcards
ErrMappingDestinationNotUsingAllWildcards = fmt.Errorf("%w: not using all of the token wildcard(s)", ErrInvalidMappingDestination)

View File

@ -203,7 +203,7 @@
"constant": "JSInvalidJSONErr",
"code": 400,
"error_code": 10025,
"description": "invalid JSON",
"description": "invalid JSON: {err}",
"comment": "",
"help": "",
"url": "",
@ -833,7 +833,7 @@
"constant": "JSConsumerPullRequiresAckErr",
"code": 400,
"error_code": 10084,
"description": "consumer in pull mode requires ack policy",
"description": "consumer in pull mode requires explicit ack policy on workqueue stream",
"comment": "",
"help": "",
"url": "",
@ -1433,7 +1433,7 @@
"constant": "JSSourceInvalidSubjectFilter",
"code": 400,
"error_code": 10145,
"description": "source subject filter is invalid",
"description": "source transform source: {err}",
"comment": "",
"help": "",
"url": "",
@ -1443,7 +1443,7 @@
"constant": "JSSourceInvalidTransformDestination",
"code": 400,
"error_code": 10146,
"description": "source transform destination is invalid",
"description": "source transform: {err}",
"comment": "",
"help": "",
"url": "",
@ -1493,7 +1493,7 @@
"constant": "JSMirrorInvalidSubjectFilter",
"code": 400,
"error_code": 10151,
"description": "mirror subject filter is invalid",
"description": "mirror transform source: {err}",
"comment": "",
"help": "",
"url": "",
@ -1518,5 +1518,145 @@
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSMirrorInvalidTransformDestination",
"code": 400,
"error_code": 10154,
"description": "mirror transform: {err}",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSStreamTransformInvalidSource",
"code": 400,
"error_code": 10155,
"description": "stream transform source: {err}",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSStreamTransformInvalidDestination",
"code": 400,
"error_code": 10156,
"description": "stream transform: {err}",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSPedanticErrF",
"code": 400,
"error_code": 10157,
"description": "pedantic mode: {err}",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSStreamDuplicateMessageConflict",
"code": 409,
"error_code": 10158,
"description": "duplicate message id is in process",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSConsumerPriorityPolicyWithoutGroup",
"code": 400,
"error_code": 10159,
"description": "Setting PriorityPolicy requires at least one PriorityGroup to be set",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSConsumerInvalidPriorityGroupErr",
"code": 400,
"error_code": 10160,
"description": "Provided priority group does not exist for this consumer",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSConsumerEmptyGroupName",
"code": 400,
"error_code": 10161,
"description": "Group name cannot be an empty string",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSConsumerInvalidGroupNameErr",
"code": 400,
"error_code": 10162,
"description": "Valid priority group name must match A-Z, a-z, 0-9, -_/=)+ and may not exceed 16 characters",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSStreamExpectedLastSeqPerSubjectNotReady",
"code": 503,
"error_code": 10163,
"description": "expected last sequence per subject temporarily unavailable",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSStreamWrongLastSequenceConstantErr",
"code": 400,
"error_code": 10164,
"description": "wrong last sequence",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSMessageTTLInvalidErr",
"code": 400,
"error_code": 10165,
"description": "invalid per-message TTL",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSMessageTTLDisabledErr",
"code": 400,
"error_code": 10166,
"description": "per-message TTL is disabled",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSStreamTooManyRequests",
"code": 429,
"error_code": 10167,
"description": "too many requests",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
}
]

View File

@ -1,4 +1,4 @@
// Copyright 2018-2023 The NATS Authors
// Copyright 2018-2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -31,7 +31,6 @@ import (
"time"
"github.com/klauspost/compress/s2"
"github.com/nats-io/jwt/v2"
"github.com/nats-io/nats-server/v2/server/certidp"
"github.com/nats-io/nats-server/v2/server/pse"
@ -215,6 +214,7 @@ type AccountNumConns struct {
// AccountStat contains the data common between AccountNumConns and AccountStatz
type AccountStat struct {
Account string `json:"acc"`
Name string `json:"name"`
Conns int `json:"conns"`
LeafNodes int `json:"leafnodes"`
TotalConns int `json:"total_conns"`
@ -264,6 +264,7 @@ type ServerInfo struct {
const (
JetStreamEnabled ServerCapability = 1 << iota // Server had JetStream enabled.
BinaryStreamSnapshot // New stream snapshot capability.
AccountNRG // Move NRG traffic out of system account.
)
// Set JetStream capability.
@ -289,6 +290,17 @@ func (si *ServerInfo) BinaryStreamSnapshot() bool {
return si.Flags&BinaryStreamSnapshot != 0
}
// Set account NRG capability.
func (si *ServerInfo) SetAccountNRG() {
si.Flags |= AccountNRG
}
// AccountNRG indicates whether or not we support moving the NRG traffic out of the
// system account and into the asset account.
func (si *ServerInfo) AccountNRG() bool {
return si.Flags&AccountNRG != 0
}
// ClientInfo is detailed information about the client forming a connection.
type ClientInfo struct {
Start *time.Time `json:"start,omitempty"`
@ -315,23 +327,55 @@ type ClientInfo struct {
Nonce string `json:"nonce,omitempty"`
}
// forAssignmentSnap returns the minimum amount of ClientInfo we need for assignment snapshots.
func (ci *ClientInfo) forAssignmentSnap() *ClientInfo {
return &ClientInfo{
Account: ci.Account,
Service: ci.Service,
Cluster: ci.Cluster,
}
}
// forProposal returns the minimum amount of ClientInfo we need for assignment proposals.
func (ci *ClientInfo) forProposal() *ClientInfo {
if ci == nil {
return nil
}
cci := *ci
cci.Jwt = _EMPTY_
cci.IssuerKey = _EMPTY_
return &cci
}
// forAdvisory returns the minimum amount of ClientInfo we need for JS advisory events.
func (ci *ClientInfo) forAdvisory() *ClientInfo {
if ci == nil {
return nil
}
cci := *ci
cci.Jwt = _EMPTY_
cci.Alternates = nil
return &cci
}
// ServerStats hold various statistics that we will periodically send out.
type ServerStats struct {
Start time.Time `json:"start"`
Mem int64 `json:"mem"`
Cores int `json:"cores"`
CPU float64 `json:"cpu"`
Connections int `json:"connections"`
TotalConnections uint64 `json:"total_connections"`
ActiveAccounts int `json:"active_accounts"`
NumSubs uint32 `json:"subscriptions"`
Sent DataStats `json:"sent"`
Received DataStats `json:"received"`
SlowConsumers int64 `json:"slow_consumers"`
Routes []*RouteStat `json:"routes,omitempty"`
Gateways []*GatewayStat `json:"gateways,omitempty"`
ActiveServers int `json:"active_servers,omitempty"`
JetStream *JetStreamVarz `json:"jetstream,omitempty"`
Start time.Time `json:"start"`
Mem int64 `json:"mem"`
Cores int `json:"cores"`
CPU float64 `json:"cpu"`
Connections int `json:"connections"`
TotalConnections uint64 `json:"total_connections"`
ActiveAccounts int `json:"active_accounts"`
NumSubs uint32 `json:"subscriptions"`
Sent DataStats `json:"sent"`
Received DataStats `json:"received"`
SlowConsumers int64 `json:"slow_consumers"`
SlowConsumersStats *SlowConsumersStats `json:"slow_consumer_stats,omitempty"`
Routes []*RouteStat `json:"routes,omitempty"`
Gateways []*GatewayStat `json:"gateways,omitempty"`
ActiveServers int `json:"active_servers,omitempty"`
JetStream *JetStreamVarz `json:"jetstream,omitempty"`
}
// RouteStat holds route statistics.
@ -475,10 +519,14 @@ RESET:
si.Version = VERSION
si.Time = time.Now().UTC()
si.Tags = tags
si.Flags = 0
if js {
// New capability based flags.
si.SetJetStreamEnabled()
si.SetBinaryStreamSnapshot()
if s.accountNRGAllowed.Load() {
si.SetAccountNRG()
}
}
}
var b []byte
@ -653,7 +701,7 @@ func (s *Server) sendInternalAccountMsgWithReply(a *Account, subject, reply stri
}
// Send system style message to an account scope.
func (s *Server) sendInternalAccountSysMsg(a *Account, subj string, si *ServerInfo, msg interface{}) {
func (s *Server) sendInternalAccountSysMsg(a *Account, subj string, si *ServerInfo, msg any, ct compressionType) {
s.mu.RLock()
if s.sys == nil || s.sys.sendq == nil || a == nil {
s.mu.RUnlock()
@ -666,7 +714,7 @@ func (s *Server) sendInternalAccountSysMsg(a *Account, subj string, si *ServerIn
c := a.internalClient()
a.mu.Unlock()
sendq.push(newPubMsg(c, subj, _EMPTY_, si, nil, msg, noCompression, false, false))
sendq.push(newPubMsg(c, subj, _EMPTY_, si, nil, msg, ct, false, false))
}
// This will queue up a message to be sent.
@ -864,6 +912,16 @@ func (s *Server) sendStatsz(subj string) {
m.Stats.Sent.Msgs = atomic.LoadInt64(&s.outMsgs)
m.Stats.Sent.Bytes = atomic.LoadInt64(&s.outBytes)
m.Stats.SlowConsumers = atomic.LoadInt64(&s.slowConsumers)
// Evaluate the slow consumer stats, but set it only if one of the value is not 0.
scs := &SlowConsumersStats{
Clients: s.NumSlowConsumersClients(),
Routes: s.NumSlowConsumersRoutes(),
Gateways: s.NumSlowConsumersGateways(),
Leafs: s.NumSlowConsumersLeafs(),
}
if scs.Clients != 0 || scs.Routes != 0 || scs.Gateways != 0 || scs.Leafs != 0 {
m.Stats.SlowConsumersStats = scs
}
m.Stats.NumSubs = s.numSubscriptions()
// Routes
s.forEachRoute(func(r *client) {
@ -949,6 +1007,7 @@ func (s *Server) sendStatsz(subj string) {
jStat.Meta.Pending = ipq.len()
}
}
jStat.Limits = &s.getOpts().JetStreamLimits
m.Stats.JetStream = jStat
s.mu.RLock()
}
@ -1184,6 +1243,14 @@ func (s *Server) initEventTracking() {
optz := &ExpvarzEventOptions{}
s.zReq(c, reply, hdr, msg, &optz.EventFilterOptions, optz, func() (any, error) { return s.expvarz(optz), nil })
},
"IPQUEUESZ": func(sub *subscription, c *client, _ *Account, subject, reply string, hdr, msg []byte) {
optz := &IpqueueszEventOptions{}
s.zReq(c, reply, hdr, msg, &optz.EventFilterOptions, optz, func() (any, error) { return s.Ipqueuesz(&optz.IpqueueszOptions), nil })
},
"RAFTZ": func(sub *subscription, c *client, _ *Account, subject, reply string, hdr, msg []byte) {
optz := &RaftzEventOptions{}
s.zReq(c, reply, hdr, msg, &optz.EventFilterOptions, optz, func() (any, error) { return s.Raftz(&optz.RaftzOptions), nil })
},
}
profilez := func(_ *subscription, c *client, _ *Account, _, rply string, rmsg []byte) {
hdr, msg := c.msgParts(rmsg)
@ -1618,7 +1685,8 @@ func (s *Server) remoteServerUpdate(sub *subscription, c *client, _ *Account, su
}
node := getHash(si.Name)
s.nodeToInfo.Store(node, nodeInfo{
accountNRG := si.AccountNRG()
oldInfo, _ := s.nodeToInfo.Swap(node, nodeInfo{
si.Name,
si.Version,
si.Cluster,
@ -1630,7 +1698,14 @@ func (s *Server) remoteServerUpdate(sub *subscription, c *client, _ *Account, su
false,
si.JetStreamEnabled(),
si.BinaryStreamSnapshot(),
accountNRG,
})
if oldInfo == nil || accountNRG != oldInfo.(nodeInfo).accountNRG {
// One of the servers we received statsz from changed its mind about
// whether or not it supports in-account NRG, so update the groups
// with this information.
s.updateNRGAccountStatus()
}
}
// updateRemoteServer is called when we have an update from a remote server.
@ -1677,14 +1752,35 @@ func (s *Server) processNewServer(si *ServerInfo) {
false,
si.JetStreamEnabled(),
si.BinaryStreamSnapshot(),
si.AccountNRG(),
})
}
}
go s.updateNRGAccountStatus()
// Announce ourselves..
// Do this in a separate Go routine.
go s.sendStatszUpdate()
}
// Works out whether all nodes support moving the NRG traffic into
// the account and moves it appropriately.
// Server lock MUST NOT be held on entry.
func (s *Server) updateNRGAccountStatus() {
s.rnMu.RLock()
raftNodes := make([]RaftNode, 0, len(s.raftNodes))
for _, n := range s.raftNodes {
raftNodes = append(raftNodes, n)
}
s.rnMu.RUnlock()
for _, n := range raftNodes {
// In the event that the node is happy that all nodes that
// it cares about haven't changed, this will be a no-op.
if err := n.RecreateInternalSubs(); err != nil {
n.Stop()
}
}
}
// If GW is enabled on this server and there are any leaf node connections,
// this function will send a LeafNode connect system event to the super cluster
// to ensure that the GWs are in interest-only mode for this account.
@ -1890,6 +1986,18 @@ type ExpvarzEventOptions struct {
EventFilterOptions
}
// In the context of system events, IpqueueszEventOptions are options passed to Ipqueuesz
type IpqueueszEventOptions struct {
EventFilterOptions
IpqueueszOptions
}
// In the context of system events, RaftzEventOptions are options passed to Raftz
type RaftzEventOptions struct {
EventFilterOptions
RaftzOptions
}
// returns true if the request does NOT apply to this server and can be ignored.
// DO NOT hold the server lock when
func (s *Server) filterRequest(fOpts *EventFilterOptions) bool {
@ -1938,7 +2046,9 @@ type ServerAPIResponse struct {
compress compressionType
}
// Specialized response types for unmarshalling.
// Specialized response types for unmarshalling. These structures are not
// used in the server code and only there for users of the Z endpoints to
// unmarshal the data without having to create these structs in their code
// ServerAPIConnzResponse is the response type connz
type ServerAPIConnzResponse struct {
@ -1947,6 +2057,83 @@ type ServerAPIConnzResponse struct {
Error *ApiError `json:"error,omitempty"`
}
// ServerAPIRoutezResponse is the response type for routez
type ServerAPIRoutezResponse struct {
Server *ServerInfo `json:"server"`
Data *Routez `json:"data,omitempty"`
Error *ApiError `json:"error,omitempty"`
}
// ServerAPIGatewayzResponse is the response type for gatewayz
type ServerAPIGatewayzResponse struct {
Server *ServerInfo `json:"server"`
Data *Gatewayz `json:"data,omitempty"`
Error *ApiError `json:"error,omitempty"`
}
// ServerAPIJszResponse is the response type for jsz
type ServerAPIJszResponse struct {
Server *ServerInfo `json:"server"`
Data *JSInfo `json:"data,omitempty"`
Error *ApiError `json:"error,omitempty"`
}
// ServerAPIHealthzResponse is the response type for healthz
type ServerAPIHealthzResponse struct {
Server *ServerInfo `json:"server"`
Data *HealthStatus `json:"data,omitempty"`
Error *ApiError `json:"error,omitempty"`
}
// ServerAPIVarzResponse is the response type for varz
type ServerAPIVarzResponse struct {
Server *ServerInfo `json:"server"`
Data *Varz `json:"data,omitempty"`
Error *ApiError `json:"error,omitempty"`
}
// ServerAPISubszResponse is the response type for subsz
type ServerAPISubszResponse struct {
Server *ServerInfo `json:"server"`
Data *Subsz `json:"data,omitempty"`
Error *ApiError `json:"error,omitempty"`
}
// ServerAPILeafzResponse is the response type for leafz
type ServerAPILeafzResponse struct {
Server *ServerInfo `json:"server"`
Data *Leafz `json:"data,omitempty"`
Error *ApiError `json:"error,omitempty"`
}
// ServerAPIAccountzResponse is the response type for accountz
type ServerAPIAccountzResponse struct {
Server *ServerInfo `json:"server"`
Data *Accountz `json:"data,omitempty"`
Error *ApiError `json:"error,omitempty"`
}
// ServerAPIExpvarzResponse is the response type for expvarz
type ServerAPIExpvarzResponse struct {
Server *ServerInfo `json:"server"`
Data *ExpvarzStatus `json:"data,omitempty"`
Error *ApiError `json:"error,omitempty"`
}
// ServerAPIpqueueszResponse is the response type for ipqueuesz
type ServerAPIpqueueszResponse struct {
Server *ServerInfo `json:"server"`
Data *IpqueueszStatus `json:"data,omitempty"`
Error *ApiError `json:"error,omitempty"`
}
// ServerAPIRaftzResponse is the response type for raftz
type ServerAPIRaftzResponse struct {
Server *ServerInfo `json:"server"`
Data *RaftzStatus `json:"data,omitempty"`
Error *ApiError `json:"error,omitempty"`
}
// statszReq is a request for us to respond with current statsz.
func (s *Server) statszReq(sub *subscription, c *client, _ *Account, subject, reply string, hdr, msg []byte) {
if !s.EventsEnabled() {
@ -2208,6 +2395,7 @@ func (a *Account) statz() *AccountStat {
leafConns := a.numLocalLeafNodes()
return &AccountStat{
Account: a.Name,
Name: a.getNameTagLocked(),
Conns: localConns,
LeafNodes: leafConns,
TotalConns: localConns + leafConns,
@ -2278,7 +2466,7 @@ func (s *Server) accountConnectEvent(c *client) {
Jwt: c.opts.JWT,
IssuerKey: issuerForClient(c),
Tags: c.tags,
NameTag: c.nameTag,
NameTag: c.acc.getNameTag(),
Kind: c.kindString(),
ClientType: c.clientTypeString(),
MQTTClient: c.getMQTTClientID(),
@ -2330,7 +2518,7 @@ func (s *Server) accountDisconnectEvent(c *client, now time.Time, reason string)
Jwt: c.opts.JWT,
IssuerKey: issuerForClient(c),
Tags: c.tags,
NameTag: c.nameTag,
NameTag: c.acc.getNameTag(),
Kind: c.kindString(),
ClientType: c.clientTypeString(),
MQTTClient: c.getMQTTClientID(),
@ -2384,7 +2572,7 @@ func (s *Server) sendAuthErrorEvent(c *client) {
Jwt: c.opts.JWT,
IssuerKey: issuerForClient(c),
Tags: c.tags,
NameTag: c.nameTag,
NameTag: c.acc.getNameTag(),
Kind: c.kindString(),
ClientType: c.clientTypeString(),
MQTTClient: c.getMQTTClientID(),
@ -2442,7 +2630,7 @@ func (s *Server) sendAccountAuthErrorEvent(c *client, acc *Account, reason strin
Jwt: c.opts.JWT,
IssuerKey: issuerForClient(c),
Tags: c.tags,
NameTag: c.nameTag,
NameTag: c.acc.getNameTag(),
Kind: c.kindString(),
ClientType: c.clientTypeString(),
MQTTClient: c.getMQTTClientID(),
@ -2459,7 +2647,7 @@ func (s *Server) sendAccountAuthErrorEvent(c *client, acc *Account, reason strin
}
c.mu.Unlock()
s.sendInternalAccountSysMsg(acc, authErrorAccountEventSubj, &m.Server, &m)
s.sendInternalAccountSysMsg(acc, authErrorAccountEventSubj, &m.Server, &m, noCompression)
}
// Internal message callback.

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
// Copyright 2020 The NATS Authors
// Copyright 2020-2022 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -12,7 +12,6 @@
// limitations under the License.
//go:build gofuzz
// +build gofuzz
package server

View File

@ -1,4 +1,4 @@
// Copyright 2018-2023 The NATS Authors
// Copyright 2018-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -19,12 +19,14 @@ import (
"crypto/sha256"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"math/rand"
"net"
"net/url"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@ -217,6 +219,8 @@ type gateway struct {
// interest-only mode "immediately", so the outbound should disregard
// the optimistic mode when checking for interest.
interestOnlyMode bool
// Name of the remote server
remoteName string
}
// Outbound subject interest entry.
@ -298,17 +302,20 @@ func (r *RemoteGatewayOpts) clone() *RemoteGatewayOpts {
// Ensure that gateway is properly configured.
func validateGatewayOptions(o *Options) error {
if o.Gateway.Name == "" && o.Gateway.Port == 0 {
if o.Gateway.Name == _EMPTY_ && o.Gateway.Port == 0 {
return nil
}
if o.Gateway.Name == "" {
return fmt.Errorf("gateway has no name")
if o.Gateway.Name == _EMPTY_ {
return errors.New("gateway has no name")
}
if strings.Contains(o.Gateway.Name, " ") {
return ErrGatewayNameHasSpaces
}
if o.Gateway.Port == 0 {
return fmt.Errorf("gateway %q has no port specified (select -1 for random port)", o.Gateway.Name)
}
for i, g := range o.Gateway.Gateways {
if g.Name == "" {
if g.Name == _EMPTY_ {
return fmt.Errorf("gateway in the list %d has no name", i)
}
if len(g.URLs) == 0 {
@ -528,6 +535,7 @@ func (s *Server) startGatewayAcceptLoop() {
Gateway: opts.Gateway.Name,
GatewayNRP: true,
Headers: s.supportsHeaders(),
Proto: s.getServerProto(),
}
// Unless in some tests we want to keep the old behavior, we are now
// (since v2.9.0) indicate that this server will switch all accounts
@ -1035,6 +1043,10 @@ func (c *client) processGatewayInfo(info *Info) {
}
if isFirstINFO {
c.opts.Name = info.ID
// Get the protocol version from the INFO protocol. This will be checked
// to see if this connection supports message tracing for instance.
c.opts.Protocol = info.Proto
c.gw.remoteName = info.Name
}
c.mu.Unlock()
@ -1900,7 +1912,7 @@ func (c *client) processGatewayAccountSub(accName string) error {
// the sublist if present.
// <Invoked from outbound connection's readLoop>
func (c *client) processGatewayRUnsub(arg []byte) error {
accName, subject, queue, err := c.parseUnsubProto(arg)
_, accName, subject, queue, err := c.parseUnsubProto(arg, true, false)
if err != nil {
return fmt.Errorf("processGatewaySubjectUnsub %s", err.Error())
}
@ -2400,7 +2412,7 @@ func (s *Server) gatewayUpdateSubInterest(accName string, sub *subscription, cha
if change < 0 {
return
}
entry = &sitally{n: 1, q: sub.queue != nil}
entry = &sitally{n: change, q: sub.queue != nil}
st[string(key)] = entry
first = true
} else {
@ -2499,8 +2511,13 @@ var subPool = &sync.Pool{
// that the message is not sent to a given gateway if for instance
// it is known that this gateway has no interest in the account or
// subject, etc..
// When invoked from a LEAF connection, `checkLeafQF` should be passed as `true`
// so that we skip any queue subscription interest that is not part of the
// `c.pa.queues` filter (similar to what we do in `processMsgResults`). However,
// when processing service imports, then this boolean should be passes as `false`,
// regardless if it is a LEAF connection or not.
// <Invoked from any client connection's readLoop>
func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgroups [][]byte) bool {
func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgroups [][]byte, checkLeafQF bool) bool {
// We had some times when we were sending across a GW with no subject, and the other side would break
// due to parser error. These need to be fixed upstream but also double check here.
if len(subject) == 0 {
@ -2523,6 +2540,14 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr
if len(gws) == 0 {
return false
}
mt, _ := c.isMsgTraceEnabled()
if mt != nil {
pa := c.pa
msg = mt.setOriginAccountHeaderIfNeeded(c, acc, msg)
defer func() { c.pa = pa }()
}
var (
queuesa = [512]byte{}
queues = queuesa[:0]
@ -2577,6 +2602,21 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr
qsubs := qr.qsubs[i]
if len(qsubs) > 0 {
queue := qsubs[0].queue
if checkLeafQF {
// Skip any queue that is not in the leaf's queue filter.
skip := true
for _, qn := range c.pa.queues {
if bytes.Equal(queue, qn) {
skip = false
break
}
}
if skip {
continue
}
// Now we still need to check that it was not delivered
// locally by checking the given `qgroups`.
}
add := true
for _, qn := range qgroups {
if bytes.Equal(queue, qn) {
@ -2615,6 +2655,11 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr
mreply = append(mreply, reply...)
}
}
if mt != nil {
msg = mt.setHopHeader(c, msg)
}
// Setup the message header.
// Make sure we are an 'R' proto by default
c.msgb[0] = 'R'
@ -2969,7 +3014,7 @@ func (c *client) handleGatewayReply(msg []byte) (processed bool) {
// we now need to send the message with the real subject to
// gateways in case they have interest on that reply subject.
if !isServiceReply {
c.sendMsgToGateways(acc, msg, c.pa.subject, c.pa.reply, queues)
c.sendMsgToGateways(acc, msg, c.pa.subject, c.pa.reply, queues, false)
}
} else if c.kind == GATEWAY {
// Only if we are a gateway connection should we try to route

View File

@ -0,0 +1,532 @@
// Copyright 2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gsl
import (
"errors"
"sync"
"github.com/nats-io/nats-server/v2/server/stree"
)
// Sublist is a routing mechanism to handle subject distribution and
// provides a facility to match subjects from published messages to
// interested subscribers. Subscribers can have wildcard subjects to
// match multiple published subjects.
// Common byte variables for wildcards and token separator.
const (
pwc = '*'
pwcs = "*"
fwc = '>'
fwcs = ">"
tsep = "."
btsep = '.'
_EMPTY_ = ""
)
// Sublist related errors
var (
ErrInvalidSubject = errors.New("gsl: invalid subject")
ErrNotFound = errors.New("gsl: no matches found")
ErrNilChan = errors.New("gsl: nil channel")
ErrAlreadyRegistered = errors.New("gsl: notification already registered")
)
// A GenericSublist stores and efficiently retrieves subscriptions.
type GenericSublist[T comparable] struct {
sync.RWMutex
root *level[T]
count uint32
}
// A node contains subscriptions and a pointer to the next level.
type node[T comparable] struct {
next *level[T]
subs map[T]string // value -> subject
}
// A level represents a group of nodes and special pointers to
// wildcard nodes.
type level[T comparable] struct {
nodes map[string]*node[T]
pwc, fwc *node[T]
}
// Create a new default node.
func newNode[T comparable]() *node[T] {
return &node[T]{subs: make(map[T]string)}
}
// Create a new default level.
func newLevel[T comparable]() *level[T] {
return &level[T]{nodes: make(map[string]*node[T])}
}
// NewSublist will create a default sublist with caching enabled per the flag.
func NewSublist[T comparable]() *GenericSublist[T] {
return &GenericSublist[T]{root: newLevel[T]()}
}
// Insert adds a subscription into the sublist
func (s *GenericSublist[T]) Insert(subject string, value T) error {
tsa := [32]string{}
tokens := tsa[:0]
start := 0
for i := 0; i < len(subject); i++ {
if subject[i] == btsep {
tokens = append(tokens, subject[start:i])
start = i + 1
}
}
tokens = append(tokens, subject[start:])
s.Lock()
var sfwc bool
var n *node[T]
l := s.root
for _, t := range tokens {
lt := len(t)
if lt == 0 || sfwc {
s.Unlock()
return ErrInvalidSubject
}
if lt > 1 {
n = l.nodes[t]
} else {
switch t[0] {
case pwc:
n = l.pwc
case fwc:
n = l.fwc
sfwc = true
default:
n = l.nodes[t]
}
}
if n == nil {
n = newNode[T]()
if lt > 1 {
l.nodes[t] = n
} else {
switch t[0] {
case pwc:
l.pwc = n
case fwc:
l.fwc = n
default:
l.nodes[t] = n
}
}
}
if n.next == nil {
n.next = newLevel[T]()
}
l = n.next
}
n.subs[value] = subject
s.count++
s.Unlock()
return nil
}
// Match will match all entries to the literal subject.
// It will return a set of results for both normal and queue subscribers.
func (s *GenericSublist[T]) Match(subject string, cb func(T)) {
s.match(subject, cb, true)
}
// MatchBytes will match all entries to the literal subject.
// It will return a set of results for both normal and queue subscribers.
func (s *GenericSublist[T]) MatchBytes(subject []byte, cb func(T)) {
s.match(string(subject), cb, true)
}
// HasInterest will return whether or not there is any interest in the subject.
// In cases where more detail is not required, this may be faster than Match.
func (s *GenericSublist[T]) HasInterest(subject string) bool {
return s.hasInterest(subject, true, nil)
}
// NumInterest will return the number of subs interested in the subject.
// In cases where more detail is not required, this may be faster than Match.
func (s *GenericSublist[T]) NumInterest(subject string) (np int) {
s.hasInterest(subject, true, &np)
return
}
func (s *GenericSublist[T]) match(subject string, cb func(T), doLock bool) {
tsa := [32]string{}
tokens := tsa[:0]
start := 0
for i := 0; i < len(subject); i++ {
if subject[i] == btsep {
if i-start == 0 {
return
}
tokens = append(tokens, subject[start:i])
start = i + 1
}
}
if start >= len(subject) {
return
}
tokens = append(tokens, subject[start:])
if doLock {
s.RLock()
defer s.RUnlock()
}
matchLevel(s.root, tokens, cb)
}
func (s *GenericSublist[T]) hasInterest(subject string, doLock bool, np *int) bool {
tsa := [32]string{}
tokens := tsa[:0]
start := 0
for i := 0; i < len(subject); i++ {
if subject[i] == btsep {
if i-start == 0 {
return false
}
tokens = append(tokens, subject[start:i])
start = i + 1
}
}
if start >= len(subject) {
return false
}
tokens = append(tokens, subject[start:])
if doLock {
s.RLock()
defer s.RUnlock()
}
return matchLevelForAny(s.root, tokens, np)
}
func matchLevelForAny[T comparable](l *level[T], toks []string, np *int) bool {
var pwc, n *node[T]
for i, t := range toks {
if l == nil {
return false
}
if l.fwc != nil {
if np != nil {
*np += len(l.fwc.subs)
}
return true
}
if pwc = l.pwc; pwc != nil {
if match := matchLevelForAny(pwc.next, toks[i+1:], np); match {
return true
}
}
n = l.nodes[t]
if n != nil {
l = n.next
} else {
l = nil
}
}
if n != nil {
if np != nil {
*np += len(n.subs)
}
return len(n.subs) > 0
}
if pwc != nil {
if np != nil {
*np += len(pwc.subs)
}
return len(pwc.subs) > 0
}
return false
}
// callbacksForResults will make the necessary callbacks for each
// result in this node.
func callbacksForResults[T comparable](n *node[T], cb func(T)) {
for sub := range n.subs {
cb(sub)
}
}
// matchLevel is used to recursively descend into the trie.
func matchLevel[T comparable](l *level[T], toks []string, cb func(T)) {
var pwc, n *node[T]
for i, t := range toks {
if l == nil {
return
}
if l.fwc != nil {
callbacksForResults(l.fwc, cb)
}
if pwc = l.pwc; pwc != nil {
matchLevel(pwc.next, toks[i+1:], cb)
}
n = l.nodes[t]
if n != nil {
l = n.next
} else {
l = nil
}
}
if n != nil {
callbacksForResults(n, cb)
}
if pwc != nil {
callbacksForResults(pwc, cb)
}
}
// lnt is used to track descent into levels for a removal for pruning.
type lnt[T comparable] struct {
l *level[T]
n *node[T]
t string
}
// Raw low level remove, can do batches with lock held outside.
func (s *GenericSublist[T]) remove(subject string, value T, shouldLock bool) error {
tsa := [32]string{}
tokens := tsa[:0]
start := 0
for i := 0; i < len(subject); i++ {
if subject[i] == btsep {
tokens = append(tokens, subject[start:i])
start = i + 1
}
}
tokens = append(tokens, subject[start:])
if shouldLock {
s.Lock()
defer s.Unlock()
}
var sfwc bool
var n *node[T]
l := s.root
// Track levels for pruning
var lnts [32]lnt[T]
levels := lnts[:0]
for _, t := range tokens {
lt := len(t)
if lt == 0 || sfwc {
return ErrInvalidSubject
}
if l == nil {
return ErrNotFound
}
if lt > 1 {
n = l.nodes[t]
} else {
switch t[0] {
case pwc:
n = l.pwc
case fwc:
n = l.fwc
sfwc = true
default:
n = l.nodes[t]
}
}
if n != nil {
levels = append(levels, lnt[T]{l, n, t})
l = n.next
} else {
l = nil
}
}
if !s.removeFromNode(n, value) {
return ErrNotFound
}
s.count--
for i := len(levels) - 1; i >= 0; i-- {
l, n, t := levels[i].l, levels[i].n, levels[i].t
if n.isEmpty() {
l.pruneNode(n, t)
}
}
return nil
}
// Remove will remove a subscription.
func (s *GenericSublist[T]) Remove(subject string, value T) error {
return s.remove(subject, value, true)
}
// pruneNode is used to prune an empty node from the tree.
func (l *level[T]) pruneNode(n *node[T], t string) {
if n == nil {
return
}
if n == l.fwc {
l.fwc = nil
} else if n == l.pwc {
l.pwc = nil
} else {
delete(l.nodes, t)
}
}
// isEmpty will test if the node has any entries. Used
// in pruning.
func (n *node[T]) isEmpty() bool {
return len(n.subs) == 0 && (n.next == nil || n.next.numNodes() == 0)
}
// Return the number of nodes for the given level.
func (l *level[T]) numNodes() int {
num := len(l.nodes)
if l.pwc != nil {
num++
}
if l.fwc != nil {
num++
}
return num
}
// Remove the sub for the given node.
func (s *GenericSublist[T]) removeFromNode(n *node[T], value T) (found bool) {
if n == nil {
return false
}
if _, found = n.subs[value]; found {
delete(n.subs, value)
}
return found
}
// Count returns the number of subscriptions.
func (s *GenericSublist[T]) Count() uint32 {
s.RLock()
defer s.RUnlock()
return s.count
}
// numLevels will return the maximum number of levels
// contained in the Sublist tree.
func (s *GenericSublist[T]) numLevels() int {
return visitLevel(s.root, 0)
}
// visitLevel is used to descend the Sublist tree structure
// recursively.
func visitLevel[T comparable](l *level[T], depth int) int {
if l == nil || l.numNodes() == 0 {
return depth
}
depth++
maxDepth := depth
for _, n := range l.nodes {
if n == nil {
continue
}
newDepth := visitLevel(n.next, depth)
if newDepth > maxDepth {
maxDepth = newDepth
}
}
if l.pwc != nil {
pwcDepth := visitLevel(l.pwc.next, depth)
if pwcDepth > maxDepth {
maxDepth = pwcDepth
}
}
if l.fwc != nil {
fwcDepth := visitLevel(l.fwc.next, depth)
if fwcDepth > maxDepth {
maxDepth = fwcDepth
}
}
return maxDepth
}
// IntersectStree will match all items in the given subject tree that
// have interest expressed in the given sublist. The callback will only be called
// once for each subject, regardless of overlapping subscriptions in the sublist.
func IntersectStree[T1 any, T2 comparable](st *stree.SubjectTree[T1], sl *GenericSublist[T2], cb func(subj []byte, entry *T1)) {
var _subj [255]byte
intersectStree(st, sl.root, _subj[:0], cb)
}
func intersectStree[T1 any, T2 comparable](st *stree.SubjectTree[T1], r *level[T2], subj []byte, cb func(subj []byte, entry *T1)) {
if r.numNodes() == 0 {
// For wildcards we can't avoid Match, but if it's a literal subject at
// this point, using Find is considerably cheaper.
if subjectHasWildcard(string(subj)) {
st.Match(subj, cb)
} else if e, ok := st.Find(subj); ok {
cb(subj, e)
}
return
}
nsubj := subj
if len(nsubj) > 0 {
nsubj = append(subj, '.')
}
switch {
case r.fwc != nil:
// We've reached a full wildcard, do a FWC match on the stree at this point
// and don't keep iterating downward.
nsubj := append(nsubj, '>')
st.Match(nsubj, cb)
case r.pwc != nil:
// We've found a partial wildcard. We'll keep iterating downwards, but first
// check whether there's interest at this level (without triggering dupes) and
// match if so.
nsubj := append(nsubj, '*')
if len(r.pwc.subs) > 0 && r.pwc.next != nil && r.pwc.next.numNodes() > 0 {
st.Match(nsubj, cb)
}
intersectStree(st, r.pwc.next, nsubj, cb)
case r.numNodes() > 0:
// Normal node with subject literals, keep iterating.
for t, n := range r.nodes {
nsubj := append(nsubj, t...)
intersectStree(st, n.next, nsubj, cb)
}
}
}
// Determine if a subject has any wildcard tokens.
func subjectHasWildcard(subject string) bool {
// This one exits earlier then !subjectIsLiteral(subject)
for i, c := range subject {
if c == pwc || c == fwc {
if (i == 0 || subject[i-1] == btsep) &&
(i+1 == len(subject) || subject[i+1] == btsep) {
return true
}
}
}
return false
}

View File

@ -1,4 +1,4 @@
// Copyright 2021-2023 The NATS Authors
// Copyright 2021-2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -14,6 +14,7 @@
package server
import (
"errors"
"sync"
"sync/atomic"
)
@ -28,36 +29,79 @@ type ipQueue[T any] struct {
elts []T
pos int
pool *sync.Pool
mrs int
sz uint64 // Calculated size (only if calc != nil)
name string
m *sync.Map
ipQueueOpts[T]
}
type ipQueueOpts struct {
maxRecycleSize int
type ipQueueOpts[T any] struct {
mrs int // Max recycle size
calc func(e T) uint64 // Calc function for tracking size
msz uint64 // Limit by total calculated size
mlen int // Limit by number of entries
}
type ipQueueOpt func(*ipQueueOpts)
type ipQueueOpt[T any] func(*ipQueueOpts[T])
// This option allows to set the maximum recycle size when attempting
// to put back a slice to the pool.
func ipQueue_MaxRecycleSize(max int) ipQueueOpt {
return func(o *ipQueueOpts) {
o.maxRecycleSize = max
func ipqMaxRecycleSize[T any](max int) ipQueueOpt[T] {
return func(o *ipQueueOpts[T]) {
o.mrs = max
}
}
func newIPQueue[T any](s *Server, name string, opts ...ipQueueOpt) *ipQueue[T] {
qo := ipQueueOpts{maxRecycleSize: ipQueueDefaultMaxRecycleSize}
for _, o := range opts {
o(&qo)
// This option enables total queue size counting by passing in a function
// that evaluates the size of each entry as it is pushed/popped. This option
// enables the size() function.
func ipqSizeCalculation[T any](calc func(e T) uint64) ipQueueOpt[T] {
return func(o *ipQueueOpts[T]) {
o.calc = calc
}
}
// This option allows setting the maximum queue size. Once the limit is
// reached, then push() will stop returning true and no more entries will
// be stored until some more are popped. The ipQueue_SizeCalculation must
// be provided for this to work.
func ipqLimitBySize[T any](max uint64) ipQueueOpt[T] {
return func(o *ipQueueOpts[T]) {
o.msz = max
}
}
// This option allows setting the maximum queue length. Once the limit is
// reached, then push() will stop returning true and no more entries will
// be stored until some more are popped.
func ipqLimitByLen[T any](max int) ipQueueOpt[T] {
return func(o *ipQueueOpts[T]) {
o.mlen = max
}
}
var errIPQLenLimitReached = errors.New("IPQ len limit reached")
var errIPQSizeLimitReached = errors.New("IPQ size limit reached")
func newIPQueue[T any](s *Server, name string, opts ...ipQueueOpt[T]) *ipQueue[T] {
q := &ipQueue[T]{
ch: make(chan struct{}, 1),
mrs: qo.maxRecycleSize,
pool: &sync.Pool{},
ch: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() any {
// Reason we use pointer to slice instead of slice is explained
// here: https://staticcheck.io/docs/checks#SA6002
res := make([]T, 0, 32)
return &res
},
},
name: name,
m: &s.ipQueues,
ipQueueOpts: ipQueueOpts[T]{
mrs: ipQueueDefaultMaxRecycleSize,
},
}
for _, o := range opts {
o(&q.ipQueueOpts)
}
s.ipQueues.Store(name, q)
return q
@ -66,32 +110,34 @@ func newIPQueue[T any](s *Server, name string, opts ...ipQueueOpt) *ipQueue[T] {
// Add the element `e` to the queue, notifying the queue channel's `ch` if the
// entry is the first to be added, and returns the length of the queue after
// this element is added.
func (q *ipQueue[T]) push(e T) int {
var signal bool
func (q *ipQueue[T]) push(e T) (int, error) {
q.Lock()
l := len(q.elts) - q.pos
if l == 0 {
signal = true
eltsi := q.pool.Get()
if eltsi != nil {
// Reason we use pointer to slice instead of slice is explained
// here: https://staticcheck.io/docs/checks#SA6002
q.elts = (*(eltsi.(*[]T)))[:0]
}
if cap(q.elts) == 0 {
q.elts = make([]T, 0, 32)
if q.mlen > 0 && l == q.mlen {
q.Unlock()
return l, errIPQLenLimitReached
}
if q.calc != nil {
sz := q.calc(e)
if q.msz > 0 && q.sz+sz > q.msz {
q.Unlock()
return l, errIPQSizeLimitReached
}
q.sz += sz
}
if q.elts == nil {
// What comes out of the pool is already of size 0, so no need for [:0].
q.elts = *(q.pool.Get().(*[]T))
}
q.elts = append(q.elts, e)
l++
q.Unlock()
if signal {
if l == 0 {
select {
case q.ch <- struct{}{}:
default:
}
}
return l
return l + 1, nil
}
// Returns the whole list of elements currently present in the queue,
@ -107,24 +153,23 @@ func (q *ipQueue[T]) pop() []T {
if q == nil {
return nil
}
var elts []T
q.Lock()
if len(q.elts)-q.pos == 0 {
q.Unlock()
return nil
}
var elts []T
if q.pos == 0 {
elts = q.elts
} else {
elts = q.elts[q.pos:]
}
q.elts, q.pos = nil, 0
q.elts, q.pos, q.sz = nil, 0, 0
atomic.AddInt64(&q.inprogress, int64(len(elts)))
q.Unlock()
return elts
}
func (q *ipQueue[T]) resetAndReturnToPool(elts *[]T) {
(*elts) = (*elts)[:0]
q.pool.Put(elts)
}
// Returns the first element from the queue, if any. See comment above
// regarding calling after being notified that there is something and
// the use of drain(). In short, the caller should always check the
@ -133,24 +178,30 @@ func (q *ipQueue[T]) resetAndReturnToPool(elts *[]T) {
func (q *ipQueue[T]) popOne() (T, bool) {
q.Lock()
l := len(q.elts) - q.pos
if l < 1 {
if l == 0 {
q.Unlock()
var empty T
return empty, false
}
e := q.elts[q.pos]
q.pos++
l--
if l > 0 {
if l--; l > 0 {
q.pos++
if q.calc != nil {
q.sz -= q.calc(e)
}
// We need to re-signal
select {
case q.ch <- struct{}{}:
default:
}
} else {
// We have just emptied the queue, so we can recycle now.
q.resetAndReturnToPool(&q.elts)
q.elts, q.pos = nil, 0
// We have just emptied the queue, so we can reuse unless it is too big.
if cap(q.elts) <= q.mrs {
q.elts = q.elts[:0]
} else {
q.elts = nil
}
q.pos, q.sz = 0, 0
}
q.Unlock()
return e, true
@ -160,8 +211,7 @@ func (q *ipQueue[T]) popOne() (T, bool) {
// a first element is added to the queue.
// This will also decrement the "in progress" count with the length
// of the slice.
// Reason we use pointer to slice instead of slice is explained
// here: https://staticcheck.io/docs/checks#SA6002
// WARNING: The caller MUST never reuse `elts`.
func (q *ipQueue[T]) recycle(elts *[]T) {
// If invoked with a nil list, nothing to do.
if elts == nil || *elts == nil {
@ -169,39 +219,44 @@ func (q *ipQueue[T]) recycle(elts *[]T) {
}
// Update the in progress count.
if len(*elts) > 0 {
if atomic.AddInt64(&q.inprogress, int64(-(len(*elts)))) < 0 {
atomic.StoreInt64(&q.inprogress, 0)
}
atomic.AddInt64(&q.inprogress, int64(-(len(*elts))))
}
// We also don't want to recycle huge slices, so check against the max.
// q.mrs is normally immutable but can be changed, in a safe way, in some tests.
if cap(*elts) > q.mrs {
return
}
q.resetAndReturnToPool(elts)
(*elts) = (*elts)[:0]
q.pool.Put(elts)
}
// Returns the current length of the queue.
func (q *ipQueue[T]) len() int {
q.Lock()
l := len(q.elts) - q.pos
q.Unlock()
return l
defer q.Unlock()
return len(q.elts) - q.pos
}
// Returns the calculated size of the queue (if ipQueue_SizeCalculation has been
// passed in), otherwise returns zero.
func (q *ipQueue[T]) size() uint64 {
q.Lock()
defer q.Unlock()
return q.sz
}
// Empty the queue and consumes the notification signal if present.
// Returns the number of items that were drained from the queue.
// Note that this could cause a reader go routine that has been
// notified that there is something in the queue (reading from queue's `ch`)
// may then get nothing if `drain()` is invoked before the `pop()` or `popOne()`.
func (q *ipQueue[T]) drain() {
func (q *ipQueue[T]) drain() int {
if q == nil {
return
return 0
}
q.Lock()
if q.elts != nil {
q.resetAndReturnToPool(&q.elts)
q.elts, q.pos = nil, 0
}
olen := len(q.elts) - q.pos
q.elts, q.pos, q.sz = nil, 0, 0
// Consume the signal if it was present to reduce the chance of a reader
// routine to be think that there is something in the queue...
select {
@ -209,6 +264,7 @@ func (q *ipQueue[T]) drain() {
default:
}
q.Unlock()
return olen
}
// Since the length of the queue goes to 0 after a pop(), it is good to

View File

@ -1,4 +1,4 @@
// Copyright 2019-2024 The NATS Authors
// Copyright 2019-2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -32,6 +32,7 @@ import (
"github.com/minio/highwayhash"
"github.com/nats-io/nats-server/v2/server/sysmem"
"github.com/nats-io/nats-server/v2/server/tpm"
"github.com/nats-io/nkeys"
"github.com/nats-io/nuid"
)
@ -47,6 +48,7 @@ type JetStreamConfig struct {
Domain string `json:"domain,omitempty"`
CompressOK bool `json:"compress_ok,omitempty"`
UniqueTag string `json:"unique_tag,omitempty"`
Strict bool `json:"strict,omitempty"`
}
// Statistics about JetStream for this server.
@ -90,6 +92,7 @@ type JetStreamAccountStats struct {
}
type JetStreamAPIStats struct {
Level int `json:"level"`
Total uint64 `json:"total"`
Errors uint64 `json:"errors"`
Inflight uint64 `json:"inflight,omitempty"`
@ -173,6 +176,9 @@ type jsAccount struct {
updatesSub *subscription
lupdate time.Time
utimer *time.Timer
// Which account to send NRG traffic into. Empty string is system account.
nrgAccount string
}
// Track general usage for this account.
@ -370,6 +376,40 @@ func (s *Server) checkStoreDir(cfg *JetStreamConfig) error {
return nil
}
// This function sets/updates the jetstream encryption key and cipher based
// on options. If the TPM options have been specified, a key is generated
// and sealed by the TPM.
func (s *Server) initJetStreamEncryption() (err error) {
opts := s.getOpts()
// The TPM settings and other encryption settings are mutually exclusive.
if opts.JetStreamKey != _EMPTY_ && opts.JetStreamTpm.KeysFile != _EMPTY_ {
return fmt.Errorf("JetStream encryption key may not be used with TPM options")
}
// if we are using the standard method to set the encryption key just return and carry on.
if opts.JetStreamKey != _EMPTY_ {
return nil
}
// if the tpm options are not used then no encryption has been configured and return.
if opts.JetStreamTpm.KeysFile == _EMPTY_ {
return nil
}
if opts.JetStreamTpm.Pcr == 0 {
// Default PCR to use in the TPM. Values can be 0-23, and most platforms
// reserve values 0-12 for the OS, boot locker, disc encryption, etc.
// 16 used for debugging. In sticking to NATS tradition, we'll use 22
// as the default with the option being configurable.
opts.JetStreamTpm.Pcr = 22
}
// Using the TPM to generate or get the encryption key and update the encryption options.
opts.JetStreamKey, err = tpm.LoadJetStreamEncryptionKeyFromTPM(opts.JetStreamTpm.SrkPassword,
opts.JetStreamTpm.KeysFile, opts.JetStreamTpm.KeyPassword, opts.JetStreamTpm.Pcr)
return err
}
// enableJetStream will start up the JetStream subsystem.
func (s *Server) enableJetStream(cfg JetStreamConfig) error {
js := &jetStream{srv: s, config: cfg, accounts: make(map[string]*jsAccount), apiSubs: NewSublistNoCache()}
@ -402,6 +442,10 @@ func (s *Server) enableJetStream(cfg JetStreamConfig) error {
os.Remove(tmpfile.Name())
}
if err := s.initJetStreamEncryption(); err != nil {
return err
}
// JetStream is an internal service so we need to make sure we have a system account.
// This system account will export the JetStream service endpoints.
if s.SystemAccount() == nil {
@ -419,6 +463,11 @@ func (s *Server) enableJetStream(cfg JetStreamConfig) error {
s.Noticef("")
}
s.Noticef("---------------- JETSTREAM ----------------")
if cfg.Strict {
s.Noticef(" Strict: %t", cfg.Strict)
}
s.Noticef(" Max Memory: %s", friendlyBytes(cfg.MaxMemory))
s.Noticef(" Max Storage: %s", friendlyBytes(cfg.MaxStore))
s.Noticef(" Store Directory: \"%s\"", cfg.StoreDir)
@ -429,6 +478,11 @@ func (s *Server) enableJetStream(cfg JetStreamConfig) error {
if ek := opts.JetStreamKey; ek != _EMPTY_ {
s.Noticef(" Encryption: %s", opts.JetStreamCipher)
}
if opts.JetStreamTpm.KeysFile != _EMPTY_ {
s.Noticef(" TPM File: %q, Pcr: %d", opts.JetStreamTpm.KeysFile,
opts.JetStreamTpm.Pcr)
}
s.Noticef(" API Level: %d", JSApiLevel)
s.Noticef("-------------------------------------------")
// Setup our internal subscriptions.
@ -461,6 +515,8 @@ func (s *Server) enableJetStream(cfg JetStreamConfig) error {
if err := s.enableJetStreamClustering(); err != nil {
return err
}
// Set our atomic bool to clustered.
s.jsClustered.Store(true)
}
// Mark when we are up and running.
@ -506,6 +562,7 @@ func (s *Server) restartJetStream() error {
MaxMemory: opts.JetStreamMaxMemory,
MaxStore: opts.JetStreamMaxStore,
Domain: opts.JetStreamDomain,
Strict: opts.JetStreamStrict,
}
s.Noticef("Restarting JetStream")
err := s.EnableJetStream(&cfg)
@ -965,6 +1022,8 @@ func (s *Server) shutdownJetStream() {
cc.c = nil
}
cc.meta = nil
// Set our atomic bool to false.
s.jsClustered.Store(false)
}
js.mu.Unlock()
@ -1404,7 +1463,7 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits) erro
// the consumer can reconnect. We will create it as a durable and switch it.
cfg.ConsumerConfig.Durable = ofi.Name()
}
obs, err := e.mset.addConsumerWithAssignment(&cfg.ConsumerConfig, _EMPTY_, nil, true, ActionCreateOrUpdate)
obs, err := e.mset.addConsumerWithAssignment(&cfg.ConsumerConfig, _EMPTY_, nil, true, ActionCreateOrUpdate, false)
if err != nil {
s.Warnf(" Error adding consumer %q: %v", cfg.Name, err)
continue
@ -1497,12 +1556,14 @@ func (a *Account) filteredStreams(filter string) []*stream {
var msets []*stream
for _, mset := range jsa.streams {
if filter != _EMPTY_ {
mset.cfgMu.RLock()
for _, subj := range mset.cfg.Subjects {
if SubjectsCollide(filter, subj) {
msets = append(msets, mset)
break
}
}
mset.cfgMu.RUnlock()
} else {
msets = append(msets, mset)
}
@ -1646,6 +1707,7 @@ func (a *Account) JetStreamUsage() JetStreamAccountStats {
stats.Memory, stats.Store = jsa.storageTotals()
stats.Domain = js.config.Domain
stats.API = JetStreamAPIStats{
Level: JSApiLevel,
Total: jsa.apiTotal,
Errors: jsa.apiErrors,
}
@ -2103,7 +2165,7 @@ func (js *jetStream) wouldExceedLimits(storeType StorageType, sz int) bool {
} else {
total, max = &js.storeUsed, js.config.MaxStore
}
return atomic.LoadInt64(total) > (max + int64(sz))
return (atomic.LoadInt64(total) + int64(sz)) > max
}
func (js *jetStream) limitsExceeded(storeType StorageType) bool {
@ -2143,14 +2205,11 @@ func (jsa *jsAccount) selectLimits(replicas int) (JetStreamAccountLimits, string
}
// Lock should be held.
func (jsa *jsAccount) countStreams(tier string, cfg *StreamConfig) int {
streams := len(jsa.streams)
if tier != _EMPTY_ {
streams = 0
for _, sa := range jsa.streams {
if isSameTier(&sa.cfg, cfg) {
streams++
}
func (jsa *jsAccount) countStreams(tier string, cfg *StreamConfig) (streams int) {
for _, sa := range jsa.streams {
// Don't count the stream toward the limit if it already exists.
if (tier == _EMPTY_ || isSameTier(&sa.cfg, cfg)) && sa.cfg.Name != cfg.Name {
streams++
}
}
return streams
@ -2256,7 +2315,7 @@ func (js *jetStream) checkBytesLimits(selectedLimits *JetStreamAccountLimits, ad
return NewJSMemoryResourcesExceededError()
}
// Check if this server can handle request.
if checkServer && js.memReserved+addBytes > js.config.MaxMemory {
if checkServer && js.memReserved+totalBytes > js.config.MaxMemory {
return NewJSMemoryResourcesExceededError()
}
case FileStorage:
@ -2265,7 +2324,7 @@ func (js *jetStream) checkBytesLimits(selectedLimits *JetStreamAccountLimits, ad
return NewJSStorageResourcesExceededError()
}
// Check if this server can handle request.
if checkServer && js.storeReserved+addBytes > js.config.MaxStore {
if checkServer && js.storeReserved+totalBytes > js.config.MaxStore {
return NewJSStorageResourcesExceededError()
}
}
@ -2335,6 +2394,7 @@ func (js *jetStream) usageStats() *JetStreamStats {
stats.ReservedStore = uint64(js.storeReserved)
s := js.srv
js.mu.RUnlock()
stats.API.Level = JSApiLevel
stats.API.Total = uint64(atomic.LoadInt64(&js.apiTotal))
stats.API.Errors = uint64(atomic.LoadInt64(&js.apiErrors))
stats.API.Inflight = uint64(atomic.LoadInt64(&js.apiInflight))
@ -2477,6 +2537,9 @@ func (s *Server) dynJetStreamConfig(storeDir string, maxStore, maxMem int64) *Je
opts := s.getOpts()
// Strict mode.
jsc.Strict = opts.JetStreamStrict
// Sync options.
jsc.SyncInterval = opts.SyncInterval
jsc.SyncAlways = opts.SyncAlways
@ -2566,7 +2629,7 @@ func (a *Account) addStreamTemplate(tc *StreamTemplateConfig) (*streamTemplate,
// FIXME(dlc) - Hacky
tcopy := tc.deepCopy()
tcopy.Config.Name = "_"
cfg, apiErr := s.checkStreamCfg(tcopy.Config, a)
cfg, apiErr := s.checkStreamCfg(tcopy.Config, a, false)
if apiErr != nil {
return nil, apiErr
}
@ -2868,11 +2931,11 @@ func (s *Server) resourcesExceededError() {
}
s.rerrMu.Unlock()
// If we are meta leader we should relinguish that here.
// If we are meta leader we should relinquish that here.
if didAlert {
if js := s.getJetStream(); js != nil {
js.mu.RLock()
if cc := js.cluster; cc != nil && cc.isLeader() {
if cc := js.cluster; cc != nil && cc.meta != nil {
cc.meta.StepDown()
}
js.mu.RUnlock()
@ -2970,3 +3033,14 @@ func fixCfgMirrorWithDedupWindow(cfg *StreamConfig) {
cfg.Duplicates = 0
}
}
func (s *Server) handleWritePermissionError() {
//TODO Check if we should add s.jetStreamOOSPending in condition
if s.JetStreamEnabled() {
s.Errorf("File system permission denied while writing, disabling JetStream")
go s.DisableJetStream()
//TODO Send respective advisory if needed, same as in handleOutOfSpace
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -95,6 +95,9 @@ const (
// JSConsumerEmptyFilter consumer filter in FilterSubjects cannot be empty
JSConsumerEmptyFilter ErrorIdentifier = 10139
// JSConsumerEmptyGroupName Group name cannot be an empty string
JSConsumerEmptyGroupName ErrorIdentifier = 10161
// JSConsumerEphemeralWithDurableInSubjectErr consumer expected to be ephemeral but detected a durable name set in subject
JSConsumerEphemeralWithDurableInSubjectErr ErrorIdentifier = 10019
@ -119,9 +122,15 @@ const (
// JSConsumerInvalidDeliverSubject invalid push consumer deliver subject
JSConsumerInvalidDeliverSubject ErrorIdentifier = 10112
// JSConsumerInvalidGroupNameErr Valid priority group name must match A-Z, a-z, 0-9, -_/=)+ and may not exceed 16 characters
JSConsumerInvalidGroupNameErr ErrorIdentifier = 10162
// JSConsumerInvalidPolicyErrF Generic delivery policy error ({err})
JSConsumerInvalidPolicyErrF ErrorIdentifier = 10094
// JSConsumerInvalidPriorityGroupErr Provided priority group does not exist for this consumer
JSConsumerInvalidPriorityGroupErr ErrorIdentifier = 10160
// JSConsumerInvalidSamplingErrF failed to parse consumer sampling configuration: {err}
JSConsumerInvalidSamplingErrF ErrorIdentifier = 10095
@ -173,10 +182,13 @@ const (
// JSConsumerOverlappingSubjectFilters consumer subject filters cannot overlap
JSConsumerOverlappingSubjectFilters ErrorIdentifier = 10138
// JSConsumerPriorityPolicyWithoutGroup Setting PriorityPolicy requires at least one PriorityGroup to be set
JSConsumerPriorityPolicyWithoutGroup ErrorIdentifier = 10159
// JSConsumerPullNotDurableErr consumer in pull mode requires a durable name
JSConsumerPullNotDurableErr ErrorIdentifier = 10085
// JSConsumerPullRequiresAckErr consumer in pull mode requires ack policy
// JSConsumerPullRequiresAckErr consumer in pull mode requires explicit ack policy on workqueue stream
JSConsumerPullRequiresAckErr ErrorIdentifier = 10084
// JSConsumerPullWithRateLimitErr consumer in pull mode can not have rate limit set
@ -218,7 +230,7 @@ const (
// JSInsufficientResourcesErr insufficient resources
JSInsufficientResourcesErr ErrorIdentifier = 10023
// JSInvalidJSONErr invalid JSON
// JSInvalidJSONErr invalid JSON: {err}
JSInvalidJSONErr ErrorIdentifier = 10025
// JSMaximumConsumersLimitErr maximum consumers limit reached
@ -230,15 +242,24 @@ const (
// JSMemoryResourcesExceededErr insufficient memory resources available
JSMemoryResourcesExceededErr ErrorIdentifier = 10028
// JSMessageTTLDisabledErr per-message TTL is disabled
JSMessageTTLDisabledErr ErrorIdentifier = 10166
// JSMessageTTLInvalidErr invalid per-message TTL
JSMessageTTLInvalidErr ErrorIdentifier = 10165
// JSMirrorConsumerSetupFailedErrF generic mirror consumer setup failure string ({err})
JSMirrorConsumerSetupFailedErrF ErrorIdentifier = 10029
// JSMirrorInvalidStreamName mirrored stream name is invalid
JSMirrorInvalidStreamName ErrorIdentifier = 10142
// JSMirrorInvalidSubjectFilter mirror subject filter is invalid
// JSMirrorInvalidSubjectFilter mirror transform source: {err}
JSMirrorInvalidSubjectFilter ErrorIdentifier = 10151
// JSMirrorInvalidTransformDestination mirror transform: {err}
JSMirrorInvalidTransformDestination ErrorIdentifier = 10154
// JSMirrorMaxMessageSizeTooBigErr stream mirror must have max message size >= source
JSMirrorMaxMessageSizeTooBigErr ErrorIdentifier = 10030
@ -281,6 +302,9 @@ const (
// JSNotEnabledForAccountErr JetStream not enabled for account
JSNotEnabledForAccountErr ErrorIdentifier = 10039
// JSPedanticErrF pedantic mode: {err}
JSPedanticErrF ErrorIdentifier = 10157
// JSPeerRemapErr peer remap failed
JSPeerRemapErr ErrorIdentifier = 10075
@ -308,10 +332,10 @@ const (
// JSSourceInvalidStreamName sourced stream name is invalid
JSSourceInvalidStreamName ErrorIdentifier = 10141
// JSSourceInvalidSubjectFilter source subject filter is invalid
// JSSourceInvalidSubjectFilter source transform source: {err}
JSSourceInvalidSubjectFilter ErrorIdentifier = 10145
// JSSourceInvalidTransformDestination source transform destination is invalid
// JSSourceInvalidTransformDestination source transform: {err}
JSSourceInvalidTransformDestination ErrorIdentifier = 10146
// JSSourceMaxMessageSizeTooBigErr stream source must have max message size >= target
@ -335,6 +359,12 @@ const (
// JSStreamDeleteErrF General stream deletion error string ({err})
JSStreamDeleteErrF ErrorIdentifier = 10050
// JSStreamDuplicateMessageConflict duplicate message id is in process
JSStreamDuplicateMessageConflict ErrorIdentifier = 10158
// JSStreamExpectedLastSeqPerSubjectNotReady expected last sequence per subject temporarily unavailable
JSStreamExpectedLastSeqPerSubjectNotReady ErrorIdentifier = 10163
// JSStreamExternalApiOverlapErrF stream external api prefix {prefix} must not overlap with {subject}
JSStreamExternalApiOverlapErrF ErrorIdentifier = 10021
@ -446,12 +476,24 @@ const (
// JSStreamTemplateNotFoundErr template not found
JSStreamTemplateNotFoundErr ErrorIdentifier = 10068
// JSStreamTooManyRequests too many requests
JSStreamTooManyRequests ErrorIdentifier = 10167
// JSStreamTransformInvalidDestination stream transform: {err}
JSStreamTransformInvalidDestination ErrorIdentifier = 10156
// JSStreamTransformInvalidSource stream transform source: {err}
JSStreamTransformInvalidSource ErrorIdentifier = 10155
// JSStreamUpdateErrF Generic stream update error string ({err})
JSStreamUpdateErrF ErrorIdentifier = 10069
// JSStreamWrongLastMsgIDErrF wrong last msg ID: {id}
JSStreamWrongLastMsgIDErrF ErrorIdentifier = 10070
// JSStreamWrongLastSequenceConstantErr wrong last sequence
JSStreamWrongLastSequenceConstantErr ErrorIdentifier = 10164
// JSStreamWrongLastSequenceErrF wrong last sequence: {seq}
JSStreamWrongLastSequenceErrF ErrorIdentifier = 10071
@ -494,6 +536,7 @@ var (
JSConsumerDurableNameNotMatchSubjectErr: {Code: 400, ErrCode: 10017, Description: "consumer name in subject does not match durable name in request"},
JSConsumerDurableNameNotSetErr: {Code: 400, ErrCode: 10018, Description: "consumer expected to be durable but a durable name was not set"},
JSConsumerEmptyFilter: {Code: 400, ErrCode: 10139, Description: "consumer filter in FilterSubjects cannot be empty"},
JSConsumerEmptyGroupName: {Code: 400, ErrCode: 10161, Description: "Group name cannot be an empty string"},
JSConsumerEphemeralWithDurableInSubjectErr: {Code: 400, ErrCode: 10019, Description: "consumer expected to be ephemeral but detected a durable name set in subject"},
JSConsumerEphemeralWithDurableNameErr: {Code: 400, ErrCode: 10020, Description: "consumer expected to be ephemeral but a durable name was set in request"},
JSConsumerExistingActiveErr: {Code: 400, ErrCode: 10105, Description: "consumer already exists and is still active"},
@ -502,7 +545,9 @@ var (
JSConsumerHBRequiresPushErr: {Code: 400, ErrCode: 10088, Description: "consumer idle heartbeat requires a push based consumer"},
JSConsumerInactiveThresholdExcess: {Code: 400, ErrCode: 10153, Description: "consumer inactive threshold exceeds system limit of {limit}"},
JSConsumerInvalidDeliverSubject: {Code: 400, ErrCode: 10112, Description: "invalid push consumer deliver subject"},
JSConsumerInvalidGroupNameErr: {Code: 400, ErrCode: 10162, Description: "Valid priority group name must match A-Z, a-z, 0-9, -_/=)+ and may not exceed 16 characters"},
JSConsumerInvalidPolicyErrF: {Code: 400, ErrCode: 10094, Description: "{err}"},
JSConsumerInvalidPriorityGroupErr: {Code: 400, ErrCode: 10160, Description: "Provided priority group does not exist for this consumer"},
JSConsumerInvalidSamplingErrF: {Code: 400, ErrCode: 10095, Description: "failed to parse consumer sampling configuration: {err}"},
JSConsumerMaxDeliverBackoffErr: {Code: 400, ErrCode: 10116, Description: "max deliver is required to be > length of backoff values"},
JSConsumerMaxPendingAckExcessErrF: {Code: 400, ErrCode: 10121, Description: "consumer max ack pending exceeds system limit of {limit}"},
@ -520,8 +565,9 @@ var (
JSConsumerOfflineErr: {Code: 500, ErrCode: 10119, Description: "consumer is offline"},
JSConsumerOnMappedErr: {Code: 400, ErrCode: 10092, Description: "consumer direct on a mapped consumer"},
JSConsumerOverlappingSubjectFilters: {Code: 400, ErrCode: 10138, Description: "consumer subject filters cannot overlap"},
JSConsumerPriorityPolicyWithoutGroup: {Code: 400, ErrCode: 10159, Description: "Setting PriorityPolicy requires at least one PriorityGroup to be set"},
JSConsumerPullNotDurableErr: {Code: 400, ErrCode: 10085, Description: "consumer in pull mode requires a durable name"},
JSConsumerPullRequiresAckErr: {Code: 400, ErrCode: 10084, Description: "consumer in pull mode requires ack policy"},
JSConsumerPullRequiresAckErr: {Code: 400, ErrCode: 10084, Description: "consumer in pull mode requires explicit ack policy on workqueue stream"},
JSConsumerPullWithRateLimitErr: {Code: 400, ErrCode: 10086, Description: "consumer in pull mode can not have rate limit set"},
JSConsumerPushMaxWaitingErr: {Code: 400, ErrCode: 10080, Description: "consumer in push mode can not set max waiting"},
JSConsumerReplacementWithDifferentNameErr: {Code: 400, ErrCode: 10106, Description: "consumer replacement durable config not the same"},
@ -535,13 +581,16 @@ var (
JSConsumerWQRequiresExplicitAckErr: {Code: 400, ErrCode: 10098, Description: "workqueue stream requires explicit ack"},
JSConsumerWithFlowControlNeedsHeartbeats: {Code: 400, ErrCode: 10108, Description: "consumer with flow control also needs heartbeats"},
JSInsufficientResourcesErr: {Code: 503, ErrCode: 10023, Description: "insufficient resources"},
JSInvalidJSONErr: {Code: 400, ErrCode: 10025, Description: "invalid JSON"},
JSInvalidJSONErr: {Code: 400, ErrCode: 10025, Description: "invalid JSON: {err}"},
JSMaximumConsumersLimitErr: {Code: 400, ErrCode: 10026, Description: "maximum consumers limit reached"},
JSMaximumStreamsLimitErr: {Code: 400, ErrCode: 10027, Description: "maximum number of streams reached"},
JSMemoryResourcesExceededErr: {Code: 500, ErrCode: 10028, Description: "insufficient memory resources available"},
JSMessageTTLDisabledErr: {Code: 400, ErrCode: 10166, Description: "per-message TTL is disabled"},
JSMessageTTLInvalidErr: {Code: 400, ErrCode: 10165, Description: "invalid per-message TTL"},
JSMirrorConsumerSetupFailedErrF: {Code: 500, ErrCode: 10029, Description: "{err}"},
JSMirrorInvalidStreamName: {Code: 400, ErrCode: 10142, Description: "mirrored stream name is invalid"},
JSMirrorInvalidSubjectFilter: {Code: 400, ErrCode: 10151, Description: "mirror subject filter is invalid"},
JSMirrorInvalidSubjectFilter: {Code: 400, ErrCode: 10151, Description: "mirror transform source: {err}"},
JSMirrorInvalidTransformDestination: {Code: 400, ErrCode: 10154, Description: "mirror transform: {err}"},
JSMirrorMaxMessageSizeTooBigErr: {Code: 400, ErrCode: 10030, Description: "stream mirror must have max message size >= source"},
JSMirrorMultipleFiltersNotAllowed: {Code: 400, ErrCode: 10150, Description: "mirror with multiple subject transforms cannot also have a single subject filter"},
JSMirrorOverlappingSubjectFilters: {Code: 400, ErrCode: 10152, Description: "mirror subject filters can not overlap"},
@ -556,6 +605,7 @@ var (
JSNotEmptyRequestErr: {Code: 400, ErrCode: 10038, Description: "expected an empty request payload"},
JSNotEnabledErr: {Code: 503, ErrCode: 10076, Description: "JetStream not enabled"},
JSNotEnabledForAccountErr: {Code: 503, ErrCode: 10039, Description: "JetStream not enabled for account"},
JSPedanticErrF: {Code: 400, ErrCode: 10157, Description: "pedantic mode: {err}"},
JSPeerRemapErr: {Code: 503, ErrCode: 10075, Description: "peer remap failed"},
JSRaftGeneralErrF: {Code: 500, ErrCode: 10041, Description: "{err}"},
JSReplicasCountCannotBeNegative: {Code: 400, ErrCode: 10133, Description: "replicas count cannot be negative"},
@ -565,8 +615,8 @@ var (
JSSourceConsumerSetupFailedErrF: {Code: 500, ErrCode: 10045, Description: "{err}"},
JSSourceDuplicateDetected: {Code: 400, ErrCode: 10140, Description: "duplicate source configuration detected"},
JSSourceInvalidStreamName: {Code: 400, ErrCode: 10141, Description: "sourced stream name is invalid"},
JSSourceInvalidSubjectFilter: {Code: 400, ErrCode: 10145, Description: "source subject filter is invalid"},
JSSourceInvalidTransformDestination: {Code: 400, ErrCode: 10146, Description: "source transform destination is invalid"},
JSSourceInvalidSubjectFilter: {Code: 400, ErrCode: 10145, Description: "source transform source: {err}"},
JSSourceInvalidTransformDestination: {Code: 400, ErrCode: 10146, Description: "source transform: {err}"},
JSSourceMaxMessageSizeTooBigErr: {Code: 400, ErrCode: 10046, Description: "stream source must have max message size >= target"},
JSSourceMultipleFiltersNotAllowed: {Code: 400, ErrCode: 10144, Description: "source with multiple subject transforms cannot also have a single subject filter"},
JSSourceOverlappingSubjectFilters: {Code: 400, ErrCode: 10147, Description: "source filters can not overlap"},
@ -574,6 +624,8 @@ var (
JSStreamAssignmentErrF: {Code: 500, ErrCode: 10048, Description: "{err}"},
JSStreamCreateErrF: {Code: 500, ErrCode: 10049, Description: "{err}"},
JSStreamDeleteErrF: {Code: 500, ErrCode: 10050, Description: "{err}"},
JSStreamDuplicateMessageConflict: {Code: 409, ErrCode: 10158, Description: "duplicate message id is in process"},
JSStreamExpectedLastSeqPerSubjectNotReady: {Code: 503, ErrCode: 10163, Description: "expected last sequence per subject temporarily unavailable"},
JSStreamExternalApiOverlapErrF: {Code: 400, ErrCode: 10021, Description: "stream external api prefix {prefix} must not overlap with {subject}"},
JSStreamExternalDelPrefixOverlapsErrF: {Code: 400, ErrCode: 10022, Description: "stream external delivery prefix {prefix} overlaps with stream subject {subject}"},
JSStreamGeneralErrorF: {Code: 500, ErrCode: 10051, Description: "{err}"},
@ -611,8 +663,12 @@ var (
JSStreamTemplateCreateErrF: {Code: 500, ErrCode: 10066, Description: "{err}"},
JSStreamTemplateDeleteErrF: {Code: 500, ErrCode: 10067, Description: "{err}"},
JSStreamTemplateNotFoundErr: {Code: 404, ErrCode: 10068, Description: "template not found"},
JSStreamTooManyRequests: {Code: 429, ErrCode: 10167, Description: "too many requests"},
JSStreamTransformInvalidDestination: {Code: 400, ErrCode: 10156, Description: "stream transform: {err}"},
JSStreamTransformInvalidSource: {Code: 400, ErrCode: 10155, Description: "stream transform source: {err}"},
JSStreamUpdateErrF: {Code: 500, ErrCode: 10069, Description: "{err}"},
JSStreamWrongLastMsgIDErrF: {Code: 400, ErrCode: 10070, Description: "wrong last msg ID: {id}"},
JSStreamWrongLastSequenceConstantErr: {Code: 400, ErrCode: 10164, Description: "wrong last sequence"},
JSStreamWrongLastSequenceErrF: {Code: 400, ErrCode: 10071, Description: "wrong last sequence: {seq}"},
JSTempStorageFailedErr: {Code: 500, ErrCode: 10072, Description: "JetStream unable to open temp storage for restore"},
JSTemplateNameNotMatchSubjectErr: {Code: 400, ErrCode: 10073, Description: "template name in subject does not match request"},
@ -959,6 +1015,16 @@ func NewJSConsumerEmptyFilterError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSConsumerEmptyFilter]
}
// NewJSConsumerEmptyGroupNameError creates a new JSConsumerEmptyGroupName error: "Group name cannot be an empty string"
func NewJSConsumerEmptyGroupNameError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSConsumerEmptyGroupName]
}
// NewJSConsumerEphemeralWithDurableInSubjectError creates a new JSConsumerEphemeralWithDurableInSubjectErr error: "consumer expected to be ephemeral but detected a durable name set in subject"
func NewJSConsumerEphemeralWithDurableInSubjectError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@ -1045,6 +1111,16 @@ func NewJSConsumerInvalidDeliverSubjectError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSConsumerInvalidDeliverSubject]
}
// NewJSConsumerInvalidGroupNameError creates a new JSConsumerInvalidGroupNameErr error: "Valid priority group name must match A-Z, a-z, 0-9, -_/=)+ and may not exceed 16 characters"
func NewJSConsumerInvalidGroupNameError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSConsumerInvalidGroupNameErr]
}
// NewJSConsumerInvalidPolicyError creates a new JSConsumerInvalidPolicyErrF error: "{err}"
func NewJSConsumerInvalidPolicyError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@ -1061,6 +1137,16 @@ func NewJSConsumerInvalidPolicyError(err error, opts ...ErrorOption) *ApiError {
}
}
// NewJSConsumerInvalidPriorityGroupError creates a new JSConsumerInvalidPriorityGroupErr error: "Provided priority group does not exist for this consumer"
func NewJSConsumerInvalidPriorityGroupError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSConsumerInvalidPriorityGroupErr]
}
// NewJSConsumerInvalidSamplingError creates a new JSConsumerInvalidSamplingErrF error: "failed to parse consumer sampling configuration: {err}"
func NewJSConsumerInvalidSamplingError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@ -1261,6 +1347,16 @@ func NewJSConsumerOverlappingSubjectFiltersError(opts ...ErrorOption) *ApiError
return ApiErrors[JSConsumerOverlappingSubjectFilters]
}
// NewJSConsumerPriorityPolicyWithoutGroupError creates a new JSConsumerPriorityPolicyWithoutGroup error: "Setting PriorityPolicy requires at least one PriorityGroup to be set"
func NewJSConsumerPriorityPolicyWithoutGroupError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSConsumerPriorityPolicyWithoutGroup]
}
// NewJSConsumerPullNotDurableError creates a new JSConsumerPullNotDurableErr error: "consumer in pull mode requires a durable name"
func NewJSConsumerPullNotDurableError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@ -1271,7 +1367,7 @@ func NewJSConsumerPullNotDurableError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSConsumerPullNotDurableErr]
}
// NewJSConsumerPullRequiresAckError creates a new JSConsumerPullRequiresAckErr error: "consumer in pull mode requires ack policy"
// NewJSConsumerPullRequiresAckError creates a new JSConsumerPullRequiresAckErr error: "consumer in pull mode requires explicit ack policy on workqueue stream"
func NewJSConsumerPullRequiresAckError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
@ -1417,14 +1513,20 @@ func NewJSInsufficientResourcesError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSInsufficientResourcesErr]
}
// NewJSInvalidJSONError creates a new JSInvalidJSONErr error: "invalid JSON"
func NewJSInvalidJSONError(opts ...ErrorOption) *ApiError {
// NewJSInvalidJSONError creates a new JSInvalidJSONErr error: "invalid JSON: {err}"
func NewJSInvalidJSONError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSInvalidJSONErr]
e := ApiErrors[JSInvalidJSONErr]
args := e.toReplacerArgs([]interface{}{"{err}", err})
return &ApiError{
Code: e.Code,
ErrCode: e.ErrCode,
Description: strings.NewReplacer(args...).Replace(e.Description),
}
}
// NewJSMaximumConsumersLimitError creates a new JSMaximumConsumersLimitErr error: "maximum consumers limit reached"
@ -1457,6 +1559,26 @@ func NewJSMemoryResourcesExceededError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSMemoryResourcesExceededErr]
}
// NewJSMessageTTLDisabledError creates a new JSMessageTTLDisabledErr error: "per-message TTL is disabled"
func NewJSMessageTTLDisabledError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSMessageTTLDisabledErr]
}
// NewJSMessageTTLInvalidError creates a new JSMessageTTLInvalidErr error: "invalid per-message TTL"
func NewJSMessageTTLInvalidError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSMessageTTLInvalidErr]
}
// NewJSMirrorConsumerSetupFailedError creates a new JSMirrorConsumerSetupFailedErrF error: "{err}"
func NewJSMirrorConsumerSetupFailedError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@ -1483,14 +1605,36 @@ func NewJSMirrorInvalidStreamNameError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSMirrorInvalidStreamName]
}
// NewJSMirrorInvalidSubjectFilterError creates a new JSMirrorInvalidSubjectFilter error: "mirror subject filter is invalid"
func NewJSMirrorInvalidSubjectFilterError(opts ...ErrorOption) *ApiError {
// NewJSMirrorInvalidSubjectFilterError creates a new JSMirrorInvalidSubjectFilter error: "mirror transform source: {err}"
func NewJSMirrorInvalidSubjectFilterError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSMirrorInvalidSubjectFilter]
e := ApiErrors[JSMirrorInvalidSubjectFilter]
args := e.toReplacerArgs([]interface{}{"{err}", err})
return &ApiError{
Code: e.Code,
ErrCode: e.ErrCode,
Description: strings.NewReplacer(args...).Replace(e.Description),
}
}
// NewJSMirrorInvalidTransformDestinationError creates a new JSMirrorInvalidTransformDestination error: "mirror transform: {err}"
func NewJSMirrorInvalidTransformDestinationError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
e := ApiErrors[JSMirrorInvalidTransformDestination]
args := e.toReplacerArgs([]interface{}{"{err}", err})
return &ApiError{
Code: e.Code,
ErrCode: e.ErrCode,
Description: strings.NewReplacer(args...).Replace(e.Description),
}
}
// NewJSMirrorMaxMessageSizeTooBigError creates a new JSMirrorMaxMessageSizeTooBigErr error: "stream mirror must have max message size >= source"
@ -1633,6 +1777,22 @@ func NewJSNotEnabledForAccountError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSNotEnabledForAccountErr]
}
// NewJSPedanticError creates a new JSPedanticErrF error: "pedantic mode: {err}"
func NewJSPedanticError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
e := ApiErrors[JSPedanticErrF]
args := e.toReplacerArgs([]interface{}{"{err}", err})
return &ApiError{
Code: e.Code,
ErrCode: e.ErrCode,
Description: strings.NewReplacer(args...).Replace(e.Description),
}
}
// NewJSPeerRemapError creates a new JSPeerRemapErr error: "peer remap failed"
func NewJSPeerRemapError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@ -1747,24 +1907,36 @@ func NewJSSourceInvalidStreamNameError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSSourceInvalidStreamName]
}
// NewJSSourceInvalidSubjectFilterError creates a new JSSourceInvalidSubjectFilter error: "source subject filter is invalid"
func NewJSSourceInvalidSubjectFilterError(opts ...ErrorOption) *ApiError {
// NewJSSourceInvalidSubjectFilterError creates a new JSSourceInvalidSubjectFilter error: "source transform source: {err}"
func NewJSSourceInvalidSubjectFilterError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSSourceInvalidSubjectFilter]
e := ApiErrors[JSSourceInvalidSubjectFilter]
args := e.toReplacerArgs([]interface{}{"{err}", err})
return &ApiError{
Code: e.Code,
ErrCode: e.ErrCode,
Description: strings.NewReplacer(args...).Replace(e.Description),
}
}
// NewJSSourceInvalidTransformDestinationError creates a new JSSourceInvalidTransformDestination error: "source transform destination is invalid"
func NewJSSourceInvalidTransformDestinationError(opts ...ErrorOption) *ApiError {
// NewJSSourceInvalidTransformDestinationError creates a new JSSourceInvalidTransformDestination error: "source transform: {err}"
func NewJSSourceInvalidTransformDestinationError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSSourceInvalidTransformDestination]
e := ApiErrors[JSSourceInvalidTransformDestination]
args := e.toReplacerArgs([]interface{}{"{err}", err})
return &ApiError{
Code: e.Code,
ErrCode: e.ErrCode,
Description: strings.NewReplacer(args...).Replace(e.Description),
}
}
// NewJSSourceMaxMessageSizeTooBigError creates a new JSSourceMaxMessageSizeTooBigErr error: "stream source must have max message size >= target"
@ -1855,6 +2027,26 @@ func NewJSStreamDeleteError(err error, opts ...ErrorOption) *ApiError {
}
}
// NewJSStreamDuplicateMessageConflictError creates a new JSStreamDuplicateMessageConflict error: "duplicate message id is in process"
func NewJSStreamDuplicateMessageConflictError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSStreamDuplicateMessageConflict]
}
// NewJSStreamExpectedLastSeqPerSubjectNotReadyError creates a new JSStreamExpectedLastSeqPerSubjectNotReady error: "expected last sequence per subject temporarily unavailable"
func NewJSStreamExpectedLastSeqPerSubjectNotReadyError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSStreamExpectedLastSeqPerSubjectNotReady]
}
// NewJSStreamExternalApiOverlapError creates a new JSStreamExternalApiOverlapErrF error: "stream external api prefix {prefix} must not overlap with {subject}"
func NewJSStreamExternalApiOverlapError(prefix interface{}, subject interface{}, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@ -2315,6 +2507,48 @@ func NewJSStreamTemplateNotFoundError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSStreamTemplateNotFoundErr]
}
// NewJSStreamTooManyRequestsError creates a new JSStreamTooManyRequests error: "too many requests"
func NewJSStreamTooManyRequestsError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSStreamTooManyRequests]
}
// NewJSStreamTransformInvalidDestinationError creates a new JSStreamTransformInvalidDestination error: "stream transform: {err}"
func NewJSStreamTransformInvalidDestinationError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
e := ApiErrors[JSStreamTransformInvalidDestination]
args := e.toReplacerArgs([]interface{}{"{err}", err})
return &ApiError{
Code: e.Code,
ErrCode: e.ErrCode,
Description: strings.NewReplacer(args...).Replace(e.Description),
}
}
// NewJSStreamTransformInvalidSourceError creates a new JSStreamTransformInvalidSource error: "stream transform source: {err}"
func NewJSStreamTransformInvalidSourceError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
e := ApiErrors[JSStreamTransformInvalidSource]
args := e.toReplacerArgs([]interface{}{"{err}", err})
return &ApiError{
Code: e.Code,
ErrCode: e.ErrCode,
Description: strings.NewReplacer(args...).Replace(e.Description),
}
}
// NewJSStreamUpdateError creates a new JSStreamUpdateErrF error: "{err}"
func NewJSStreamUpdateError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@ -2347,6 +2581,16 @@ func NewJSStreamWrongLastMsgIDError(id interface{}, opts ...ErrorOption) *ApiErr
}
}
// NewJSStreamWrongLastSequenceConstantError creates a new JSStreamWrongLastSequenceConstantErr error: "wrong last sequence"
func NewJSStreamWrongLastSequenceConstantError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSStreamWrongLastSequenceConstantErr]
}
// NewJSStreamWrongLastSequenceError creates a new JSStreamWrongLastSequenceErrF error: "wrong last sequence: {seq}"
func NewJSStreamWrongLastSequenceError(seq uint64, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)

Some files were not shown because too many files have changed in this diff Show More